feat(docker): add Docker service integration and config serialization

Add Docker service integration with FastAPI server and client implementation.
Implement serialization utilities for BrowserConfig and CrawlerRunConfig to support
Docker service communication. Clean up imports and improve error handling.

- Add Crawl4aiDockerClient class
- Implement config serialization/deserialization
- Add FastAPI server with streaming support
- Add health check endpoint
- Clean up imports and type hints
This commit is contained in:
UncleCode
2025-01-31 18:00:16 +08:00
parent ce4f04dad2
commit 53ac3ec0b4
11 changed files with 859 additions and 97 deletions

View File

@@ -1,4 +1,5 @@
# __init__.py
import warnings
from .async_webcrawler import AsyncWebCrawler, CacheMode
from .async_configs import BrowserConfig, CrawlerRunConfig
@@ -28,6 +29,7 @@ from .async_dispatcher import (
DisplayMode,
BaseDispatcher
)
from .docker_client import Crawl4aiDockerClient
from .hub import CrawlerHub
__all__ = [
@@ -59,12 +61,13 @@ __all__ = [
"CrawlerMonitor",
"DisplayMode",
"MarkdownGenerationResult",
"Crawl4aiDockerClient",
]
def is_sync_version_installed():
try:
import selenium
import selenium # noqa
return True
except ImportError:
@@ -85,9 +88,6 @@ else:
# import warnings
# print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.")
import warnings
from pydantic import warnings as pydantic_warnings
# Disable all Pydantic warnings
warnings.filterwarnings("ignore", module="pydantic")
# pydantic_warnings.filter_warnings()
# pydantic_warnings.filter_warnings()

View File

@@ -7,17 +7,129 @@ from .config import (
SOCIAL_MEDIA_DOMAINS,
)
from .user_agent_generator import UserAgentGenerator, UAGen, ValidUAGenerator, OnlineUAGenerator
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
from .extraction_strategy import ExtractionStrategy
from .chunking_strategy import ChunkingStrategy, RegexChunking
from .markdown_generation_strategy import MarkdownGenerationStrategy
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter, LLMContentFilter, PruningContentFilter
from .content_filter_strategy import RelevantContentFilter # , BM25ContentFilter, LLMContentFilter, PruningContentFilter
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
from typing import Optional, Union, List
from typing import Union, List
from .cache_context import CacheMode
import inspect
from typing import Any, Dict
from enum import Enum
class BrowserConfig:
def to_serializable_dict(obj: Any) -> Dict:
"""
Recursively convert an object to a serializable dictionary using {type, params} structure
for complex objects.
"""
if obj is None:
return None
# Handle basic types
if isinstance(obj, (str, int, float, bool)):
return obj
# Handle Enum
if isinstance(obj, Enum):
return {
"type": obj.__class__.__name__,
"params": obj.value
}
# Handle datetime objects
if hasattr(obj, 'isoformat'):
return obj.isoformat()
# Handle lists, tuples, and sets
if isinstance(obj, (list, tuple, set)):
return [to_serializable_dict(item) for item in obj]
# Handle dictionaries - preserve them as-is
if isinstance(obj, dict):
return {
"type": "dict", # Mark as plain dictionary
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
}
# Handle class instances
if hasattr(obj, '__class__'):
# Get constructor signature
sig = inspect.signature(obj.__class__.__init__)
params = sig.parameters
# Get current values
current_values = {}
for name, param in params.items():
if name == 'self':
continue
value = getattr(obj, name, param.default)
# Only include if different from default, considering empty values
if not (is_empty_value(value) and is_empty_value(param.default)):
if value != param.default:
current_values[name] = to_serializable_dict(value)
return {
"type": obj.__class__.__name__,
"params": current_values
}
return str(obj)
def from_serializable_dict(data: Any) -> Any:
"""
Recursively convert a serializable dictionary back to an object instance.
"""
if data is None:
return None
# Handle basic types
if isinstance(data, (str, int, float, bool)):
return data
# Handle typed data
if isinstance(data, dict) and "type" in data:
# Handle plain dictionaries
if data["type"] == "dict":
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
# Import from crawl4ai for class instances
import crawl4ai
cls = getattr(crawl4ai, data["type"])
# Handle Enum
if issubclass(cls, Enum):
return cls(data["params"])
# Handle class instances
constructor_args = {
k: from_serializable_dict(v) for k, v in data["params"].items()
}
return cls(**constructor_args)
# Handle lists
if isinstance(data, list):
return [from_serializable_dict(item) for item in data]
# Handle raw dictionaries (legacy support)
if isinstance(data, dict):
return {k: from_serializable_dict(v) for k, v in data.items()}
return data
def is_empty_value(value: Any) -> bool:
"""Check if a value is effectively empty/null."""
if value is None:
return True
if isinstance(value, (list, tuple, set, dict, str)) and len(value) == 0:
return True
return False
class BrowserConfig():
"""
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
@@ -239,8 +351,18 @@ class BrowserConfig:
config_dict.update(kwargs)
return BrowserConfig.from_kwargs(config_dict)
# Create a funciton returns dict of the object
def dump(self) -> dict:
# Serialize the object to a dictionary
return to_serializable_dict(self)
class CrawlerRunConfig:
@staticmethod
def load( data: dict) -> "BrowserConfig":
# Deserialize the object from a dictionary
return from_serializable_dict(data)
class CrawlerRunConfig():
"""
Configuration class for controlling how the crawler runs each crawl operation.
This includes parameters for content extraction, page manipulation, waiting conditions,
@@ -665,6 +787,15 @@ class CrawlerRunConfig:
)
# Create a funciton returns dict of the object
def dump(self) -> dict:
# Serialize the object to a dictionary
return to_serializable_dict(self)
@staticmethod
def load(data: dict) -> "CrawlerRunConfig":
# Deserialize the object from a dictionary
return from_serializable_dict(data)
def to_dict(self):
return {
"word_count_threshold": self.word_count_threshold,

View File

@@ -1,3 +1,4 @@
from .__version__ import __version__ as crawl4ai_version
import os
import sys
import time
@@ -10,7 +11,7 @@ import asyncio
# from contextlib import nullcontext, asynccontextmanager
from contextlib import asynccontextmanager
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult
from .models import CrawlResult, MarkdownGenerationResult,DispatchResult
from .async_database import async_db_manager
from .chunking_strategy import * # noqa: F403
from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking
@@ -43,8 +44,7 @@ from .utils import (
RobotsParser,
)
from typing import Union, AsyncGenerator, List, TypeVar
from collections.abc import AsyncGenerator
from typing import Union, AsyncGenerator, TypeVar
CrawlResultT = TypeVar('CrawlResultT', bound=CrawlResult)
RunManyReturn = Union[List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
@@ -55,7 +55,7 @@ DeepCrawlManyReturn = Union[
AsyncGenerator[CrawlResultT, None],
]
from .__version__ import __version__ as crawl4ai_version
class AsyncWebCrawler:
@@ -768,18 +768,19 @@ class AsyncWebCrawler:
),
)
transform_result = lambda task_result: (
setattr(task_result.result, 'dispatch_result',
DispatchResult(
task_id=task_result.task_id,
memory_usage=task_result.memory_usage,
peak_memory=task_result.peak_memory,
start_time=task_result.start_time,
end_time=task_result.end_time,
error_message=task_result.error_message,
def transform_result(task_result):
return (
setattr(task_result.result, 'dispatch_result',
DispatchResult(
task_id=task_result.task_id,
memory_usage=task_result.memory_usage,
peak_memory=task_result.peak_memory,
start_time=task_result.start_time,
end_time=task_result.end_time,
error_message=task_result.error_message,
)
) or task_result.result
)
) or task_result.result
)
stream = config.stream

View File

@@ -9,16 +9,16 @@ from .utils import clean_tokens, perform_completion_with_backoff, escape_json_st
from abc import ABC, abstractmethod
import math
from snowballstemmer import stemmer
from .config import DEFAULT_PROVIDER, OVERLAP_RATE, WORD_TOKEN_RATE
from .config import DEFAULT_PROVIDER, OVERLAP_RATE, WORD_TOKEN_RATE, PROVIDER_MODELS
from .models import TokenUsage
from .prompts import PROMPT_FILTER_CONTENT
import os
import json
import hashlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from .async_logger import AsyncLogger, LogLevel
from colorama import Fore, Style, init
from colorama import Fore, Style
class RelevantContentFilter(ABC):
"""Abstract base class for content filtering strategies"""
@@ -879,7 +879,6 @@ class LLMContentFilter(RelevantContentFilter):
colors={"chunk_count": Fore.YELLOW}
)
extracted_content = []
start_time = time.time()
# Process chunks in parallel

