feat(docker): add Docker service integration and config serialization

Add Docker service integration with FastAPI server and client implementation.
Implement serialization utilities for BrowserConfig and CrawlerRunConfig to support
Docker service communication. Clean up imports and improve error handling.

- Add Crawl4aiDockerClient class
- Implement config serialization/deserialization
- Add FastAPI server with streaming support
- Add health check endpoint
- Clean up imports and type hints
This commit is contained in:
UncleCode
2025-01-31 18:00:16 +08:00
parent ce4f04dad2
commit 53ac3ec0b4
11 changed files with 859 additions and 97 deletions

View File

@@ -1,4 +1,5 @@
# __init__.py
import warnings
from .async_webcrawler import AsyncWebCrawler, CacheMode
from .async_configs import BrowserConfig, CrawlerRunConfig
@@ -28,6 +29,7 @@ from .async_dispatcher import (
DisplayMode,
BaseDispatcher
)
from .docker_client import Crawl4aiDockerClient
from .hub import CrawlerHub
__all__ = [
@@ -59,12 +61,13 @@ __all__ = [
"CrawlerMonitor",
"DisplayMode",
"MarkdownGenerationResult",
"Crawl4aiDockerClient",
]
def is_sync_version_installed():
try:
import selenium
import selenium # noqa
return True
except ImportError:
@@ -85,9 +88,6 @@ else:
# import warnings
# print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.")
import warnings
from pydantic import warnings as pydantic_warnings
# Disable all Pydantic warnings
warnings.filterwarnings("ignore", module="pydantic")
# pydantic_warnings.filter_warnings()
# pydantic_warnings.filter_warnings()

View File

@@ -7,17 +7,129 @@ from .config import (
SOCIAL_MEDIA_DOMAINS,
)
from .user_agent_generator import UserAgentGenerator, UAGen, ValidUAGenerator, OnlineUAGenerator
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
from .extraction_strategy import ExtractionStrategy
from .chunking_strategy import ChunkingStrategy, RegexChunking
from .markdown_generation_strategy import MarkdownGenerationStrategy
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter, LLMContentFilter, PruningContentFilter
from .content_filter_strategy import RelevantContentFilter # , BM25ContentFilter, LLMContentFilter, PruningContentFilter
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
from typing import Optional, Union, List
from typing import Union, List
from .cache_context import CacheMode
import inspect
from typing import Any, Dict
from enum import Enum
class BrowserConfig:
def to_serializable_dict(obj: Any) -> Dict:
"""
Recursively convert an object to a serializable dictionary using {type, params} structure
for complex objects.
"""
if obj is None:
return None
# Handle basic types
if isinstance(obj, (str, int, float, bool)):
return obj
# Handle Enum
if isinstance(obj, Enum):
return {
"type": obj.__class__.__name__,
"params": obj.value
}
# Handle datetime objects
if hasattr(obj, 'isoformat'):
return obj.isoformat()
# Handle lists, tuples, and sets
if isinstance(obj, (list, tuple, set)):
return [to_serializable_dict(item) for item in obj]
# Handle dictionaries - preserve them as-is
if isinstance(obj, dict):
return {
"type": "dict", # Mark as plain dictionary
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
}
# Handle class instances
if hasattr(obj, '__class__'):
# Get constructor signature
sig = inspect.signature(obj.__class__.__init__)
params = sig.parameters
# Get current values
current_values = {}
for name, param in params.items():
if name == 'self':
continue
value = getattr(obj, name, param.default)
# Only include if different from default, considering empty values
if not (is_empty_value(value) and is_empty_value(param.default)):
if value != param.default:
current_values[name] = to_serializable_dict(value)
return {
"type": obj.__class__.__name__,
"params": current_values
}
return str(obj)
def from_serializable_dict(data: Any) -> Any:
"""
Recursively convert a serializable dictionary back to an object instance.
"""
if data is None:
return None
# Handle basic types
if isinstance(data, (str, int, float, bool)):
return data
# Handle typed data
if isinstance(data, dict) and "type" in data:
# Handle plain dictionaries
if data["type"] == "dict":
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
# Import from crawl4ai for class instances
import crawl4ai
cls = getattr(crawl4ai, data["type"])
# Handle Enum
if issubclass(cls, Enum):
return cls(data["params"])
# Handle class instances
constructor_args = {
k: from_serializable_dict(v) for k, v in data["params"].items()
}
return cls(**constructor_args)
# Handle lists
if isinstance(data, list):
return [from_serializable_dict(item) for item in data]
# Handle raw dictionaries (legacy support)
if isinstance(data, dict):
return {k: from_serializable_dict(v) for k, v in data.items()}
return data
def is_empty_value(value: Any) -> bool:
"""Check if a value is effectively empty/null."""
if value is None:
return True
if isinstance(value, (list, tuple, set, dict, str)) and len(value) == 0:
return True
return False
class BrowserConfig():
"""
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
@@ -239,8 +351,18 @@ class BrowserConfig:
config_dict.update(kwargs)
return BrowserConfig.from_kwargs(config_dict)
# Create a funciton returns dict of the object
def dump(self) -> dict:
# Serialize the object to a dictionary
return to_serializable_dict(self)
class CrawlerRunConfig:
@staticmethod
def load( data: dict) -> "BrowserConfig":
# Deserialize the object from a dictionary
return from_serializable_dict(data)
class CrawlerRunConfig():
"""
Configuration class for controlling how the crawler runs each crawl operation.
This includes parameters for content extraction, page manipulation, waiting conditions,
@@ -665,6 +787,15 @@ class CrawlerRunConfig:
)
# Create a funciton returns dict of the object
def dump(self) -> dict:
# Serialize the object to a dictionary
return to_serializable_dict(self)
@staticmethod
def load(data: dict) -> "CrawlerRunConfig":
# Deserialize the object from a dictionary
return from_serializable_dict(data)
def to_dict(self):
return {
"word_count_threshold": self.word_count_threshold,

View File

@@ -1,3 +1,4 @@
from .__version__ import __version__ as crawl4ai_version
import os
import sys
import time
@@ -10,7 +11,7 @@ import asyncio
# from contextlib import nullcontext, asynccontextmanager
from contextlib import asynccontextmanager
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult
from .models import CrawlResult, MarkdownGenerationResult,DispatchResult
from .async_database import async_db_manager
from .chunking_strategy import * # noqa: F403
from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking
@@ -43,8 +44,7 @@ from .utils import (
RobotsParser,
)
from typing import Union, AsyncGenerator, List, TypeVar
from collections.abc import AsyncGenerator
from typing import Union, AsyncGenerator, TypeVar
CrawlResultT = TypeVar('CrawlResultT', bound=CrawlResult)
RunManyReturn = Union[List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
@@ -55,7 +55,7 @@ DeepCrawlManyReturn = Union[
AsyncGenerator[CrawlResultT, None],
]
from .__version__ import __version__ as crawl4ai_version
class AsyncWebCrawler:
@@ -768,18 +768,19 @@ class AsyncWebCrawler:
),
)
transform_result = lambda task_result: (
setattr(task_result.result, 'dispatch_result',
DispatchResult(
task_id=task_result.task_id,
memory_usage=task_result.memory_usage,
peak_memory=task_result.peak_memory,
start_time=task_result.start_time,
end_time=task_result.end_time,
error_message=task_result.error_message,
def transform_result(task_result):
return (
setattr(task_result.result, 'dispatch_result',
DispatchResult(
task_id=task_result.task_id,
memory_usage=task_result.memory_usage,
peak_memory=task_result.peak_memory,
start_time=task_result.start_time,
end_time=task_result.end_time,
error_message=task_result.error_message,
)
) or task_result.result
)
) or task_result.result
)
stream = config.stream

View File

@@ -9,16 +9,16 @@ from .utils import clean_tokens, perform_completion_with_backoff, escape_json_st
from abc import ABC, abstractmethod
import math
from snowballstemmer import stemmer
from .config import DEFAULT_PROVIDER, OVERLAP_RATE, WORD_TOKEN_RATE
from .config import DEFAULT_PROVIDER, OVERLAP_RATE, WORD_TOKEN_RATE, PROVIDER_MODELS
from .models import TokenUsage
from .prompts import PROMPT_FILTER_CONTENT
import os
import json
import hashlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from .async_logger import AsyncLogger, LogLevel
from colorama import Fore, Style, init
from colorama import Fore, Style
class RelevantContentFilter(ABC):
"""Abstract base class for content filtering strategies"""
@@ -879,7 +879,6 @@ class LLMContentFilter(RelevantContentFilter):
colors={"chunk_count": Fore.YELLOW}
)
extracted_content = []
start_time = time.time()
# Process chunks in parallel

