feat(docker): add Docker service integration and config serialization
Add Docker service integration with FastAPI server and client implementation. Implement serialization utilities for BrowserConfig and CrawlerRunConfig to support Docker service communication. Clean up imports and improve error handling. - Add Crawl4aiDockerClient class - Implement config serialization/deserialization - Add FastAPI server with streaming support - Add health check endpoint - Clean up imports and type hints
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# __init__.py
|
||||
import warnings
|
||||
|
||||
from .async_webcrawler import AsyncWebCrawler, CacheMode
|
||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||
@@ -28,6 +29,7 @@ from .async_dispatcher import (
|
||||
DisplayMode,
|
||||
BaseDispatcher
|
||||
)
|
||||
from .docker_client import Crawl4aiDockerClient
|
||||
from .hub import CrawlerHub
|
||||
|
||||
__all__ = [
|
||||
@@ -59,12 +61,13 @@ __all__ = [
|
||||
"CrawlerMonitor",
|
||||
"DisplayMode",
|
||||
"MarkdownGenerationResult",
|
||||
"Crawl4aiDockerClient",
|
||||
]
|
||||
|
||||
|
||||
def is_sync_version_installed():
|
||||
try:
|
||||
import selenium
|
||||
import selenium # noqa
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
@@ -85,9 +88,6 @@ else:
|
||||
# import warnings
|
||||
# print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.")
|
||||
|
||||
import warnings
|
||||
from pydantic import warnings as pydantic_warnings
|
||||
|
||||
# Disable all Pydantic warnings
|
||||
warnings.filterwarnings("ignore", module="pydantic")
|
||||
# pydantic_warnings.filter_warnings()
|
||||
# pydantic_warnings.filter_warnings()
|
||||
|
||||
@@ -7,17 +7,129 @@ from .config import (
|
||||
SOCIAL_MEDIA_DOMAINS,
|
||||
)
|
||||
|
||||
from .user_agent_generator import UserAgentGenerator, UAGen, ValidUAGenerator, OnlineUAGenerator
|
||||
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
|
||||
from .extraction_strategy import ExtractionStrategy
|
||||
from .chunking_strategy import ChunkingStrategy, RegexChunking
|
||||
from .markdown_generation_strategy import MarkdownGenerationStrategy
|
||||
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter, LLMContentFilter, PruningContentFilter
|
||||
from .content_filter_strategy import RelevantContentFilter # , BM25ContentFilter, LLMContentFilter, PruningContentFilter
|
||||
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
|
||||
from typing import Optional, Union, List
|
||||
from typing import Union, List
|
||||
from .cache_context import CacheMode
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict
|
||||
from enum import Enum
|
||||
|
||||
class BrowserConfig:
|
||||
def to_serializable_dict(obj: Any) -> Dict:
|
||||
"""
|
||||
Recursively convert an object to a serializable dictionary using {type, params} structure
|
||||
for complex objects.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle basic types
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
|
||||
# Handle Enum
|
||||
if isinstance(obj, Enum):
|
||||
return {
|
||||
"type": obj.__class__.__name__,
|
||||
"params": obj.value
|
||||
}
|
||||
|
||||
# Handle datetime objects
|
||||
if hasattr(obj, 'isoformat'):
|
||||
return obj.isoformat()
|
||||
|
||||
# Handle lists, tuples, and sets
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [to_serializable_dict(item) for item in obj]
|
||||
|
||||
# Handle dictionaries - preserve them as-is
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
"type": "dict", # Mark as plain dictionary
|
||||
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
|
||||
}
|
||||
|
||||
# Handle class instances
|
||||
if hasattr(obj, '__class__'):
|
||||
# Get constructor signature
|
||||
sig = inspect.signature(obj.__class__.__init__)
|
||||
params = sig.parameters
|
||||
|
||||
# Get current values
|
||||
current_values = {}
|
||||
for name, param in params.items():
|
||||
if name == 'self':
|
||||
continue
|
||||
|
||||
value = getattr(obj, name, param.default)
|
||||
|
||||
# Only include if different from default, considering empty values
|
||||
if not (is_empty_value(value) and is_empty_value(param.default)):
|
||||
if value != param.default:
|
||||
current_values[name] = to_serializable_dict(value)
|
||||
|
||||
return {
|
||||
"type": obj.__class__.__name__,
|
||||
"params": current_values
|
||||
}
|
||||
|
||||
return str(obj)
|
||||
|
||||
def from_serializable_dict(data: Any) -> Any:
|
||||
"""
|
||||
Recursively convert a serializable dictionary back to an object instance.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
# Handle basic types
|
||||
if isinstance(data, (str, int, float, bool)):
|
||||
return data
|
||||
|
||||
# Handle typed data
|
||||
if isinstance(data, dict) and "type" in data:
|
||||
# Handle plain dictionaries
|
||||
if data["type"] == "dict":
|
||||
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
|
||||
|
||||
# Import from crawl4ai for class instances
|
||||
import crawl4ai
|
||||
cls = getattr(crawl4ai, data["type"])
|
||||
|
||||
# Handle Enum
|
||||
if issubclass(cls, Enum):
|
||||
return cls(data["params"])
|
||||
|
||||
# Handle class instances
|
||||
constructor_args = {
|
||||
k: from_serializable_dict(v) for k, v in data["params"].items()
|
||||
}
|
||||
return cls(**constructor_args)
|
||||
|
||||
# Handle lists
|
||||
if isinstance(data, list):
|
||||
return [from_serializable_dict(item) for item in data]
|
||||
|
||||
# Handle raw dictionaries (legacy support)
|
||||
if isinstance(data, dict):
|
||||
return {k: from_serializable_dict(v) for k, v in data.items()}
|
||||
|
||||
return data
|
||||
|
||||
def is_empty_value(value: Any) -> bool:
|
||||
"""Check if a value is effectively empty/null."""
|
||||
if value is None:
|
||||
return True
|
||||
if isinstance(value, (list, tuple, set, dict, str)) and len(value) == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
class BrowserConfig():
|
||||
"""
|
||||
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
|
||||
|
||||
@@ -239,8 +351,18 @@ class BrowserConfig:
|
||||
config_dict.update(kwargs)
|
||||
return BrowserConfig.from_kwargs(config_dict)
|
||||
|
||||
# Create a funciton returns dict of the object
|
||||
def dump(self) -> dict:
|
||||
# Serialize the object to a dictionary
|
||||
return to_serializable_dict(self)
|
||||
|
||||
class CrawlerRunConfig:
|
||||
@staticmethod
|
||||
def load( data: dict) -> "BrowserConfig":
|
||||
# Deserialize the object from a dictionary
|
||||
return from_serializable_dict(data)
|
||||
|
||||
|
||||
class CrawlerRunConfig():
|
||||
"""
|
||||
Configuration class for controlling how the crawler runs each crawl operation.
|
||||
This includes parameters for content extraction, page manipulation, waiting conditions,
|
||||
@@ -665,6 +787,15 @@ class CrawlerRunConfig:
|
||||
)
|
||||
|
||||
# Create a funciton returns dict of the object
|
||||
def dump(self) -> dict:
|
||||
# Serialize the object to a dictionary
|
||||
return to_serializable_dict(self)
|
||||
|
||||
@staticmethod
|
||||
def load(data: dict) -> "CrawlerRunConfig":
|
||||
# Deserialize the object from a dictionary
|
||||
return from_serializable_dict(data)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"word_count_threshold": self.word_count_threshold,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .__version__ import __version__ as crawl4ai_version
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -10,7 +11,7 @@ import asyncio
|
||||
|
||||
# from contextlib import nullcontext, asynccontextmanager
|
||||
from contextlib import asynccontextmanager
|
||||
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult
|
||||
from .models import CrawlResult, MarkdownGenerationResult,DispatchResult
|
||||
from .async_database import async_db_manager
|
||||
from .chunking_strategy import * # noqa: F403
|
||||
from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking
|
||||
@@ -43,8 +44,7 @@ from .utils import (
|
||||
RobotsParser,
|
||||
)
|
||||
|
||||
from typing import Union, AsyncGenerator, List, TypeVar
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Union, AsyncGenerator, TypeVar
|
||||
|
||||
CrawlResultT = TypeVar('CrawlResultT', bound=CrawlResult)
|
||||
RunManyReturn = Union[List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
|
||||
@@ -55,7 +55,7 @@ DeepCrawlManyReturn = Union[
|
||||
AsyncGenerator[CrawlResultT, None],
|
||||
]
|
||||
|
||||
from .__version__ import __version__ as crawl4ai_version
|
||||
|
||||
|
||||
|
||||
class AsyncWebCrawler:
|
||||
@@ -768,18 +768,19 @@ class AsyncWebCrawler:
|
||||
),
|
||||
)
|
||||
|
||||
transform_result = lambda task_result: (
|
||||
setattr(task_result.result, 'dispatch_result',
|
||||
DispatchResult(
|
||||
task_id=task_result.task_id,
|
||||
memory_usage=task_result.memory_usage,
|
||||
peak_memory=task_result.peak_memory,
|
||||
start_time=task_result.start_time,
|
||||
end_time=task_result.end_time,
|
||||
error_message=task_result.error_message,
|
||||
def transform_result(task_result):
|
||||
return (
|
||||
setattr(task_result.result, 'dispatch_result',
|
||||
DispatchResult(
|
||||
task_id=task_result.task_id,
|
||||
memory_usage=task_result.memory_usage,
|
||||
peak_memory=task_result.peak_memory,
|
||||
start_time=task_result.start_time,
|
||||
end_time=task_result.end_time,
|
||||
error_message=task_result.error_message,
|
||||
)
|
||||
) or task_result.result
|
||||
)
|
||||
) or task_result.result
|
||||
)
|
||||
|
||||
stream = config.stream
|
||||
|
||||
|
||||
@@ -9,16 +9,16 @@ from .utils import clean_tokens, perform_completion_with_backoff, escape_json_st
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
from snowballstemmer import stemmer
|
||||
from .config import DEFAULT_PROVIDER, OVERLAP_RATE, WORD_TOKEN_RATE
|
||||
from .config import DEFAULT_PROVIDER, OVERLAP_RATE, WORD_TOKEN_RATE, PROVIDER_MODELS
|
||||
from .models import TokenUsage
|
||||
from .prompts import PROMPT_FILTER_CONTENT
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from .async_logger import AsyncLogger, LogLevel
|
||||
from colorama import Fore, Style, init
|
||||
from colorama import Fore, Style
|
||||
|
||||
class RelevantContentFilter(ABC):
|
||||
"""Abstract base class for content filtering strategies"""
|
||||
@@ -879,7 +879,6 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
colors={"chunk_count": Fore.YELLOW}
|
||||
)
|
||||
|
||||
extracted_content = []
|
||||
start_time = time.time()
|
||||
|
||||
# Process chunks in parallel
|
||||
|
||||
210
crawl4ai/docker_client.py
Normal file
210
crawl4ai/docker_client.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from typing import List, Optional, Union, AsyncGenerator, Dict, Any
|
||||
import httpx
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from .models import CrawlResult
|
||||
from .async_logger import AsyncLogger, LogLevel
|
||||
|
||||
|
||||
class Crawl4aiClientError(Exception):
|
||||
"""Base exception for Crawl4ai Docker client errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(Crawl4aiClientError):
|
||||
"""Raised when connection to the Docker server fails."""
|
||||
pass
|
||||
|
||||
|
||||
class RequestError(Crawl4aiClientError):
|
||||
"""Raised when the server returns an error response."""
|
||||
pass
|
||||
|
||||
|
||||
class Crawl4aiDockerClient:
|
||||
"""
|
||||
Client for interacting with Crawl4AI Docker server.
|
||||
|
||||
Args:
|
||||
base_url (str): Base URL of the Crawl4AI Docker server
|
||||
timeout (float): Default timeout for requests in seconds
|
||||
verify_ssl (bool): Whether to verify SSL certificates
|
||||
verbose (bool): Whether to show logging output
|
||||
log_file (str, optional): Path to log file if file logging is desired
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
timeout: float = 30.0,
|
||||
verify_ssl: bool = True,
|
||||
verbose: bool = True,
|
||||
log_file: Optional[str] = None
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.timeout = timeout
|
||||
self._http_client = httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
verify=verify_ssl,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
self.logger = AsyncLogger(
|
||||
log_file=log_file,
|
||||
log_level=LogLevel.DEBUG,
|
||||
verbose=verbose
|
||||
)
|
||||
|
||||
async def _check_server_connection(self) -> bool:
|
||||
"""Check if server is reachable."""
|
||||
try:
|
||||
self.logger.info("Checking server connection...", tag="INIT")
|
||||
response = await self._http_client.get(f"{self.base_url}/health")
|
||||
response.raise_for_status()
|
||||
self.logger.success(f"Connected to server at {self.base_url}", tag="READY")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to connect to server: {str(e)}", tag="ERROR")
|
||||
return False
|
||||
|
||||
def _prepare_request_data(
|
||||
self,
|
||||
urls: List[str],
|
||||
browser_config: Optional[BrowserConfig] = None,
|
||||
crawler_config: Optional[CrawlerRunConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare request data from configs using dump methods."""
|
||||
self.logger.debug("Preparing request data", tag="INIT")
|
||||
data = {
|
||||
"urls": urls,
|
||||
"browser_config": browser_config.dump() if browser_config else {},
|
||||
"crawler_config": crawler_config.dump() if crawler_config else {}
|
||||
}
|
||||
self.logger.debug(f"Request data prepared for {len(urls)} URLs", tag="READY")
|
||||
return data
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
**kwargs
|
||||
) -> Union[Dict, AsyncGenerator]:
|
||||
"""Make HTTP request to the server with error handling."""
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
|
||||
try:
|
||||
self.logger.debug(f"Making {method} request to {endpoint}", tag="FETCH")
|
||||
response = await self._http_client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
self.logger.success(f"Request to {endpoint} successful", tag="COMPLETE")
|
||||
return response
|
||||
except httpx.TimeoutException as e:
|
||||
error_msg = f"Request timed out: {str(e)}"
|
||||
self.logger.error(error_msg, tag="ERROR")
|
||||
raise ConnectionError(error_msg)
|
||||
except httpx.RequestError as e:
|
||||
error_msg = f"Failed to connect to server: {str(e)}"
|
||||
self.logger.error(error_msg, tag="ERROR")
|
||||
raise ConnectionError(error_msg)
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = ""
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail = error_data.get('detail', str(e))
|
||||
except (json.JSONDecodeError, AttributeError) as json_err:
|
||||
error_detail = f"{str(e)} (Failed to parse error response: {str(json_err)})"
|
||||
|
||||
error_msg = f"Server returned error {e.response.status_code}: {error_detail}"
|
||||
self.logger.error(error_msg, tag="ERROR")
|
||||
raise RequestError(error_msg)
|
||||
|
||||
async def crawl(
|
||||
self,
|
||||
urls: List[str],
|
||||
browser_config: Optional[BrowserConfig] = None,
|
||||
crawler_config: Optional[CrawlerRunConfig] = None
|
||||
) -> Union[CrawlResult, AsyncGenerator[CrawlResult, None]]:
|
||||
"""Execute a crawl operation through the Docker server."""
|
||||
# Check server connection first
|
||||
if not await self._check_server_connection():
|
||||
raise ConnectionError("Cannot proceed with crawl - server is not reachable")
|
||||
|
||||
request_data = self._prepare_request_data(urls, browser_config, crawler_config)
|
||||
is_streaming = crawler_config.stream if crawler_config else False
|
||||
|
||||
self.logger.info(
|
||||
f"Starting crawl for {len(urls)} URLs {'(streaming)' if is_streaming else ''}",
|
||||
tag="INIT"
|
||||
)
|
||||
|
||||
if is_streaming:
|
||||
async def result_generator() -> AsyncGenerator[CrawlResult, None]:
|
||||
try:
|
||||
async with self._http_client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/crawl",
|
||||
json=request_data,
|
||||
timeout=None
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
result_dict = json.loads(line)
|
||||
if "error" in result_dict:
|
||||
self.logger.error_status(
|
||||
url=result_dict.get('url', 'unknown'),
|
||||
error=result_dict['error']
|
||||
)
|
||||
continue
|
||||
|
||||
self.logger.url_status(
|
||||
url=result_dict.get('url', 'unknown'),
|
||||
success=True,
|
||||
timing=result_dict.get('timing', 0.0)
|
||||
)
|
||||
yield CrawlResult(**result_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Failed to parse server response: {e}", tag="ERROR")
|
||||
continue
|
||||
except httpx.StreamError as e:
|
||||
error_msg = f"Stream connection error: {str(e)}"
|
||||
self.logger.error(error_msg, tag="ERROR")
|
||||
raise ConnectionError(error_msg)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during streaming: {str(e)}"
|
||||
self.logger.error(error_msg, tag="ERROR")
|
||||
raise Crawl4aiClientError(error_msg)
|
||||
|
||||
return result_generator()
|
||||
|
||||
response = await self._make_request("POST", "/crawl", json=request_data)
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("success", False):
|
||||
error_msg = f"Crawl operation failed: {response_data.get('error', 'Unknown error')}"
|
||||
self.logger.error(error_msg, tag="ERROR")
|
||||
raise RequestError(error_msg)
|
||||
|
||||
results = [CrawlResult(**result_dict) for result_dict in response_data.get("results", [])]
|
||||
self.logger.success(f"Crawl completed successfully with {len(results)} results", tag="COMPLETE")
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
async def get_schema(self) -> Dict[str, Any]:
|
||||
"""Retrieve the configuration schemas from the server."""
|
||||
self.logger.info("Retrieving schema from server", tag="FETCH")
|
||||
response = await self._make_request("GET", "/schema")
|
||||
self.logger.success("Schema retrieved successfully", tag="COMPLETE")
|
||||
return response.json()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client session."""
|
||||
self.logger.info("Closing client connection", tag="COMPLETE")
|
||||
await self._http_client.aclose()
|
||||
|
||||
async def __aenter__(self) -> "Crawl4aiDockerClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[Any]) -> None:
|
||||
await self.close()
|
||||
@@ -8,6 +8,7 @@ from crawl4ai import (
|
||||
BM25ContentFilter,
|
||||
LLMContentFilter,
|
||||
# Add other strategy classes as needed
|
||||
|
||||
)
|
||||
|
||||
class StrategyConfig(BaseModel):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# pyright: ignore
|
||||
import os, sys
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from crawl4ai import (
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
@@ -12,67 +12,36 @@ from crawl4ai import (
|
||||
MemoryAdaptiveDispatcher,
|
||||
RateLimiter,
|
||||
)
|
||||
from .models import CrawlRequest, CrawlResponse
|
||||
|
||||
class CrawlJSONEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder for crawler results"""
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8', errors='ignore')
|
||||
if hasattr(obj, 'model_dump'):
|
||||
return obj.model_dump()
|
||||
if hasattr(obj, '__dict__'):
|
||||
return {k: v for k, v in obj.__dict__.items() if not k.startswith('_')}
|
||||
return str(obj) # Fallback to string representation
|
||||
|
||||
def serialize_result(result) -> dict:
|
||||
"""Safely serialize a crawler result"""
|
||||
try:
|
||||
# Convert to dict handling special cases
|
||||
if hasattr(result, 'model_dump'):
|
||||
result_dict = result.model_dump()
|
||||
else:
|
||||
result_dict = {
|
||||
k: v for k, v in result.__dict__.items()
|
||||
if not k.startswith('_')
|
||||
}
|
||||
|
||||
# Remove known non-serializable objects
|
||||
result_dict.pop('ssl_certificate', None)
|
||||
result_dict.pop('downloaded_files', None)
|
||||
|
||||
return result_dict
|
||||
except Exception as e:
|
||||
print(f"Error serializing result: {e}")
|
||||
return {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
||||
from models import CrawlRequest, CrawlResponse
|
||||
|
||||
app = FastAPI(title="Crawl4AI API")
|
||||
|
||||
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream results and manage crawler lifecycle"""
|
||||
def datetime_handler(obj):
|
||||
"""Custom handler for datetime objects during JSON serialization"""
|
||||
if hasattr(obj, 'isoformat'):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
|
||||
try:
|
||||
async for result in results_gen:
|
||||
try:
|
||||
# Handle serialization of result
|
||||
result_dict = serialize_result(result)
|
||||
# Remove non-serializable objects
|
||||
# Use dump method for serialization
|
||||
result_dict = result.model_dump()
|
||||
print(f"Streaming result for URL: {result_dict['url']}, Success: {result_dict['success']}")
|
||||
yield (json.dumps(result_dict, cls=CrawlJSONEncoder) + "\n").encode('utf-8')
|
||||
# Use custom JSON encoder with datetime handler
|
||||
yield (json.dumps(result_dict, default=datetime_handler) + "\n").encode('utf-8')
|
||||
except Exception as e:
|
||||
# Log error but continue streaming
|
||||
print(f"Error serializing result: {e}")
|
||||
error_response = {
|
||||
"error": str(e),
|
||||
"url": getattr(result, 'url', 'unknown')
|
||||
}
|
||||
yield (json.dumps(error_response) + "\n").encode('utf-8')
|
||||
yield (json.dumps(error_response, default=datetime_handler) + "\n").encode('utf-8')
|
||||
except asyncio.CancelledError:
|
||||
# Handle client disconnection gracefully
|
||||
print("Client disconnected, cleaning up...")
|
||||
finally:
|
||||
# Ensure crawler cleanup happens in all cases
|
||||
try:
|
||||
await crawler.close()
|
||||
except Exception as e:
|
||||
@@ -80,17 +49,17 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
|
||||
|
||||
@app.post("/crawl")
|
||||
async def crawl(request: CrawlRequest):
|
||||
browser_config, crawler_config = request.get_configs()
|
||||
# Load configs using our new utilities
|
||||
browser_config = BrowserConfig.load(request.browser_config)
|
||||
crawler_config = CrawlerRunConfig.load(request.crawler_config)
|
||||
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=75.0,
|
||||
memory_threshold_percent=95.0,
|
||||
rate_limiter=RateLimiter(base_delay=(1.0, 2.0)),
|
||||
# monitor=CrawlerMonitor(display_mode=DisplayMode.DETAILED)
|
||||
)
|
||||
|
||||
try:
|
||||
if crawler_config.stream:
|
||||
# For streaming, manage crawler lifecycle manually
|
||||
crawler = AsyncWebCrawler(config=browser_config)
|
||||
await crawler.start()
|
||||
|
||||
@@ -105,29 +74,14 @@ async def crawl(request: CrawlRequest):
|
||||
media_type='application/x-ndjson'
|
||||
)
|
||||
else:
|
||||
# For non-streaming, use context manager
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
results = await crawler.arun_many(
|
||||
urls=request.urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
)
|
||||
# Handle serialization of results
|
||||
results_dict = []
|
||||
for result in results:
|
||||
try:
|
||||
result_dict = {
|
||||
k: v for k, v in (result.model_dump() if hasattr(result, 'model_dump')
|
||||
else result.__dict__).items()
|
||||
if not k.startswith('_')
|
||||
}
|
||||
result_dict.pop('ssl_certificate', None)
|
||||
result_dict.pop('downloaded_files', None)
|
||||
results_dict.append(result_dict)
|
||||
except Exception as e:
|
||||
print(f"Error serializing result: {e}")
|
||||
continue
|
||||
|
||||
# Use dump method for each result
|
||||
results_dict = [result.model_dump() for result in results]
|
||||
return CrawlResponse(success=True, results=results_dict)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -140,9 +94,12 @@ async def get_schema():
|
||||
"crawler": CrawlerRunConfig.model_json_schema()
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# Run in auto reload mode
|
||||
# WARNING: You must pass the application as an import string to enable 'reload' or 'workers'.
|
||||
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
|
||||
36
deploy/docker/utils.py
Normal file
36
deploy/docker/utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
class CrawlJSONEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder for crawler results"""
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8', errors='ignore')
|
||||
if hasattr(obj, 'model_dump'):
|
||||
return obj.model_dump()
|
||||
if hasattr(obj, '__dict__'):
|
||||
return {k: v for k, v in obj.__dict__.items() if not k.startswith('_')}
|
||||
return str(obj) # Fallback to string representation
|
||||
|
||||
def serialize_result(result) -> dict:
|
||||
"""Safely serialize a crawler result"""
|
||||
try:
|
||||
# Convert to dict handling special cases
|
||||
if hasattr(result, 'model_dump'):
|
||||
result_dict = result.model_dump()
|
||||
else:
|
||||
result_dict = {
|
||||
k: v for k, v in result.__dict__.items()
|
||||
if not k.startswith('_')
|
||||
}
|
||||
|
||||
# Remove known non-serializable objects
|
||||
result_dict.pop('ssl_certificate', None)
|
||||
result_dict.pop('downloaded_files', None)
|
||||
|
||||
return result_dict
|
||||
except Exception as e:
|
||||
print(f"Error serializing result: {e}")
|
||||
return {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
||||
0
docker_client.py
Normal file
0
docker_client.py
Normal file
174
tests/docker/test_docker.py
Normal file
174
tests/docker/test_docker.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import requests
|
||||
import time
|
||||
import httpx
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from crawl4ai import (
|
||||
BrowserConfig, CrawlerRunConfig, DefaultMarkdownGenerator,
|
||||
PruningContentFilter, JsonCssExtractionStrategy, LLMContentFilter, CacheMode
|
||||
)
|
||||
from crawl4ai.docker_client import Crawl4aiDockerClient
|
||||
|
||||
class Crawl4AiTester:
|
||||
def __init__(self, base_url: str = "http://localhost:11235"):
|
||||
self.base_url = base_url
|
||||
|
||||
def submit_and_wait(
|
||||
self, request_data: Dict[str, Any], timeout: int = 300
|
||||
) -> Dict[str, Any]:
|
||||
# Submit crawl job
|
||||
response = requests.post(f"{self.base_url}/crawl", json=request_data)
|
||||
task_id = response.json()["task_id"]
|
||||
print(f"Task ID: {task_id}")
|
||||
|
||||
# Poll for result
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
result = requests.get(f"{self.base_url}/task/{task_id}")
|
||||
status = result.json()
|
||||
|
||||
if status["status"] == "failed":
|
||||
print("Task failed:", status.get("error"))
|
||||
raise Exception(f"Task failed: {status.get('error')}")
|
||||
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
async def test_direct_api():
|
||||
"""Test direct API endpoints without using the client SDK"""
|
||||
print("\n=== Testing Direct API Calls ===")
|
||||
|
||||
# Test 1: Basic crawl with content filtering
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
viewport_width=1200,
|
||||
viewport_height=800
|
||||
)
|
||||
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48,
|
||||
threshold_type="fixed",
|
||||
min_word_threshold=0
|
||||
),
|
||||
options={"ignore_links": True}
|
||||
)
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"urls": ["https://example.com"],
|
||||
"browser_config": browser_config.dump(),
|
||||
"crawler_config": crawler_config.dump()
|
||||
}
|
||||
|
||||
# Make direct API call
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"http://localhost:8000/crawl",
|
||||
json=request_data,
|
||||
timeout=300
|
||||
)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print("Basic crawl result:", result["success"])
|
||||
|
||||
# Test 2: Structured extraction with JSON CSS
|
||||
schema = {
|
||||
"baseSelector": "article.post",
|
||||
"fields": [
|
||||
{"name": "title", "selector": "h1", "type": "text"},
|
||||
{"name": "content", "selector": ".content", "type": "html"}
|
||||
]
|
||||
}
|
||||
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
extraction_strategy=JsonCssExtractionStrategy(schema=schema)
|
||||
)
|
||||
|
||||
request_data["crawler_config"] = crawler_config.dump()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"http://localhost:8000/crawl",
|
||||
json=request_data
|
||||
)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print("Structured extraction result:", result["success"])
|
||||
|
||||
# Test 3: Get schema
|
||||
# async with httpx.AsyncClient() as client:
|
||||
# response = await client.get("http://localhost:8000/schema")
|
||||
# assert response.status_code == 200
|
||||
# schemas = response.json()
|
||||
# print("Retrieved schemas for:", list(schemas.keys()))
|
||||
|
||||
async def test_with_client():
|
||||
"""Test using the Crawl4AI Docker client SDK"""
|
||||
print("\n=== Testing Client SDK ===")
|
||||
|
||||
async with Crawl4aiDockerClient(verbose=True) as client:
|
||||
# Test 1: Basic crawl
|
||||
browser_config = BrowserConfig(headless=True)
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48,
|
||||
threshold_type="fixed"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result = await client.crawl(
|
||||
urls=["https://example.com"],
|
||||
browser_config=browser_config,
|
||||
crawler_config=crawler_config
|
||||
)
|
||||
print("Client SDK basic crawl:", result.success)
|
||||
|
||||
# Test 2: LLM extraction with streaming
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=LLMContentFilter(
|
||||
provider="openai/gpt-40",
|
||||
instruction="Extract key technical concepts"
|
||||
)
|
||||
),
|
||||
stream=True
|
||||
)
|
||||
|
||||
async for result in await client.crawl(
|
||||
urls=["https://example.com"],
|
||||
browser_config=browser_config,
|
||||
crawler_config=crawler_config
|
||||
):
|
||||
print(f"Streaming result for: {result.url}")
|
||||
|
||||
# # Test 3: Get schema
|
||||
# schemas = await client.get_schema()
|
||||
# print("Retrieved client schemas for:", list(schemas.keys()))
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
# Test direct API
|
||||
print("Testing direct API calls...")
|
||||
# await test_direct_api()
|
||||
|
||||
# Test client SDK
|
||||
print("\nTesting client SDK...")
|
||||
await test_with_client()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
253
tests/docker/test_serialization.py
Normal file
253
tests/docker/test_serialization.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import inspect
|
||||
from typing import Any, Dict
|
||||
from enum import Enum
|
||||
|
||||
def to_serializable_dict(obj: Any) -> Dict:
|
||||
"""
|
||||
Recursively convert an object to a serializable dictionary using {type, params} structure
|
||||
for complex objects.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle basic types
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
|
||||
# Handle Enum
|
||||
if isinstance(obj, Enum):
|
||||
return {
|
||||
"type": obj.__class__.__name__,
|
||||
"params": obj.value
|
||||
}
|
||||
|
||||
# Handle datetime objects
|
||||
if hasattr(obj, 'isoformat'):
|
||||
return obj.isoformat()
|
||||
|
||||
# Handle lists, tuples, and sets
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [to_serializable_dict(item) for item in obj]
|
||||
|
||||
# Handle dictionaries - preserve them as-is
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
"type": "dict", # Mark as plain dictionary
|
||||
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
|
||||
}
|
||||
|
||||
# Handle class instances
|
||||
if hasattr(obj, '__class__'):
|
||||
# Get constructor signature
|
||||
sig = inspect.signature(obj.__class__.__init__)
|
||||
params = sig.parameters
|
||||
|
||||
# Get current values
|
||||
current_values = {}
|
||||
for name, param in params.items():
|
||||
if name == 'self':
|
||||
continue
|
||||
|
||||
value = getattr(obj, name, param.default)
|
||||
|
||||
# Only include if different from default, considering empty values
|
||||
if not (is_empty_value(value) and is_empty_value(param.default)):
|
||||
if value != param.default:
|
||||
current_values[name] = to_serializable_dict(value)
|
||||
|
||||
return {
|
||||
"type": obj.__class__.__name__,
|
||||
"params": current_values
|
||||
}
|
||||
|
||||
return str(obj)
|
||||
|
||||
def from_serializable_dict(data: Any) -> Any:
|
||||
"""
|
||||
Recursively convert a serializable dictionary back to an object instance.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
# Handle basic types
|
||||
if isinstance(data, (str, int, float, bool)):
|
||||
return data
|
||||
|
||||
# Handle typed data
|
||||
if isinstance(data, dict) and "type" in data:
|
||||
# Handle plain dictionaries
|
||||
if data["type"] == "dict":
|
||||
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
|
||||
|
||||
# Import from crawl4ai for class instances
|
||||
import crawl4ai
|
||||
cls = getattr(crawl4ai, data["type"])
|
||||
|
||||
# Handle Enum
|
||||
if issubclass(cls, Enum):
|
||||
return cls(data["params"])
|
||||
|
||||
# Handle class instances
|
||||
constructor_args = {
|
||||
k: from_serializable_dict(v) for k, v in data["params"].items()
|
||||
}
|
||||
return cls(**constructor_args)
|
||||
|
||||
# Handle lists
|
||||
if isinstance(data, list):
|
||||
return [from_serializable_dict(item) for item in data]
|
||||
|
||||
# Handle raw dictionaries (legacy support)
|
||||
if isinstance(data, dict):
|
||||
return {k: from_serializable_dict(v) for k, v in data.items()}
|
||||
|
||||
return data
|
||||
|
||||
def is_empty_value(value: Any) -> bool:
|
||||
"""Check if a value is effectively empty/null."""
|
||||
if value is None:
|
||||
return True
|
||||
if isinstance(value, (list, tuple, set, dict, str)) and len(value) == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# from crawl4ai import (
|
||||
# CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
|
||||
# PruningContentFilter, BM25ContentFilter, LLMContentFilter,
|
||||
# JsonCssExtractionStrategy, CosineStrategy, RegexChunking,
|
||||
# WebScrapingStrategy, LXMLWebScrapingStrategy
|
||||
# )
|
||||
|
||||
# # Test Case 1: BM25 content filtering through markdown generator
|
||||
# config1 = CrawlerRunConfig(
|
||||
# cache_mode=CacheMode.BYPASS,
|
||||
# markdown_generator=DefaultMarkdownGenerator(
|
||||
# content_filter=BM25ContentFilter(
|
||||
# user_query="technology articles",
|
||||
# bm25_threshold=1.2,
|
||||
# language="english"
|
||||
# )
|
||||
# ),
|
||||
# chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
|
||||
# excluded_tags=["nav", "footer", "aside"],
|
||||
# remove_overlay_elements=True
|
||||
# )
|
||||
|
||||
# # Serialize
|
||||
# serialized = to_serializable_dict(config1)
|
||||
# print("\nSerialized Config:")
|
||||
# print(serialized)
|
||||
|
||||
# # Example output structure would now look like:
|
||||
# """
|
||||
# {
|
||||
# "type": "CrawlerRunConfig",
|
||||
# "params": {
|
||||
# "cache_mode": {
|
||||
# "type": "CacheMode",
|
||||
# "params": "bypass"
|
||||
# },
|
||||
# "markdown_generator": {
|
||||
# "type": "DefaultMarkdownGenerator",
|
||||
# "params": {
|
||||
# "content_filter": {
|
||||
# "type": "BM25ContentFilter",
|
||||
# "params": {
|
||||
# "user_query": "technology articles",
|
||||
# "bm25_threshold": 1.2,
|
||||
# "language": "english"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# """
|
||||
|
||||
# # Deserialize
|
||||
# deserialized = from_serializable_dict(serialized)
|
||||
# print("\nDeserialized Config:")
|
||||
# print(to_serializable_dict(deserialized))
|
||||
|
||||
# # Verify they match
|
||||
# assert to_serializable_dict(config1) == to_serializable_dict(deserialized)
|
||||
# print("\nVerification passed: Configuration matches after serialization/deserialization!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
from crawl4ai import (
|
||||
CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
|
||||
PruningContentFilter, BM25ContentFilter, LLMContentFilter,
|
||||
JsonCssExtractionStrategy, RegexChunking,
|
||||
WebScrapingStrategy, LXMLWebScrapingStrategy
|
||||
)
|
||||
|
||||
# Test Case 1: BM25 content filtering through markdown generator
|
||||
config1 = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=BM25ContentFilter(
|
||||
user_query="technology articles",
|
||||
bm25_threshold=1.2,
|
||||
language="english"
|
||||
)
|
||||
),
|
||||
chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
|
||||
excluded_tags=["nav", "footer", "aside"],
|
||||
remove_overlay_elements=True
|
||||
)
|
||||
|
||||
# Test Case 2: LLM-based extraction with pruning filter
|
||||
schema = {
|
||||
"baseSelector": "article.post",
|
||||
"fields": [
|
||||
{"name": "title", "selector": "h1", "type": "text"},
|
||||
{"name": "content", "selector": ".content", "type": "html"}
|
||||
]
|
||||
}
|
||||
config2 = CrawlerRunConfig(
|
||||
extraction_strategy=JsonCssExtractionStrategy(schema=schema),
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48,
|
||||
threshold_type="fixed",
|
||||
min_word_threshold=0
|
||||
),
|
||||
options={"ignore_links": True}
|
||||
),
|
||||
scraping_strategy=LXMLWebScrapingStrategy()
|
||||
)
|
||||
|
||||
# Test Case 3:LLM content filter
|
||||
config3 = CrawlerRunConfig(
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=LLMContentFilter(
|
||||
provider="openai/gpt-4",
|
||||
instruction="Extract key technical concepts",
|
||||
chunk_token_threshold=2000,
|
||||
overlap_rate=0.1
|
||||
),
|
||||
options={"ignore_images": True}
|
||||
),
|
||||
scraping_strategy=WebScrapingStrategy()
|
||||
)
|
||||
|
||||
# Test all configurations
|
||||
test_configs = [config1, config2, config3]
|
||||
|
||||
for i, config in enumerate(test_configs, 1):
|
||||
print(f"\nTesting Configuration {i}:")
|
||||
|
||||
# Serialize
|
||||
serialized = to_serializable_dict(config)
|
||||
print(f"\nSerialized Config {i}:")
|
||||
print(serialized)
|
||||
|
||||
# Deserialize
|
||||
deserialized = from_serializable_dict(serialized)
|
||||
print(f"\nDeserialized Config {i}:")
|
||||
print(to_serializable_dict(deserialized)) # Convert back to dict for comparison
|
||||
|
||||
# Verify they match
|
||||
assert to_serializable_dict(config) == to_serializable_dict(deserialized)
|
||||
print(f"\nVerification passed: Configuration {i} matches after serialization/deserialization!")
|
||||
Reference in New Issue
Block a user