feat(docker): add Docker service integration and config serialization

Add Docker service integration with FastAPI server and client implementation.
Implement serialization utilities for BrowserConfig and CrawlerRunConfig to support
Docker service communication. Clean up imports and improve error handling.

- Add Crawl4aiDockerClient class
- Implement config serialization/deserialization
- Add FastAPI server with streaming support
- Add health check endpoint
- Clean up imports and type hints
This commit is contained in:
UncleCode
2025-01-31 18:00:16 +08:00
parent ce4f04dad2
commit 53ac3ec0b4
11 changed files with 859 additions and 97 deletions

174
tests/docker/test_docker.py Normal file
View File

@@ -0,0 +1,174 @@
import requests
import time
import httpx
import asyncio
from typing import Dict, Any
from crawl4ai import (
BrowserConfig, CrawlerRunConfig, DefaultMarkdownGenerator,
PruningContentFilter, JsonCssExtractionStrategy, LLMContentFilter, CacheMode
)
from crawl4ai.docker_client import Crawl4aiDockerClient
class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235"):
self.base_url = base_url
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data)
task_id = response.json()["task_id"]
print(f"Task ID: {task_id}")
# Poll for result
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}")
status = result.json()
if status["status"] == "failed":
print("Task failed:", status.get("error"))
raise Exception(f"Task failed: {status.get('error')}")
if status["status"] == "completed":
return status
time.sleep(2)
async def test_direct_api():
"""Test direct API endpoints without using the client SDK"""
print("\n=== Testing Direct API Calls ===")
# Test 1: Basic crawl with content filtering
browser_config = BrowserConfig(
headless=True,
viewport_width=1200,
viewport_height=800
)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48,
threshold_type="fixed",
min_word_threshold=0
),
options={"ignore_links": True}
)
)
request_data = {
"urls": ["https://example.com"],
"browser_config": browser_config.dump(),
"crawler_config": crawler_config.dump()
}
# Make direct API call
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8000/crawl",
json=request_data,
timeout=300
)
assert response.status_code == 200
result = response.json()
print("Basic crawl result:", result["success"])
# Test 2: Structured extraction with JSON CSS
schema = {
"baseSelector": "article.post",
"fields": [
{"name": "title", "selector": "h1", "type": "text"},
{"name": "content", "selector": ".content", "type": "html"}
]
}
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
extraction_strategy=JsonCssExtractionStrategy(schema=schema)
)
request_data["crawler_config"] = crawler_config.dump()
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8000/crawl",
json=request_data
)
assert response.status_code == 200
result = response.json()
print("Structured extraction result:", result["success"])
# Test 3: Get schema
# async with httpx.AsyncClient() as client:
# response = await client.get("http://localhost:8000/schema")
# assert response.status_code == 200
# schemas = response.json()
# print("Retrieved schemas for:", list(schemas.keys()))
async def test_with_client():
"""Test using the Crawl4AI Docker client SDK"""
print("\n=== Testing Client SDK ===")
async with Crawl4aiDockerClient(verbose=True) as client:
# Test 1: Basic crawl
browser_config = BrowserConfig(headless=True)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48,
threshold_type="fixed"
)
)
)
result = await client.crawl(
urls=["https://example.com"],
browser_config=browser_config,
crawler_config=crawler_config
)
print("Client SDK basic crawl:", result.success)
# Test 2: LLM extraction with streaming
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=LLMContentFilter(
provider="openai/gpt-40",
instruction="Extract key technical concepts"
)
),
stream=True
)
async for result in await client.crawl(
urls=["https://example.com"],
browser_config=browser_config,
crawler_config=crawler_config
):
print(f"Streaming result for: {result.url}")
# # Test 3: Get schema
# schemas = await client.get_schema()
# print("Retrieved client schemas for:", list(schemas.keys()))
async def main():
"""Run all tests"""
# Test direct API
print("Testing direct API calls...")
# await test_direct_api()
# Test client SDK
print("\nTesting client SDK...")
await test_with_client()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,253 @@
import inspect
from typing import Any, Dict
from enum import Enum
def to_serializable_dict(obj: Any) -> Dict:
"""
Recursively convert an object to a serializable dictionary using {type, params} structure
for complex objects.
"""
if obj is None:
return None
# Handle basic types
if isinstance(obj, (str, int, float, bool)):
return obj
# Handle Enum
if isinstance(obj, Enum):
return {
"type": obj.__class__.__name__,
"params": obj.value
}
# Handle datetime objects
if hasattr(obj, 'isoformat'):
return obj.isoformat()
# Handle lists, tuples, and sets
if isinstance(obj, (list, tuple, set)):
return [to_serializable_dict(item) for item in obj]
# Handle dictionaries - preserve them as-is
if isinstance(obj, dict):
return {
"type": "dict", # Mark as plain dictionary
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
}
# Handle class instances
if hasattr(obj, '__class__'):
# Get constructor signature
sig = inspect.signature(obj.__class__.__init__)
params = sig.parameters
# Get current values
current_values = {}
for name, param in params.items():
if name == 'self':
continue
value = getattr(obj, name, param.default)
# Only include if different from default, considering empty values
if not (is_empty_value(value) and is_empty_value(param.default)):
if value != param.default:
current_values[name] = to_serializable_dict(value)
return {
"type": obj.__class__.__name__,
"params": current_values
}
return str(obj)
def from_serializable_dict(data: Any) -> Any:
"""
Recursively convert a serializable dictionary back to an object instance.
"""
if data is None:
return None
# Handle basic types
if isinstance(data, (str, int, float, bool)):
return data
# Handle typed data
if isinstance(data, dict) and "type" in data:
# Handle plain dictionaries
if data["type"] == "dict":
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
# Import from crawl4ai for class instances
import crawl4ai
cls = getattr(crawl4ai, data["type"])
# Handle Enum
if issubclass(cls, Enum):
return cls(data["params"])
# Handle class instances
constructor_args = {
k: from_serializable_dict(v) for k, v in data["params"].items()
}
return cls(**constructor_args)
# Handle lists
if isinstance(data, list):
return [from_serializable_dict(item) for item in data]
# Handle raw dictionaries (legacy support)
if isinstance(data, dict):
return {k: from_serializable_dict(v) for k, v in data.items()}
return data
def is_empty_value(value: Any) -> bool:
"""Check if a value is effectively empty/null."""
if value is None:
return True
if isinstance(value, (list, tuple, set, dict, str)) and len(value) == 0:
return True
return False
# if __name__ == "__main__":
# from crawl4ai import (
# CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
# PruningContentFilter, BM25ContentFilter, LLMContentFilter,
# JsonCssExtractionStrategy, CosineStrategy, RegexChunking,
# WebScrapingStrategy, LXMLWebScrapingStrategy
# )
# # Test Case 1: BM25 content filtering through markdown generator
# config1 = CrawlerRunConfig(
# cache_mode=CacheMode.BYPASS,
# markdown_generator=DefaultMarkdownGenerator(
# content_filter=BM25ContentFilter(
# user_query="technology articles",
# bm25_threshold=1.2,
# language="english"
# )
# ),
# chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
# excluded_tags=["nav", "footer", "aside"],
# remove_overlay_elements=True
# )
# # Serialize
# serialized = to_serializable_dict(config1)
# print("\nSerialized Config:")
# print(serialized)
# # Example output structure would now look like:
# """
# {
# "type": "CrawlerRunConfig",
# "params": {
# "cache_mode": {
# "type": "CacheMode",
# "params": "bypass"
# },
# "markdown_generator": {
# "type": "DefaultMarkdownGenerator",
# "params": {
# "content_filter": {
# "type": "BM25ContentFilter",
# "params": {
# "user_query": "technology articles",
# "bm25_threshold": 1.2,
# "language": "english"
# }
# }
# }
# }
# }
# }
# """
# # Deserialize
# deserialized = from_serializable_dict(serialized)
# print("\nDeserialized Config:")
# print(to_serializable_dict(deserialized))
# # Verify they match
# assert to_serializable_dict(config1) == to_serializable_dict(deserialized)
# print("\nVerification passed: Configuration matches after serialization/deserialization!")
if __name__ == "__main__":
from crawl4ai import (
CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
PruningContentFilter, BM25ContentFilter, LLMContentFilter,
JsonCssExtractionStrategy, RegexChunking,
WebScrapingStrategy, LXMLWebScrapingStrategy
)
# Test Case 1: BM25 content filtering through markdown generator
config1 = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=BM25ContentFilter(
user_query="technology articles",
bm25_threshold=1.2,
language="english"
)
),
chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
excluded_tags=["nav", "footer", "aside"],
remove_overlay_elements=True
)
# Test Case 2: LLM-based extraction with pruning filter
schema = {
"baseSelector": "article.post",
"fields": [
{"name": "title", "selector": "h1", "type": "text"},
{"name": "content", "selector": ".content", "type": "html"}
]
}
config2 = CrawlerRunConfig(
extraction_strategy=JsonCssExtractionStrategy(schema=schema),
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48,
threshold_type="fixed",
min_word_threshold=0
),
options={"ignore_links": True}
),
scraping_strategy=LXMLWebScrapingStrategy()
)
# Test Case 3:LLM content filter
config3 = CrawlerRunConfig(
markdown_generator=DefaultMarkdownGenerator(
content_filter=LLMContentFilter(
provider="openai/gpt-4",
instruction="Extract key technical concepts",
chunk_token_threshold=2000,
overlap_rate=0.1
),
options={"ignore_images": True}
),
scraping_strategy=WebScrapingStrategy()
)
# Test all configurations
test_configs = [config1, config2, config3]
for i, config in enumerate(test_configs, 1):
print(f"\nTesting Configuration {i}:")
# Serialize
serialized = to_serializable_dict(config)
print(f"\nSerialized Config {i}:")
print(serialized)
# Deserialize
deserialized = from_serializable_dict(serialized)
print(f"\nDeserialized Config {i}:")
print(to_serializable_dict(deserialized)) # Convert back to dict for comparison
# Verify they match
assert to_serializable_dict(config) == to_serializable_dict(deserialized)
print(f"\nVerification passed: Configuration {i} matches after serialization/deserialization!")