210
crawl4ai/docker_client.py Normal file
View File

@@ -0,0 +1,210 @@
from typing import List, Optional, Union, AsyncGenerator, Dict, Any
import httpx
import json
from urllib.parse import urljoin
from .async_configs import BrowserConfig, CrawlerRunConfig
from .models import CrawlResult
from .async_logger import AsyncLogger, LogLevel
class Crawl4aiClientError(Exception):
"""Base exception for Crawl4ai Docker client errors."""
pass
class ConnectionError(Crawl4aiClientError):
"""Raised when connection to the Docker server fails."""
pass
class RequestError(Crawl4aiClientError):
"""Raised when the server returns an error response."""
pass
class Crawl4aiDockerClient:
"""
Client for interacting with Crawl4AI Docker server.
Args:
base_url (str): Base URL of the Crawl4AI Docker server
timeout (float): Default timeout for requests in seconds
verify_ssl (bool): Whether to verify SSL certificates
verbose (bool): Whether to show logging output
log_file (str, optional): Path to log file if file logging is desired
"""
def __init__(
self,
base_url: str = "http://localhost:8000",
timeout: float = 30.0,
verify_ssl: bool = True,
verbose: bool = True,
log_file: Optional[str] = None
) -> None:
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self._http_client = httpx.AsyncClient(
timeout=timeout,
verify=verify_ssl,
headers={"Content-Type": "application/json"}
)
self.logger = AsyncLogger(
log_file=log_file,
log_level=LogLevel.DEBUG,
verbose=verbose
)
async def _check_server_connection(self) -> bool:
"""Check if server is reachable."""
try:
self.logger.info("Checking server connection...", tag="INIT")
response = await self._http_client.get(f"{self.base_url}/health")
response.raise_for_status()
self.logger.success(f"Connected to server at {self.base_url}", tag="READY")
return True
except Exception as e:
self.logger.error(f"Failed to connect to server: {str(e)}", tag="ERROR")
return False
def _prepare_request_data(
self,
urls: List[str],
browser_config: Optional[BrowserConfig] = None,
crawler_config: Optional[CrawlerRunConfig] = None
) -> Dict[str, Any]:
"""Prepare request data from configs using dump methods."""
self.logger.debug("Preparing request data", tag="INIT")
data = {
"urls": urls,
"browser_config": browser_config.dump() if browser_config else {},
"crawler_config": crawler_config.dump() if crawler_config else {}
}
self.logger.debug(f"Request data prepared for {len(urls)} URLs", tag="READY")
return data
async def _make_request(
self,
method: str,
endpoint: str,
**kwargs
) -> Union[Dict, AsyncGenerator]:
"""Make HTTP request to the server with error handling."""
url = urljoin(self.base_url, endpoint)
try:
self.logger.debug(f"Making {method} request to {endpoint}", tag="FETCH")
response = await self._http_client.request(method, url, **kwargs)
response.raise_for_status()
self.logger.success(f"Request to {endpoint} successful", tag="COMPLETE")
return response
except httpx.TimeoutException as e:
error_msg = f"Request timed out: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise ConnectionError(error_msg)
except httpx.RequestError as e:
error_msg = f"Failed to connect to server: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise ConnectionError(error_msg)
except httpx.HTTPStatusError as e:
error_detail = ""
try:
error_data = e.response.json()
error_detail = error_data.get('detail', str(e))
except (json.JSONDecodeError, AttributeError) as json_err:
error_detail = f"{str(e)} (Failed to parse error response: {str(json_err)})"
error_msg = f"Server returned error {e.response.status_code}: {error_detail}"
self.logger.error(error_msg, tag="ERROR")
raise RequestError(error_msg)
async def crawl(
self,
urls: List[str],
browser_config: Optional[BrowserConfig] = None,
crawler_config: Optional[CrawlerRunConfig] = None
) -> Union[CrawlResult, AsyncGenerator[CrawlResult, None]]:
"""Execute a crawl operation through the Docker server."""
# Check server connection first
if not await self._check_server_connection():
raise ConnectionError("Cannot proceed with crawl - server is not reachable")
request_data = self._prepare_request_data(urls, browser_config, crawler_config)
is_streaming = crawler_config.stream if crawler_config else False
self.logger.info(
f"Starting crawl for {len(urls)} URLs {'(streaming)' if is_streaming else ''}",
tag="INIT"
)
if is_streaming:
async def result_generator() -> AsyncGenerator[CrawlResult, None]:
try:
async with self._http_client.stream(
"POST",
f"{self.base_url}/crawl",
json=request_data,
timeout=None
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip():
try:
result_dict = json.loads(line)
if "error" in result_dict:
self.logger.error_status(
url=result_dict.get('url', 'unknown'),
error=result_dict['error']
)
continue
self.logger.url_status(
url=result_dict.get('url', 'unknown'),
success=True,
timing=result_dict.get('timing', 0.0)
)
yield CrawlResult(**result_dict)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse server response: {e}", tag="ERROR")
continue
except httpx.StreamError as e:
error_msg = f"Stream connection error: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise ConnectionError(error_msg)
except Exception as e:
error_msg = f"Unexpected error during streaming: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise Crawl4aiClientError(error_msg)
return result_generator()
response = await self._make_request("POST", "/crawl", json=request_data)
response_data = response.json()
if not response_data.get("success", False):
error_msg = f"Crawl operation failed: {response_data.get('error', 'Unknown error')}"
self.logger.error(error_msg, tag="ERROR")
raise RequestError(error_msg)
results = [CrawlResult(**result_dict) for result_dict in response_data.get("results", [])]
self.logger.success(f"Crawl completed successfully with {len(results)} results", tag="COMPLETE")
return results[0] if len(results) == 1 else results
async def get_schema(self) -> Dict[str, Any]:
"""Retrieve the configuration schemas from the server."""
self.logger.info("Retrieving schema from server", tag="FETCH")
response = await self._make_request("GET", "/schema")
self.logger.success("Schema retrieved successfully", tag="COMPLETE")
return response.json()
async def close(self) -> None:
"""Close the HTTP client session."""
self.logger.info("Closing client connection", tag="COMPLETE")
await self._http_client.aclose()
async def __aenter__(self) -> "Crawl4aiDockerClient":
return self
async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[Any]) -> None:
await self.close()

