Compare commits
7 Commits
feature/do
...
fix/docker
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e728096fa | ||
|
|
4e1c4bd24e | ||
|
|
cce3390a2d | ||
|
|
4fe2d01361 | ||
|
|
38f3ea42a7 | ||
|
|
102352eac4 | ||
|
|
c09a57644f |
@@ -97,13 +97,16 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
|
||||
if value != param.default and not ignore_default_value:
|
||||
current_values[name] = to_serializable_dict(value)
|
||||
|
||||
if hasattr(obj, '__slots__'):
|
||||
for slot in obj.__slots__:
|
||||
if slot.startswith('_'): # Handle private slots
|
||||
attr_name = slot[1:] # Remove leading '_'
|
||||
value = getattr(obj, slot, None)
|
||||
if value is not None:
|
||||
current_values[attr_name] = to_serializable_dict(value)
|
||||
# Don't serialize private __slots__ - they're internal implementation details
|
||||
# not constructor parameters. This was causing URLPatternFilter to fail
|
||||
# because _simple_suffixes was being serialized as 'simple_suffixes'
|
||||
# if hasattr(obj, '__slots__'):
|
||||
# for slot in obj.__slots__:
|
||||
# if slot.startswith('_'): # Handle private slots
|
||||
# attr_name = slot[1:] # Remove leading '_'
|
||||
# value = getattr(obj, slot, None)
|
||||
# if value is not None:
|
||||
# current_values[attr_name] = to_serializable_dict(value)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,13 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
||||
self.url_scorer = url_scorer
|
||||
self.include_external = include_external
|
||||
self.max_pages = max_pages
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
# self.logger = logger or logging.getLogger(__name__)
|
||||
# Ensure logger is always a Logger instance, not a dict from serialization
|
||||
if isinstance(logger, logging.Logger):
|
||||
self.logger = logger
|
||||
else:
|
||||
# Create a new logger if logger is None, dict, or any other non-Logger type
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.stats = TraversalStats(start_time=datetime.now())
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._pages_crawled = 0
|
||||
|
||||
@@ -38,7 +38,13 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
||||
self.include_external = include_external
|
||||
self.score_threshold = score_threshold
|
||||
self.max_pages = max_pages
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
# self.logger = logger or logging.getLogger(__name__)
|
||||
# Ensure logger is always a Logger instance, not a dict from serialization
|
||||
if isinstance(logger, logging.Logger):
|
||||
self.logger = logger
|
||||
else:
|
||||
# Create a new logger if logger is None, dict, or any other non-Logger type
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.stats = TraversalStats(start_time=datetime.now())
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._pages_crawled = 0
|
||||
|
||||
@@ -120,6 +120,9 @@ class URLPatternFilter(URLFilter):
|
||||
"""Pattern filter balancing speed and completeness"""
|
||||
|
||||
__slots__ = (
|
||||
"patterns", # Store original patterns for serialization
|
||||
"use_glob", # Store original use_glob for serialization
|
||||
"reverse", # Store original reverse for serialization
|
||||
"_simple_suffixes",
|
||||
"_simple_prefixes",
|
||||
"_domain_patterns",
|
||||
@@ -142,6 +145,11 @@ class URLPatternFilter(URLFilter):
|
||||
reverse: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Store original constructor params for serialization
|
||||
self.patterns = patterns
|
||||
self.use_glob = use_glob
|
||||
self.reverse = reverse
|
||||
|
||||
self._reverse = reverse
|
||||
patterns = [patterns] if isinstance(patterns, (str, Pattern)) else patterns
|
||||
|
||||
|
||||
@@ -253,6 +253,16 @@ class CrawlResult(BaseModel):
|
||||
requirements change, this is where you would update the logic.
|
||||
"""
|
||||
result = super().model_dump(*args, **kwargs)
|
||||
|
||||
# Remove any property descriptors that might have been included
|
||||
# These deprecated properties should not be in the serialized output
|
||||
for key in ['fit_html', 'fit_markdown', 'markdown_v2']:
|
||||
if key in result and isinstance(result[key], property):
|
||||
# del result[key]
|
||||
# Nasrin: I decided to convert it to string instead of removing it.
|
||||
result[key] = str(result[key])
|
||||
|
||||
# Add the markdown field properly
|
||||
if self._markdown is not None:
|
||||
result["markdown"] = self._markdown.model_dump()
|
||||
return result
|
||||
|
||||
@@ -28,25 +28,43 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
signing_key = get_jwk_from_secret(SECRET_KEY)
|
||||
return instance.encode(to_encode, signing_key, alg='HS256')
|
||||
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> Dict:
|
||||
"""Verify the JWT token from the Authorization header."""
|
||||
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
if not credentials or not credentials.credentials:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="No token provided",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
verifying_key = get_jwk_from_secret(SECRET_KEY)
|
||||
try:
|
||||
payload = instance.decode(token, verifying_key, do_time_check=True, algorithms='HS256')
|
||||
return payload
|
||||
except Exception:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid or expired token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
def get_token_dependency(config: Dict):
|
||||
"""Return the token dependency if JWT is enabled, else a function that returns None."""
|
||||
|
||||
|
||||
if config.get("security", {}).get("jwt_enabled", False):
|
||||
return verify_token
|
||||
def jwt_required(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
|
||||
"""Enforce JWT authentication when enabled."""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Authentication required. Please provide a valid Bearer token.",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
return verify_token(credentials)
|
||||
return jwt_required
|
||||
else:
|
||||
return lambda: None
|
||||
|
||||
|
||||
@@ -38,8 +38,8 @@ rate_limiting:
|
||||
|
||||
# Security Configuration
|
||||
security:
|
||||
enabled: false
|
||||
jwt_enabled: false
|
||||
enabled: false
|
||||
jwt_enabled: false
|
||||
https_redirect: false
|
||||
trusted_hosts: ["*"]
|
||||
headers:
|
||||
|
||||
@@ -126,30 +126,6 @@ Factors:
|
||||
- URL depth (fewer slashes = higher authority)
|
||||
- Clean URL structure
|
||||
|
||||
### Custom Link Scoring
|
||||
|
||||
```python
|
||||
class CustomLinkScorer:
|
||||
def score(self, link: Link, query: str, state: CrawlState) -> float:
|
||||
# Prioritize specific URL patterns
|
||||
if "/api/reference/" in link.href:
|
||||
return 2.0 # Double the score
|
||||
|
||||
# Deprioritize certain sections
|
||||
if "/archive/" in link.href:
|
||||
return 0.1 # Reduce score by 90%
|
||||
|
||||
# Default scoring
|
||||
return 1.0
|
||||
|
||||
# Use with adaptive crawler
|
||||
adaptive = AdaptiveCrawler(
|
||||
crawler,
|
||||
config=config,
|
||||
link_scorer=CustomLinkScorer()
|
||||
)
|
||||
```
|
||||
|
||||
## Domain-Specific Configurations
|
||||
|
||||
### Technical Documentation
|
||||
@@ -230,8 +206,12 @@ config = AdaptiveConfig(
|
||||
|
||||
# Periodically clean state
|
||||
if len(state.knowledge_base) > 1000:
|
||||
# Keep only most relevant
|
||||
state.knowledge_base = get_top_relevant(state.knowledge_base, 500)
|
||||
# Keep only the top 500 most relevant docs
|
||||
top_content = adaptive.get_relevant_content(top_k=500)
|
||||
keep_indices = {d["index"] for d in top_content}
|
||||
state.knowledge_base = [
|
||||
doc for i, doc in enumerate(state.knowledge_base) if i in keep_indices
|
||||
]
|
||||
```
|
||||
|
||||
### Parallel Processing
|
||||
@@ -252,18 +232,6 @@ tasks = [
|
||||
results = await asyncio.gather(*tasks)
|
||||
```
|
||||
|
||||
### Caching Strategy
|
||||
|
||||
```python
|
||||
# Enable caching for repeated crawls
|
||||
async with AsyncWebCrawler(
|
||||
config=BrowserConfig(
|
||||
cache_mode=CacheMode.ENABLED
|
||||
)
|
||||
) as crawler:
|
||||
adaptive = AdaptiveCrawler(crawler, config)
|
||||
```
|
||||
|
||||
## Debugging & Analysis
|
||||
|
||||
### Enable Verbose Logging
|
||||
@@ -322,9 +290,9 @@ with open("crawl_analysis.json", "w") as f:
|
||||
### Implementing a Custom Strategy
|
||||
|
||||
```python
|
||||
from crawl4ai.adaptive_crawler import BaseStrategy
|
||||
from crawl4ai.adaptive_crawler import CrawlStrategy
|
||||
|
||||
class DomainSpecificStrategy(BaseStrategy):
|
||||
class DomainSpecificStrategy(CrawlStrategy):
|
||||
def calculate_coverage(self, state: CrawlState) -> float:
|
||||
# Custom coverage calculation
|
||||
# e.g., weight certain terms more heavily
|
||||
@@ -351,7 +319,7 @@ adaptive = AdaptiveCrawler(
|
||||
### Combining Strategies
|
||||
|
||||
```python
|
||||
class HybridStrategy(BaseStrategy):
|
||||
class HybridStrategy(CrawlStrategy):
|
||||
def __init__(self):
|
||||
self.strategies = [
|
||||
TechnicalDocStrategy(),
|
||||
|
||||
@@ -79,7 +79,7 @@ if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
> IMPORTANT: By default cache mode is set to `CacheMode.ENABLED`. So to have fresh content, you need to set it to `CacheMode.BYPASS`
|
||||
> IMPORTANT: By default cache mode is set to `CacheMode.BYPASS` to have fresh content. Set `CacheMode.ENABLED` to enable caching.
|
||||
|
||||
We’ll explore more advanced config in later tutorials (like enabling proxies, PDF output, multi-tab sessions, etc.). For now, just note how you pass these objects to manage crawling.
|
||||
|
||||
|
||||
201
tests/docker/test_filter_deep_crawl.py
Normal file
201
tests/docker/test_filter_deep_crawl.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Test the complete fix for both the filter serialization and JSON serialization issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
from crawl4ai import BrowserConfig, CacheMode, CrawlerRunConfig
|
||||
from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, FilterChain, URLPatternFilter
|
||||
|
||||
BASE_URL = "http://localhost:11234/" # Adjust port as needed
|
||||
|
||||
async def test_with_docker_client():
|
||||
"""Test using the Docker client (same as 1419.py)."""
|
||||
from crawl4ai.docker_client import Crawl4aiDockerClient
|
||||
|
||||
print("=" * 60)
|
||||
print("Testing with Docker Client")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
async with Crawl4aiDockerClient(
|
||||
base_url=BASE_URL,
|
||||
verbose=True,
|
||||
) as client:
|
||||
|
||||
# Create filter chain - testing the serialization fix
|
||||
filter_chain = [
|
||||
URLPatternFilter(
|
||||
# patterns=["*about*", "*privacy*", "*terms*"],
|
||||
patterns=["*advanced*"],
|
||||
reverse=True
|
||||
),
|
||||
]
|
||||
|
||||
crawler_config = CrawlerRunConfig(
|
||||
deep_crawl_strategy=BFSDeepCrawlStrategy(
|
||||
max_depth=2, # Keep it shallow for testing
|
||||
# max_pages=5, # Limit pages for testing
|
||||
filter_chain=FilterChain(filter_chain)
|
||||
),
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
print("\n1. Testing crawl with filters...")
|
||||
results = await client.crawl(
|
||||
["https://docs.crawl4ai.com"], # Simple test page
|
||||
browser_config=BrowserConfig(headless=True),
|
||||
crawler_config=crawler_config,
|
||||
)
|
||||
|
||||
if results:
|
||||
print(f"✅ Crawl succeeded! Type: {type(results)}")
|
||||
if hasattr(results, 'success'):
|
||||
print(f"✅ Results success: {results.success}")
|
||||
# Test that we can iterate results without JSON errors
|
||||
if hasattr(results, '__iter__'):
|
||||
for i, result in enumerate(results):
|
||||
if hasattr(result, 'url'):
|
||||
print(f" Result {i}: {result.url[:50]}...")
|
||||
else:
|
||||
print(f" Result {i}: {str(result)[:50]}...")
|
||||
else:
|
||||
# Handle list of results
|
||||
print(f"✅ Got {len(results)} results")
|
||||
for i, result in enumerate(results[:3]): # Show first 3
|
||||
print(f" Result {i}: {result.url[:50]}...")
|
||||
else:
|
||||
print("❌ Crawl failed - no results returned")
|
||||
return False
|
||||
|
||||
print("\n✅ Docker client test completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Docker client test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_with_rest_api():
|
||||
"""Test using REST API directly."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing with REST API")
|
||||
print("=" * 60)
|
||||
|
||||
# Create filter configuration
|
||||
deep_crawl_strategy_payload = {
|
||||
"type": "BFSDeepCrawlStrategy",
|
||||
"params": {
|
||||
"max_depth": 2,
|
||||
# "max_pages": 5,
|
||||
"filter_chain": {
|
||||
"type": "FilterChain",
|
||||
"params": {
|
||||
"filters": [
|
||||
{
|
||||
"type": "URLPatternFilter",
|
||||
"params": {
|
||||
"patterns": ["*advanced*"],
|
||||
"reverse": True
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crawl_payload = {
|
||||
"urls": ["https://docs.crawl4ai.com"],
|
||||
"browser_config": {"type": "BrowserConfig", "params": {"headless": True}},
|
||||
"crawler_config": {
|
||||
"type": "CrawlerRunConfig",
|
||||
"params": {
|
||||
"deep_crawl_strategy": deep_crawl_strategy_payload,
|
||||
"cache_mode": "bypass"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
print("\n1. Sending crawl request to REST API...")
|
||||
response = await client.post(
|
||||
f"{BASE_URL}crawl",
|
||||
json=crawl_payload,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(f"✅ REST API returned 200 OK")
|
||||
data = response.json()
|
||||
if data.get("success"):
|
||||
results = data.get("results", [])
|
||||
print(f"✅ Got {len(results)} results")
|
||||
for i, result in enumerate(results[:3]):
|
||||
print(f" Result {i}: {result.get('url', 'unknown')[:50]}...")
|
||||
else:
|
||||
print(f"❌ Crawl not successful: {data}")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ REST API returned {response.status_code}")
|
||||
print(f" Response: {response.text[:500]}")
|
||||
return False
|
||||
|
||||
print("\n✅ REST API test completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ REST API test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("\n🧪 TESTING COMPLETE FIX FOR DOCKER FILTER AND JSON ISSUES")
|
||||
print("=" * 60)
|
||||
print("Make sure the server is running with the updated code!")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Docker client
|
||||
docker_passed = await test_with_docker_client()
|
||||
results.append(("Docker Client", docker_passed))
|
||||
|
||||
# Test 2: REST API
|
||||
rest_passed = await test_with_rest_api()
|
||||
results.append(("REST API", rest_passed))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("FINAL TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = True
|
||||
for test_name, passed in results:
|
||||
status = "✅ PASSED" if passed else "❌ FAILED"
|
||||
print(f"{test_name:20} {status}")
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("🎉 ALL TESTS PASSED! Both issues are fully resolved!")
|
||||
print("\nThe fixes:")
|
||||
print("1. Filter serialization: Fixed by not serializing private __slots__")
|
||||
print("2. JSON serialization: Fixed by removing property descriptors from model_dump()")
|
||||
else:
|
||||
print("⚠️ Some tests failed. Please check the server logs for details.")
|
||||
|
||||
return 0 if all_passed else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(asyncio.run(main()))
|
||||
Reference in New Issue
Block a user