Compare commits
7 Commits
feature/do
...
fix/docker
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e728096fa | ||
|
|
4e1c4bd24e | ||
|
|
cce3390a2d | ||
|
|
4fe2d01361 | ||
|
|
38f3ea42a7 | ||
|
|
102352eac4 | ||
|
|
c09a57644f |
@@ -97,13 +97,16 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
|
|||||||
if value != param.default and not ignore_default_value:
|
if value != param.default and not ignore_default_value:
|
||||||
current_values[name] = to_serializable_dict(value)
|
current_values[name] = to_serializable_dict(value)
|
||||||
|
|
||||||
if hasattr(obj, '__slots__'):
|
# Don't serialize private __slots__ - they're internal implementation details
|
||||||
for slot in obj.__slots__:
|
# not constructor parameters. This was causing URLPatternFilter to fail
|
||||||
if slot.startswith('_'): # Handle private slots
|
# because _simple_suffixes was being serialized as 'simple_suffixes'
|
||||||
attr_name = slot[1:] # Remove leading '_'
|
# if hasattr(obj, '__slots__'):
|
||||||
value = getattr(obj, slot, None)
|
# for slot in obj.__slots__:
|
||||||
if value is not None:
|
# if slot.startswith('_'): # Handle private slots
|
||||||
current_values[attr_name] = to_serializable_dict(value)
|
# attr_name = slot[1:] # Remove leading '_'
|
||||||
|
# value = getattr(obj, slot, None)
|
||||||
|
# if value is not None:
|
||||||
|
# current_values[attr_name] = to_serializable_dict(value)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,13 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
self.url_scorer = url_scorer
|
self.url_scorer = url_scorer
|
||||||
self.include_external = include_external
|
self.include_external = include_external
|
||||||
self.max_pages = max_pages
|
self.max_pages = max_pages
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
# self.logger = logger or logging.getLogger(__name__)
|
||||||
|
# Ensure logger is always a Logger instance, not a dict from serialization
|
||||||
|
if isinstance(logger, logging.Logger):
|
||||||
|
self.logger = logger
|
||||||
|
else:
|
||||||
|
# Create a new logger if logger is None, dict, or any other non-Logger type
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
self.stats = TraversalStats(start_time=datetime.now())
|
self.stats = TraversalStats(start_time=datetime.now())
|
||||||
self._cancel_event = asyncio.Event()
|
self._cancel_event = asyncio.Event()
|
||||||
self._pages_crawled = 0
|
self._pages_crawled = 0
|
||||||
|
|||||||
@@ -38,7 +38,13 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
self.include_external = include_external
|
self.include_external = include_external
|
||||||
self.score_threshold = score_threshold
|
self.score_threshold = score_threshold
|
||||||
self.max_pages = max_pages
|
self.max_pages = max_pages
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
# self.logger = logger or logging.getLogger(__name__)
|
||||||
|
# Ensure logger is always a Logger instance, not a dict from serialization
|
||||||
|
if isinstance(logger, logging.Logger):
|
||||||
|
self.logger = logger
|
||||||
|
else:
|
||||||
|
# Create a new logger if logger is None, dict, or any other non-Logger type
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
self.stats = TraversalStats(start_time=datetime.now())
|
self.stats = TraversalStats(start_time=datetime.now())
|
||||||
self._cancel_event = asyncio.Event()
|
self._cancel_event = asyncio.Event()
|
||||||
self._pages_crawled = 0
|
self._pages_crawled = 0
|
||||||
|
|||||||
@@ -120,6 +120,9 @@ class URLPatternFilter(URLFilter):
|
|||||||
"""Pattern filter balancing speed and completeness"""
|
"""Pattern filter balancing speed and completeness"""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
|
"patterns", # Store original patterns for serialization
|
||||||
|
"use_glob", # Store original use_glob for serialization
|
||||||
|
"reverse", # Store original reverse for serialization
|
||||||
"_simple_suffixes",
|
"_simple_suffixes",
|
||||||
"_simple_prefixes",
|
"_simple_prefixes",
|
||||||
"_domain_patterns",
|
"_domain_patterns",
|
||||||
@@ -142,6 +145,11 @@ class URLPatternFilter(URLFilter):
|
|||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Store original constructor params for serialization
|
||||||
|
self.patterns = patterns
|
||||||
|
self.use_glob = use_glob
|
||||||
|
self.reverse = reverse
|
||||||
|
|
||||||
self._reverse = reverse
|
self._reverse = reverse
|
||||||
patterns = [patterns] if isinstance(patterns, (str, Pattern)) else patterns
|
patterns = [patterns] if isinstance(patterns, (str, Pattern)) else patterns
|
||||||
|
|
||||||
|
|||||||
@@ -253,6 +253,16 @@ class CrawlResult(BaseModel):
|
|||||||
requirements change, this is where you would update the logic.
|
requirements change, this is where you would update the logic.
|
||||||
"""
|
"""
|
||||||
result = super().model_dump(*args, **kwargs)
|
result = super().model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
# Remove any property descriptors that might have been included
|
||||||
|
# These deprecated properties should not be in the serialized output
|
||||||
|
for key in ['fit_html', 'fit_markdown', 'markdown_v2']:
|
||||||
|
if key in result and isinstance(result[key], property):
|
||||||
|
# del result[key]
|
||||||
|
# Nasrin: I decided to convert it to string instead of removing it.
|
||||||
|
result[key] = str(result[key])
|
||||||
|
|
||||||
|
# Add the markdown field properly
|
||||||
if self._markdown is not None:
|
if self._markdown is not None:
|
||||||
result["markdown"] = self._markdown.model_dump()
|
result["markdown"] = self._markdown.model_dump()
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -28,25 +28,43 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
|||||||
signing_key = get_jwk_from_secret(SECRET_KEY)
|
signing_key = get_jwk_from_secret(SECRET_KEY)
|
||||||
return instance.encode(to_encode, signing_key, alg='HS256')
|
return instance.encode(to_encode, signing_key, alg='HS256')
|
||||||
|
|
||||||
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
|
def verify_token(credentials: HTTPAuthorizationCredentials) -> Dict:
|
||||||
"""Verify the JWT token from the Authorization header."""
|
"""Verify the JWT token from the Authorization header."""
|
||||||
|
|
||||||
if credentials is None:
|
if not credentials or not credentials.credentials:
|
||||||
return None
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="No token provided",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
verifying_key = get_jwk_from_secret(SECRET_KEY)
|
verifying_key = get_jwk_from_secret(SECRET_KEY)
|
||||||
try:
|
try:
|
||||||
payload = instance.decode(token, verifying_key, do_time_check=True, algorithms='HS256')
|
payload = instance.decode(token, verifying_key, do_time_check=True, algorithms='HS256')
|
||||||
return payload
|
return payload
|
||||||
except Exception:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=f"Invalid or expired token: {str(e)}",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_token_dependency(config: Dict):
|
def get_token_dependency(config: Dict):
|
||||||
"""Return the token dependency if JWT is enabled, else a function that returns None."""
|
"""Return the token dependency if JWT is enabled, else a function that returns None."""
|
||||||
|
|
||||||
if config.get("security", {}).get("jwt_enabled", False):
|
if config.get("security", {}).get("jwt_enabled", False):
|
||||||
return verify_token
|
def jwt_required(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
|
||||||
|
"""Enforce JWT authentication when enabled."""
|
||||||
|
if credentials is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Authentication required. Please provide a valid Bearer token.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
return verify_token(credentials)
|
||||||
|
return jwt_required
|
||||||
else:
|
else:
|
||||||
return lambda: None
|
return lambda: None
|
||||||
|
|
||||||
|
|||||||
@@ -126,30 +126,6 @@ Factors:
|
|||||||
- URL depth (fewer slashes = higher authority)
|
- URL depth (fewer slashes = higher authority)
|
||||||
- Clean URL structure
|
- Clean URL structure
|
||||||
|
|
||||||
### Custom Link Scoring
|
|
||||||
|
|
||||||
```python
|
|
||||||
class CustomLinkScorer:
|
|
||||||
def score(self, link: Link, query: str, state: CrawlState) -> float:
|
|
||||||
# Prioritize specific URL patterns
|
|
||||||
if "/api/reference/" in link.href:
|
|
||||||
return 2.0 # Double the score
|
|
||||||
|
|
||||||
# Deprioritize certain sections
|
|
||||||
if "/archive/" in link.href:
|
|
||||||
return 0.1 # Reduce score by 90%
|
|
||||||
|
|
||||||
# Default scoring
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Use with adaptive crawler
|
|
||||||
adaptive = AdaptiveCrawler(
|
|
||||||
crawler,
|
|
||||||
config=config,
|
|
||||||
link_scorer=CustomLinkScorer()
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Domain-Specific Configurations
|
## Domain-Specific Configurations
|
||||||
|
|
||||||
### Technical Documentation
|
### Technical Documentation
|
||||||
@@ -230,8 +206,12 @@ config = AdaptiveConfig(
|
|||||||
|
|
||||||
# Periodically clean state
|
# Periodically clean state
|
||||||
if len(state.knowledge_base) > 1000:
|
if len(state.knowledge_base) > 1000:
|
||||||
# Keep only most relevant
|
# Keep only the top 500 most relevant docs
|
||||||
state.knowledge_base = get_top_relevant(state.knowledge_base, 500)
|
top_content = adaptive.get_relevant_content(top_k=500)
|
||||||
|
keep_indices = {d["index"] for d in top_content}
|
||||||
|
state.knowledge_base = [
|
||||||
|
doc for i, doc in enumerate(state.knowledge_base) if i in keep_indices
|
||||||
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Parallel Processing
|
### Parallel Processing
|
||||||
@@ -252,18 +232,6 @@ tasks = [
|
|||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Caching Strategy
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Enable caching for repeated crawls
|
|
||||||
async with AsyncWebCrawler(
|
|
||||||
config=BrowserConfig(
|
|
||||||
cache_mode=CacheMode.ENABLED
|
|
||||||
)
|
|
||||||
) as crawler:
|
|
||||||
adaptive = AdaptiveCrawler(crawler, config)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Debugging & Analysis
|
## Debugging & Analysis
|
||||||
|
|
||||||
### Enable Verbose Logging
|
### Enable Verbose Logging
|
||||||
@@ -322,9 +290,9 @@ with open("crawl_analysis.json", "w") as f:
|
|||||||
### Implementing a Custom Strategy
|
### Implementing a Custom Strategy
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from crawl4ai.adaptive_crawler import BaseStrategy
|
from crawl4ai.adaptive_crawler import CrawlStrategy
|
||||||
|
|
||||||
class DomainSpecificStrategy(BaseStrategy):
|
class DomainSpecificStrategy(CrawlStrategy):
|
||||||
def calculate_coverage(self, state: CrawlState) -> float:
|
def calculate_coverage(self, state: CrawlState) -> float:
|
||||||
# Custom coverage calculation
|
# Custom coverage calculation
|
||||||
# e.g., weight certain terms more heavily
|
# e.g., weight certain terms more heavily
|
||||||
@@ -351,7 +319,7 @@ adaptive = AdaptiveCrawler(
|
|||||||
### Combining Strategies
|
### Combining Strategies
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class HybridStrategy(BaseStrategy):
|
class HybridStrategy(CrawlStrategy):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.strategies = [
|
self.strategies = [
|
||||||
TechnicalDocStrategy(),
|
TechnicalDocStrategy(),
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ if __name__ == "__main__":
|
|||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
```
|
```
|
||||||
|
|
||||||
> IMPORTANT: By default cache mode is set to `CacheMode.ENABLED`. So to have fresh content, you need to set it to `CacheMode.BYPASS`
|
> IMPORTANT: By default cache mode is set to `CacheMode.BYPASS` to have fresh content. Set `CacheMode.ENABLED` to enable caching.
|
||||||
|
|
||||||
We’ll explore more advanced config in later tutorials (like enabling proxies, PDF output, multi-tab sessions, etc.). For now, just note how you pass these objects to manage crawling.
|
We’ll explore more advanced config in later tutorials (like enabling proxies, PDF output, multi-tab sessions, etc.). For now, just note how you pass these objects to manage crawling.
|
||||||
|
|
||||||
|
|||||||
201
tests/docker/test_filter_deep_crawl.py
Normal file
201
tests/docker/test_filter_deep_crawl.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
Test the complete fix for both the filter serialization and JSON serialization issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from crawl4ai import BrowserConfig, CacheMode, CrawlerRunConfig
|
||||||
|
from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, FilterChain, URLPatternFilter
|
||||||
|
|
||||||
|
BASE_URL = "http://localhost:11234/" # Adjust port as needed
|
||||||
|
|
||||||
|
async def test_with_docker_client():
|
||||||
|
"""Test using the Docker client (same as 1419.py)."""
|
||||||
|
from crawl4ai.docker_client import Crawl4aiDockerClient
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing with Docker Client")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with Crawl4aiDockerClient(
|
||||||
|
base_url=BASE_URL,
|
||||||
|
verbose=True,
|
||||||
|
) as client:
|
||||||
|
|
||||||
|
# Create filter chain - testing the serialization fix
|
||||||
|
filter_chain = [
|
||||||
|
URLPatternFilter(
|
||||||
|
# patterns=["*about*", "*privacy*", "*terms*"],
|
||||||
|
patterns=["*advanced*"],
|
||||||
|
reverse=True
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
crawler_config = CrawlerRunConfig(
|
||||||
|
deep_crawl_strategy=BFSDeepCrawlStrategy(
|
||||||
|
max_depth=2, # Keep it shallow for testing
|
||||||
|
# max_pages=5, # Limit pages for testing
|
||||||
|
filter_chain=FilterChain(filter_chain)
|
||||||
|
),
|
||||||
|
cache_mode=CacheMode.BYPASS,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n1. Testing crawl with filters...")
|
||||||
|
results = await client.crawl(
|
||||||
|
["https://docs.crawl4ai.com"], # Simple test page
|
||||||
|
browser_config=BrowserConfig(headless=True),
|
||||||
|
crawler_config=crawler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
print(f"✅ Crawl succeeded! Type: {type(results)}")
|
||||||
|
if hasattr(results, 'success'):
|
||||||
|
print(f"✅ Results success: {results.success}")
|
||||||
|
# Test that we can iterate results without JSON errors
|
||||||
|
if hasattr(results, '__iter__'):
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if hasattr(result, 'url'):
|
||||||
|
print(f" Result {i}: {result.url[:50]}...")
|
||||||
|
else:
|
||||||
|
print(f" Result {i}: {str(result)[:50]}...")
|
||||||
|
else:
|
||||||
|
# Handle list of results
|
||||||
|
print(f"✅ Got {len(results)} results")
|
||||||
|
for i, result in enumerate(results[:3]): # Show first 3
|
||||||
|
print(f" Result {i}: {result.url[:50]}...")
|
||||||
|
else:
|
||||||
|
print("❌ Crawl failed - no results returned")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n✅ Docker client test completed successfully!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Docker client test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_with_rest_api():
|
||||||
|
"""Test using REST API directly."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing with REST API")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Create filter configuration
|
||||||
|
deep_crawl_strategy_payload = {
|
||||||
|
"type": "BFSDeepCrawlStrategy",
|
||||||
|
"params": {
|
||||||
|
"max_depth": 2,
|
||||||
|
# "max_pages": 5,
|
||||||
|
"filter_chain": {
|
||||||
|
"type": "FilterChain",
|
||||||
|
"params": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"type": "URLPatternFilter",
|
||||||
|
"params": {
|
||||||
|
"patterns": ["*advanced*"],
|
||||||
|
"reverse": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crawl_payload = {
|
||||||
|
"urls": ["https://docs.crawl4ai.com"],
|
||||||
|
"browser_config": {"type": "BrowserConfig", "params": {"headless": True}},
|
||||||
|
"crawler_config": {
|
||||||
|
"type": "CrawlerRunConfig",
|
||||||
|
"params": {
|
||||||
|
"deep_crawl_strategy": deep_crawl_strategy_payload,
|
||||||
|
"cache_mode": "bypass"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
print("\n1. Sending crawl request to REST API...")
|
||||||
|
response = await client.post(
|
||||||
|
f"{BASE_URL}crawl",
|
||||||
|
json=crawl_payload,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"✅ REST API returned 200 OK")
|
||||||
|
data = response.json()
|
||||||
|
if data.get("success"):
|
||||||
|
results = data.get("results", [])
|
||||||
|
print(f"✅ Got {len(results)} results")
|
||||||
|
for i, result in enumerate(results[:3]):
|
||||||
|
print(f" Result {i}: {result.get('url', 'unknown')[:50]}...")
|
||||||
|
else:
|
||||||
|
print(f"❌ Crawl not successful: {data}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print(f"❌ REST API returned {response.status_code}")
|
||||||
|
print(f" Response: {response.text[:500]}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n✅ REST API test completed successfully!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ REST API test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("\n🧪 TESTING COMPLETE FIX FOR DOCKER FILTER AND JSON ISSUES")
|
||||||
|
print("=" * 60)
|
||||||
|
print("Make sure the server is running with the updated code!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Test 1: Docker client
|
||||||
|
docker_passed = await test_with_docker_client()
|
||||||
|
results.append(("Docker Client", docker_passed))
|
||||||
|
|
||||||
|
# Test 2: REST API
|
||||||
|
rest_passed = await test_with_rest_api()
|
||||||
|
results.append(("REST API", rest_passed))
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("FINAL TEST SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
for test_name, passed in results:
|
||||||
|
status = "✅ PASSED" if passed else "❌ FAILED"
|
||||||
|
print(f"{test_name:20} {status}")
|
||||||
|
if not passed:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("🎉 ALL TESTS PASSED! Both issues are fully resolved!")
|
||||||
|
print("\nThe fixes:")
|
||||||
|
print("1. Filter serialization: Fixed by not serializing private __slots__")
|
||||||
|
print("2. JSON serialization: Fixed by removing property descriptors from model_dump()")
|
||||||
|
else:
|
||||||
|
print("⚠️ Some tests failed. Please check the server logs for details.")
|
||||||
|
|
||||||
|
return 0 if all_passed else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
sys.exit(asyncio.run(main()))
|
||||||
Reference in New Issue
Block a user