View File

@@ -8,6 +8,7 @@ from crawl4ai import (
BM25ContentFilter,
LLMContentFilter,
# Add other strategy classes as needed
)
class StrategyConfig(BaseModel):

View File

@@ -1,10 +1,10 @@
# pyright: ignore
import os, sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
import json
import asyncio
from typing import AsyncGenerator
from datetime import datetime
from crawl4ai import (
BrowserConfig,
CrawlerRunConfig,
@@ -12,67 +12,36 @@ from crawl4ai import (
MemoryAdaptiveDispatcher,
RateLimiter,
)
from .models import CrawlRequest, CrawlResponse
class CrawlJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for crawler results"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, bytes):
return obj.decode('utf-8', errors='ignore')
if hasattr(obj, 'model_dump'):
return obj.model_dump()
if hasattr(obj, '__dict__'):
return {k: v for k, v in obj.__dict__.items() if not k.startswith('_')}
return str(obj) # Fallback to string representation
def serialize_result(result) -> dict:
"""Safely serialize a crawler result"""
try:
# Convert to dict handling special cases
if hasattr(result, 'model_dump'):
result_dict = result.model_dump()
else:
result_dict = {
k: v for k, v in result.__dict__.items()
if not k.startswith('_')
}
# Remove known non-serializable objects
result_dict.pop('ssl_certificate', None)
result_dict.pop('downloaded_files', None)
return result_dict
except Exception as e:
print(f"Error serializing result: {e}")
return {"error": str(e), "url": getattr(result, 'url', 'unknown')}
from models import CrawlRequest, CrawlResponse
app = FastAPI(title="Crawl4AI API")
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
"""Stream results and manage crawler lifecycle"""
def datetime_handler(obj):
"""Custom handler for datetime objects during JSON serialization"""
if hasattr(obj, 'isoformat'):
return obj.isoformat()
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
try:
async for result in results_gen:
try:
# Handle serialization of result
result_dict = serialize_result(result)
# Remove non-serializable objects
# Use dump method for serialization
result_dict = result.model_dump()
print(f"Streaming result for URL: {result_dict['url']}, Success: {result_dict['success']}")
yield (json.dumps(result_dict, cls=CrawlJSONEncoder) + "\n").encode('utf-8')
# Use custom JSON encoder with datetime handler
yield (json.dumps(result_dict, default=datetime_handler) + "\n").encode('utf-8')
except Exception as e:
# Log error but continue streaming
print(f"Error serializing result: {e}")
error_response = {
"error": str(e),
"url": getattr(result, 'url', 'unknown')
}
yield (json.dumps(error_response) + "\n").encode('utf-8')
yield (json.dumps(error_response, default=datetime_handler) + "\n").encode('utf-8')
except asyncio.CancelledError:
# Handle client disconnection gracefully
print("Client disconnected, cleaning up...")
finally:
# Ensure crawler cleanup happens in all cases
try:
await crawler.close()
except Exception as e:
@@ -80,17 +49,17 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
@app.post("/crawl")
async def crawl(request: CrawlRequest):
browser_config, crawler_config = request.get_configs()
# Load configs using our new utilities
browser_config = BrowserConfig.load(request.browser_config)
crawler_config = CrawlerRunConfig.load(request.crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=75.0,
memory_threshold_percent=95.0,
rate_limiter=RateLimiter(base_delay=(1.0, 2.0)),
# monitor=CrawlerMonitor(display_mode=DisplayMode.DETAILED)
)
try:
if crawler_config.stream:
# For streaming, manage crawler lifecycle manually
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
@@ -105,29 +74,14 @@ async def crawl(request: CrawlRequest):
media_type='application/x-ndjson'
)
else:
# For non-streaming, use context manager
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await crawler.arun_many(
urls=request.urls,
config=crawler_config,
dispatcher=dispatcher
)
# Handle serialization of results
results_dict = []
for result in results:
try:
result_dict = {
k: v for k, v in (result.model_dump() if hasattr(result, 'model_dump')
else result.__dict__).items()
if not k.startswith('_')
}
result_dict.pop('ssl_certificate', None)
result_dict.pop('downloaded_files', None)
results_dict.append(result_dict)
except Exception as e:
print(f"Error serializing result: {e}")
continue
# Use dump method for each result
results_dict = [result.model_dump() for result in results]
return CrawlResponse(success=True, results=results_dict)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -140,9 +94,12 @@ async def get_schema():
"crawler": CrawlerRunConfig.model_json_schema()
}
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
# Run in auto reload mode
# WARNING: You must pass the application as an import string to enable 'reload' or 'workers'.
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)

