Merge pull request #1648 from christopher-w-murphy/fix/content-relevance-filter
[Fix]: Docker server does not decode ContentRelevanceFilter
This commit is contained in:
@@ -1,16 +1,31 @@
|
||||
"""
|
||||
Test the complete fix for both the filter serialization and JSON serialization issues.
|
||||
"""
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
from crawl4ai import BrowserConfig, CacheMode, CrawlerRunConfig
|
||||
from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, FilterChain, URLPatternFilter
|
||||
from crawl4ai.deep_crawling import (
|
||||
BFSDeepCrawlStrategy,
|
||||
ContentRelevanceFilter,
|
||||
FilterChain,
|
||||
URLFilter,
|
||||
URLPatternFilter,
|
||||
)
|
||||
|
||||
BASE_URL = "http://localhost:11234/" # Adjust port as needed
|
||||
CRAWL4AI_DOCKER_PORT = os.environ.get("CRAWL4AI_DOCKER_PORT", "11234")
|
||||
try:
|
||||
BASE_PORT = int(CRAWL4AI_DOCKER_PORT)
|
||||
except TypeError:
|
||||
BASE_PORT = 11234
|
||||
BASE_URL = f"http://localhost:{BASE_PORT}/" # Adjust port as needed
|
||||
|
||||
async def test_with_docker_client():
|
||||
|
||||
async def test_with_docker_client(filter_chain: list[URLFilter], max_pages: int = 20, timeout: int = 30) -> bool:
|
||||
"""Test using the Docker client (same as 1419.py)."""
|
||||
from crawl4ai.docker_client import Crawl4aiDockerClient
|
||||
|
||||
@@ -24,19 +39,10 @@ async def test_with_docker_client():
|
||||
verbose=True,
|
||||
) as client:
|
||||
|
||||
# Create filter chain - testing the serialization fix
|
||||
filter_chain = [
|
||||
URLPatternFilter(
|
||||
# patterns=["*about*", "*privacy*", "*terms*"],
|
||||
patterns=["*advanced*"],
|
||||
reverse=True
|
||||
),
|
||||
]
|
||||
|
||||
crawler_config = CrawlerRunConfig(
|
||||
deep_crawl_strategy=BFSDeepCrawlStrategy(
|
||||
max_depth=2, # Keep it shallow for testing
|
||||
# max_pages=5, # Limit pages for testing
|
||||
max_pages=max_pages, # Limit pages for testing
|
||||
filter_chain=FilterChain(filter_chain)
|
||||
),
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
@@ -47,6 +53,7 @@ async def test_with_docker_client():
|
||||
["https://docs.crawl4ai.com"], # Simple test page
|
||||
browser_config=BrowserConfig(headless=True),
|
||||
crawler_config=crawler_config,
|
||||
hooks_timeout=timeout,
|
||||
)
|
||||
|
||||
if results:
|
||||
@@ -74,12 +81,11 @@ async def test_with_docker_client():
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Docker client test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_with_rest_api():
|
||||
async def test_with_rest_api(filters: list[dict[str, Any]], max_pages: int = 20, timeout: int = 30) -> bool:
|
||||
"""Test using REST API directly."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing with REST API")
|
||||
@@ -90,19 +96,11 @@ async def test_with_rest_api():
|
||||
"type": "BFSDeepCrawlStrategy",
|
||||
"params": {
|
||||
"max_depth": 2,
|
||||
# "max_pages": 5,
|
||||
"max_pages": max_pages,
|
||||
"filter_chain": {
|
||||
"type": "FilterChain",
|
||||
"params": {
|
||||
"filters": [
|
||||
{
|
||||
"type": "URLPatternFilter",
|
||||
"params": {
|
||||
"patterns": ["*advanced*"],
|
||||
"reverse": True
|
||||
}
|
||||
}
|
||||
]
|
||||
"filters": filters
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,7 +124,7 @@ async def test_with_rest_api():
|
||||
response = await client.post(
|
||||
f"{BASE_URL}crawl",
|
||||
json=crawl_payload,
|
||||
timeout=30
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -150,7 +148,6 @@ async def test_with_rest_api():
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ REST API test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
@@ -165,12 +162,62 @@ async def main():
|
||||
results = []
|
||||
|
||||
# Test 1: Docker client
|
||||
docker_passed = await test_with_docker_client()
|
||||
results.append(("Docker Client", docker_passed))
|
||||
max_pages_ = [20, 5]
|
||||
timeouts = [30, 60]
|
||||
filter_chain_test_cases = [
|
||||
[
|
||||
URLPatternFilter(
|
||||
# patterns=["*about*", "*privacy*", "*terms*"],
|
||||
patterns=["*advanced*"],
|
||||
reverse=True
|
||||
),
|
||||
],
|
||||
[
|
||||
ContentRelevanceFilter(
|
||||
query="about faq",
|
||||
threshold=0.2,
|
||||
),
|
||||
],
|
||||
]
|
||||
for idx, (filter_chain, max_pages, timeout) in enumerate(zip(filter_chain_test_cases, max_pages_, timeouts)):
|
||||
docker_passed = await test_with_docker_client(filter_chain=filter_chain, max_pages=max_pages, timeout=timeout)
|
||||
results.append((f"Docker Client w/ filter chain {idx}", docker_passed))
|
||||
|
||||
# Test 2: REST API
|
||||
rest_passed = await test_with_rest_api()
|
||||
results.append(("REST API", rest_passed))
|
||||
max_pages_ = [20, 5, 5]
|
||||
timeouts = [30, 60, 60]
|
||||
filters_test_cases = [
|
||||
[
|
||||
{
|
||||
"type": "URLPatternFilter",
|
||||
"params": {
|
||||
"patterns": ["*advanced*"],
|
||||
"reverse": True
|
||||
}
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "ContentRelevanceFilter",
|
||||
"params": {
|
||||
"query": "about faq",
|
||||
"threshold": 0.2,
|
||||
}
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "ContentRelevanceFilter",
|
||||
"params": {
|
||||
"query": ["about", "faq"],
|
||||
"threshold": 0.2,
|
||||
}
|
||||
}
|
||||
],
|
||||
]
|
||||
for idx, (filters, max_pages, timeout) in enumerate(zip(filters_test_cases, max_pages_, timeouts)):
|
||||
rest_passed = await test_with_rest_api(filters=filters, max_pages=max_pages, timeout=timeout)
|
||||
results.append((f"REST API w/ filters {idx}", rest_passed))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
@@ -186,10 +233,7 @@ async def main():
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("🎉 ALL TESTS PASSED! Both issues are fully resolved!")
|
||||
print("\nThe fixes:")
|
||||
print("1. Filter serialization: Fixed by not serializing private __slots__")
|
||||
print("2. JSON serialization: Fixed by removing property descriptors from model_dump()")
|
||||
print("🎉 ALL TESTS PASSED!")
|
||||
else:
|
||||
print("⚠️ Some tests failed. Please check the server logs for details.")
|
||||
|
||||
@@ -198,4 +242,4 @@ async def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(asyncio.run(main()))
|
||||
sys.exit(asyncio.run(main()))
|
||||
|
||||
Reference in New Issue
Block a user