210
crawl4ai/docker_client.py Normal file
View File

@@ -0,0 +1,210 @@
from typing import List, Optional, Union, AsyncGenerator, Dict, Any
import httpx
import json
from urllib.parse import urljoin
from .async_configs import BrowserConfig, CrawlerRunConfig
from .models import CrawlResult
from .async_logger import AsyncLogger, LogLevel
class Crawl4aiClientError(Exception):
"""Base exception for Crawl4ai Docker client errors."""
pass
class ConnectionError(Crawl4aiClientError):
"""Raised when connection to the Docker server fails."""
pass
class RequestError(Crawl4aiClientError):
"""Raised when the server returns an error response."""
pass
class Crawl4aiDockerClient:
"""
Client for interacting with Crawl4AI Docker server.
Args:
base_url (str): Base URL of the Crawl4AI Docker server
timeout (float): Default timeout for requests in seconds
verify_ssl (bool): Whether to verify SSL certificates
verbose (bool): Whether to show logging output
log_file (str, optional): Path to log file if file logging is desired
"""
def __init__(
self,
base_url: str = "http://localhost:8000",
timeout: float = 30.0,
verify_ssl: bool = True,
verbose: bool = True,
log_file: Optional[str] = None
) -> None:
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self._http_client = httpx.AsyncClient(
timeout=timeout,
verify=verify_ssl,
headers={"Content-Type": "application/json"}
)
self.logger = AsyncLogger(
log_file=log_file,
log_level=LogLevel.DEBUG,
verbose=verbose
)
async def _check_server_connection(self) -> bool:
"""Check if server is reachable."""
try:
self.logger.info("Checking server connection...", tag="INIT")
response = await self._http_client.get(f"{self.base_url}/health")
response.raise_for_status()
self.logger.success(f"Connected to server at {self.base_url}", tag="READY")
return True
except Exception as e:
self.logger.error(f"Failed to connect to server: {str(e)}", tag="ERROR")
return False
def _prepare_request_data(
self,
urls: List[str],
browser_config: Optional[BrowserConfig] = None,
crawler_config: Optional[CrawlerRunConfig] = None
) -> Dict[str, Any]:
"""Prepare request data from configs using dump methods."""
self.logger.debug("Preparing request data", tag="INIT")
data = {
"urls": urls,
"browser_config": browser_config.dump() if browser_config else {},
"crawler_config": crawler_config.dump() if crawler_config else {}
}
self.logger.debug(f"Request data prepared for {len(urls)} URLs", tag="READY")
return data
async def _make_request(
self,
method: str,
endpoint: str,
**kwargs
) -> Union[Dict, AsyncGenerator]:
"""Make HTTP request to the server with error handling."""
url = urljoin(self.base_url, endpoint)
try:
self.logger.debug(f"Making {method} request to {endpoint}", tag="FETCH")
response = await self._http_client.request(method, url, **kwargs)
response.raise_for_status()
self.logger.success(f"Request to {endpoint} successful", tag="COMPLETE")
return response
except httpx.TimeoutException as e:
error_msg = f"Request timed out: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise ConnectionError(error_msg)
except httpx.RequestError as e:
error_msg = f"Failed to connect to server: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise ConnectionError(error_msg)
except httpx.HTTPStatusError as e:
error_detail = ""
try:
error_data = e.response.json()
error_detail = error_data.get('detail', str(e))
except (json.JSONDecodeError, AttributeError) as json_err:
error_detail = f"{str(e)} (Failed to parse error response: {str(json_err)})"
error_msg = f"Server returned error {e.response.status_code}: {error_detail}"
self.logger.error(error_msg, tag="ERROR")
raise RequestError(error_msg)
async def crawl(
self,
urls: List[str],
browser_config: Optional[BrowserConfig] = None,
crawler_config: Optional[CrawlerRunConfig] = None
) -> Union[CrawlResult, AsyncGenerator[CrawlResult, None]]:
"""Execute a crawl operation through the Docker server."""
# Check server connection first
if not await self._check_server_connection():
raise ConnectionError("Cannot proceed with crawl - server is not reachable")
request_data = self._prepare_request_data(urls, browser_config, crawler_config)
is_streaming = crawler_config.stream if crawler_config else False
self.logger.info(
f"Starting crawl for {len(urls)} URLs {'(streaming)' if is_streaming else ''}",
tag="INIT"
)
if is_streaming:
async def result_generator() -> AsyncGenerator[CrawlResult, None]:
try:
async with self._http_client.stream(
"POST",
f"{self.base_url}/crawl",
json=request_data,
timeout=None
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip():
try:
result_dict = json.loads(line)
if "error" in result_dict:
self.logger.error_status(
url=result_dict.get('url', 'unknown'),
error=result_dict['error']
)
continue
self.logger.url_status(
url=result_dict.get('url', 'unknown'),
success=True,
timing=result_dict.get('timing', 0.0)
)
yield CrawlResult(**result_dict)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse server response: {e}", tag="ERROR")
continue
except httpx.StreamError as e:
error_msg = f"Stream connection error: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise ConnectionError(error_msg)
except Exception as e:
error_msg = f"Unexpected error during streaming: {str(e)}"
self.logger.error(error_msg, tag="ERROR")
raise Crawl4aiClientError(error_msg)
return result_generator()
response = await self._make_request("POST", "/crawl", json=request_data)
response_data = response.json()
if not response_data.get("success", False):
error_msg = f"Crawl operation failed: {response_data.get('error', 'Unknown error')}"
self.logger.error(error_msg, tag="ERROR")
raise RequestError(error_msg)
results = [CrawlResult(**result_dict) for result_dict in response_data.get("results", [])]
self.logger.success(f"Crawl completed successfully with {len(results)} results", tag="COMPLETE")
return results[0] if len(results) == 1 else results
async def get_schema(self) -> Dict[str, Any]:
"""Retrieve the configuration schemas from the server."""
self.logger.info("Retrieving schema from server", tag="FETCH")
response = await self._make_request("GET", "/schema")
self.logger.success("Schema retrieved successfully", tag="COMPLETE")
return response.json()
async def close(self) -> None:
"""Close the HTTP client session."""
self.logger.info("Closing client connection", tag="COMPLETE")
await self._http_client.aclose()
async def __aenter__(self) -> "Crawl4aiDockerClient":
return self
async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[Any]) -> None:
await self.close()