36
deploy/docker/utils.py Normal file
View File

@@ -0,0 +1,36 @@
import json
from datetime import datetime
class CrawlJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for crawler results"""
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, bytes):
return obj.decode('utf-8', errors='ignore')
if hasattr(obj, 'model_dump'):
return obj.model_dump()
if hasattr(obj, '__dict__'):
return {k: v for k, v in obj.__dict__.items() if not k.startswith('_')}
return str(obj) # Fallback to string representation
def serialize_result(result) -> dict:
"""Safely serialize a crawler result"""
try:
# Convert to dict handling special cases
if hasattr(result, 'model_dump'):
result_dict = result.model_dump()
else:
result_dict = {
k: v for k, v in result.__dict__.items()
if not k.startswith('_')
}
# Remove known non-serializable objects
result_dict.pop('ssl_certificate', None)
result_dict.pop('downloaded_files', None)
return result_dict
except Exception as e:
print(f"Error serializing result: {e}")
return {"error": str(e), "url": getattr(result, 'url', 'unknown')}

0
docker_client.py Normal file
View File

174
tests/docker/test_docker.py Normal file
View File

@@ -0,0 +1,174 @@
import requests
import time
import httpx
import asyncio
from typing import Dict, Any
from crawl4ai import (
BrowserConfig, CrawlerRunConfig, DefaultMarkdownGenerator,
PruningContentFilter, JsonCssExtractionStrategy, LLMContentFilter, CacheMode
)
from crawl4ai.docker_client import Crawl4aiDockerClient
class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235"):
self.base_url = base_url
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data)
task_id = response.json()["task_id"]
print(f"Task ID: {task_id}")
# Poll for result
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}")
status = result.json()
if status["status"] == "failed":
print("Task failed:", status.get("error"))
raise Exception(f"Task failed: {status.get('error')}")
if status["status"] == "completed":
return status
time.sleep(2)
async def test_direct_api():
"""Test direct API endpoints without using the client SDK"""
print("\n=== Testing Direct API Calls ===")
# Test 1: Basic crawl with content filtering
browser_config = BrowserConfig(
headless=True,
viewport_width=1200,
viewport_height=800
)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48,
threshold_type="fixed",
min_word_threshold=0
),
options={"ignore_links": True}
)
)
request_data = {
"urls": ["https://example.com"],
"browser_config": browser_config.dump(),
"crawler_config": crawler_config.dump()
}
# Make direct API call
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8000/crawl",
json=request_data,
timeout=300
)
assert response.status_code == 200
result = response.json()
print("Basic crawl result:", result["success"])
# Test 2: Structured extraction with JSON CSS
schema = {
"baseSelector": "article.post",
"fields": [
{"name": "title", "selector": "h1", "type": "text"},
{"name": "content", "selector": ".content", "type": "html"}
]
}
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
extraction_strategy=JsonCssExtractionStrategy(schema=schema)
)
request_data["crawler_config"] = crawler_config.dump()
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8000/crawl",
json=request_data
)
assert response.status_code == 200
result = response.json()
print("Structured extraction result:", result["success"])
# Test 3: Get schema
# async with httpx.AsyncClient() as client:
# response = await client.get("http://localhost:8000/schema")
# assert response.status_code == 200
# schemas = response.json()
# print("Retrieved schemas for:", list(schemas.keys()))
async def test_with_client():
"""Test using the Crawl4AI Docker client SDK"""
print("\n=== Testing Client SDK ===")
async with Crawl4aiDockerClient(verbose=True) as client:
# Test 1: Basic crawl
browser_config = BrowserConfig(headless=True)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48,
threshold_type="fixed"
)
)
)
result = await client.crawl(
urls=["https://example.com"],
browser_config=browser_config,
crawler_config=crawler_config
)
print("Client SDK basic crawl:", result.success)
# Test 2: LLM extraction with streaming
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=LLMContentFilter(
provider="openai/gpt-40",
instruction="Extract key technical concepts"
)
),
stream=True
)
async for result in await client.crawl(
urls=["https://example.com"],
browser_config=browser_config,
crawler_config=crawler_config
):
print(f"Streaming result for: {result.url}")
# # Test 3: Get schema
# schemas = await client.get_schema()
# print("Retrieved client schemas for:", list(schemas.keys()))
async def main():
"""Run all tests"""
# Test direct API
print("Testing direct API calls...")
# await test_direct_api()
# Test client SDK
print("\nTesting client SDK...")
await test_with_client()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,253 @@
import inspect
from typing import Any, Dict
from enum import Enum
def to_serializable_dict(obj: Any) -> Dict:
"""
Recursively convert an object to a serializable dictionary using {type, params} structure
for complex objects.
"""
if obj is None:
return None
# Handle basic types
if isinstance(obj, (str, int, float, bool)):
return obj
# Handle Enum
if isinstance(obj, Enum):
return {
"type": obj.__class__.__name__,
"params": obj.value
}
# Handle datetime objects
if hasattr(obj, 'isoformat'):
return obj.isoformat()
# Handle lists, tuples, and sets
if isinstance(obj, (list, tuple, set)):
return [to_serializable_dict(item) for item in obj]
# Handle dictionaries - preserve them as-is
if isinstance(obj, dict):
return {
"type": "dict", # Mark as plain dictionary
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
}
# Handle class instances
if hasattr(obj, '__class__'):
# Get constructor signature
sig = inspect.signature(obj.__class__.__init__)
params = sig.parameters
# Get current values
current_values = {}
for name, param in params.items():
if name == 'self':
continue
value = getattr(obj, name, param.default)
# Only include if different from default, considering empty values
if not (is_empty_value(value) and is_empty_value(param.default)):
if value != param.default:
current_values[name] = to_serializable_dict(value)
return {
"type": obj.__class__.__name__,
"params": current_values
}
return str(obj)
def from_serializable_dict(data: Any) -> Any:
"""
Recursively convert a serializable dictionary back to an object instance.
"""
if data is None:
return None
# Handle basic types
if isinstance(data, (str, int, float, bool)):
return data
# Handle typed data
if isinstance(data, dict) and "type" in data:
# Handle plain dictionaries
if data["type"] == "dict":
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
# Import from crawl4ai for class instances
import crawl4ai
cls = getattr(crawl4ai, data["type"])
# Handle Enum
if issubclass(cls, Enum):
return cls(data["params"])
# Handle class instances
constructor_args = {
k: from_serializable_dict(v) for k, v in data["params"].items()
}
return cls(**constructor_args)
# Handle lists
if isinstance(data, list):
return [from_serializable_dict(item) for item in data]
# Handle raw dictionaries (legacy support)
if isinstance(data, dict):
return {k: from_serializable_dict(v) for k, v in data.items()}
return data
def is_empty_value(value: Any) -> bool:
"""Check if a value is effectively empty/null."""
if value is None:
return True
if isinstance(value, (list, tuple, set, dict, str)) and len(value) == 0:
return True
return False
# if __name__ == "__main__":
# from crawl4ai import (
# CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
# PruningContentFilter, BM25ContentFilter, LLMContentFilter,
# JsonCssExtractionStrategy, CosineStrategy, RegexChunking,
# WebScrapingStrategy, LXMLWebScrapingStrategy
# )
# # Test Case 1: BM25 content filtering through markdown generator
# config1 = CrawlerRunConfig(
# cache_mode=CacheMode.BYPASS,
# markdown_generator=DefaultMarkdownGenerator(
# content_filter=BM25ContentFilter(
# user_query="technology articles",
# bm25_threshold=1.2,
# language="english"
# )
# ),
# chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
# excluded_tags=["nav", "footer", "aside"],
# remove_overlay_elements=True
# )
# # Serialize
# serialized = to_serializable_dict(config1)
# print("\nSerialized Config:")
# print(serialized)
# # Example output structure would now look like:
# """
# {
# "type": "CrawlerRunConfig",
# "params": {
# "cache_mode": {
# "type": "CacheMode",
# "params": "bypass"
# },
# "markdown_generator": {
# "type": "DefaultMarkdownGenerator",
# "params": {
# "content_filter": {
# "type": "BM25ContentFilter",
# "params": {
# "user_query": "technology articles",
# "bm25_threshold": 1.2,
# "language": "english"
# }
# }
# }
# }
# }
# }
# """
# # Deserialize
# deserialized = from_serializable_dict(serialized)
# print("\nDeserialized Config:")
# print(to_serializable_dict(deserialized))
# # Verify they match
# assert to_serializable_dict(config1) == to_serializable_dict(deserialized)
# print("\nVerification passed: Configuration matches after serialization/deserialization!")
if __name__ == "__main__":
from crawl4ai import (
CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
PruningContentFilter, BM25ContentFilter, LLMContentFilter,
JsonCssExtractionStrategy, RegexChunking,
WebScrapingStrategy, LXMLWebScrapingStrategy
)
# Test Case 1: BM25 content filtering through markdown generator
config1 = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=BM25ContentFilter(
user_query="technology articles",
bm25_threshold=1.2,
language="english"
)
),
chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
excluded_tags=["nav", "footer", "aside"],
remove_overlay_elements=True
)
# Test Case 2: LLM-based extraction with pruning filter
schema = {
"baseSelector": "article.post",
"fields": [
{"name": "title", "selector": "h1", "type": "text"},
{"name": "content", "selector": ".content", "type": "html"}
]
}
config2 = CrawlerRunConfig(
extraction_strategy=JsonCssExtractionStrategy(schema=schema),
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48,
threshold_type="fixed",
min_word_threshold=0
),
options={"ignore_links": True}
),
scraping_strategy=LXMLWebScrapingStrategy()
)
# Test Case 3:LLM content filter
config3 = CrawlerRunConfig(
markdown_generator=DefaultMarkdownGenerator(
content_filter=LLMContentFilter(
provider="openai/gpt-4",
instruction="Extract key technical concepts",
chunk_token_threshold=2000,
overlap_rate=0.1
),
options={"ignore_images": True}
),
scraping_strategy=WebScrapingStrategy()
)
# Test all configurations
test_configs = [config1, config2, config3]
for i, config in enumerate(test_configs, 1):
print(f"\nTesting Configuration {i}:")
# Serialize
serialized = to_serializable_dict(config)
print(f"\nSerialized Config {i}:")
print(serialized)
# Deserialize
deserialized = from_serializable_dict(serialized)
print(f"\nDeserialized Config {i}:")
print(to_serializable_dict(deserialized)) # Convert back to dict for comparison
# Verify they match
assert to_serializable_dict(config) == to_serializable_dict(deserialized)
print(f"\nVerification passed: Configuration {i} matches after serialization/deserialization!")