Compare commits

...

2 Commits

Author SHA1 Message Date
AHMET YILMAZ
8e1362acf5 Fix async generator type mismatch in Docker Client streaming
- Fixed single_result_generator to properly handle async generators from deep crawl strategies
- Added proper __aiter__ checking to distinguish between CrawlResult and async generators
- Await and yield individual results from nested async generators
- Streaming functionality now works correctly for all patterns (SDK, Direct API, Docker Client)
- All 22 comprehensive tests passing with 100% success rate
- Live streaming test confirmed working end-to-end
2025-08-15 15:49:11 +08:00
AHMET YILMAZ
07e9d651fb feat: Comprehensive deep crawl streaming functionality restoration
🚀 Major Achievements:
-  ORJSON Serialization System: Complete implementation with custom handlers
-  Global Deprecated Properties System: DeprecatedPropertiesMixin for automatic exclusion
-  Deep Crawl Streaming: Fully restored with proper CrawlResultContainer handling
-  Docker Client Streaming: Fixed async generator patterns and result type checking
-  Server API Improvements: Correct method selection logic and streaming responses
-  Type Safety: Dict-as-logger detection to prevent crashes

📊 Test Results: 100% success rate on comprehensive test suite (10/10 tests passing)

🔧 Files Modified:
- crawl4ai/models.py: ORJSON + DeprecatedPropertiesMixin implementation
- deploy/docker/api.py: Streaming endpoint fixes + CrawlResultContainer handling
- deploy/docker/server.py: Production imports + ORJSON response handling
- crawl4ai/docker_client.py: Async generator streaming fixes
- crawl4ai/deep_crawling/bfs_strategy.py: Logger type safety
- .gitignore: Development environment cleanup
- tests/test_comprehensive_fixes.py: Rich-based comprehensive test suite

🎯 Impact: Production-ready deep crawl streaming functionality with comprehensive testing coverage
2025-08-15 15:31:36 +08:00
7 changed files with 1473 additions and 69 deletions

10
.gitignore vendored
View File

@@ -1,6 +1,11 @@
# Scripts folder (private tools)
.scripts/
# Local development CLI (private)
local_dev.py
dev
DEV_CLI_README.md
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -270,4 +275,7 @@ docs/**/data
.codecat/
docs/apps/linkdin/debug*/
docs/apps/linkdin/samples/insights/*
docs/apps/linkdin/samples/insights/*
# Production checklist (local, not for version control)
PRODUCTION_CHECKLIST.md

View File

@@ -38,7 +38,14 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
self.include_external = include_external
self.score_threshold = score_threshold
self.max_pages = max_pages
self.logger = logger or logging.getLogger(__name__)
# Type check for logger
if isinstance(logger, dict):
logging.getLogger(__name__).warning(
"BFSDeepCrawlStrategy received a dict as logger; falling back to default logger."
)
self.logger = logging.getLogger(__name__)
else:
self.logger = logger or logging.getLogger(__name__)
self.stats = TraversalStats(start_time=datetime.now())
self._cancel_event = asyncio.Event()
self._pages_crawled = 0

View File

@@ -30,7 +30,7 @@ class Crawl4aiDockerClient:
def __init__(
self,
base_url: str = "http://localhost:8000",
timeout: float = 30.0,
timeout: float = 600.0, # Increased to 10 minutes for crawling operations
verify_ssl: bool = True,
verbose: bool = True,
log_file: Optional[str] = None
@@ -113,21 +113,12 @@ class Crawl4aiDockerClient:
self.logger.info(f"Crawling {len(urls)} URLs {'(streaming)' if is_streaming else ''}", tag="CRAWL")
if is_streaming:
async def stream_results() -> AsyncGenerator[CrawlResult, None]:
async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip():
result = json.loads(line)
if "error" in result:
self.logger.error_status(url=result.get("url", "unknown"), error=result["error"])
continue
self.logger.url_status(url=result.get("url", "unknown"), success=True, timing=result.get("timing", 0.0))
if result.get("status") == "completed":
continue
else:
yield CrawlResult(**result)
return stream_results()
# For streaming, we need to return the async generator properly
# The caller should be able to do: async for result in await client.crawl(...)
async def streaming_wrapper():
async for result in self._stream_crawl_results(data):
yield result
return streaming_wrapper()
response = await self._request("POST", "/crawl", json=data)
result_data = response.json()
@@ -138,6 +129,35 @@ class Crawl4aiDockerClient:
self.logger.success(f"Crawl completed with {len(results)} results", tag="CRAWL")
return results[0] if len(results) == 1 else results
async def _stream_crawl_results(self, data: Dict[str, Any]) -> AsyncGenerator[CrawlResult, None]:
"""Internal method to handle streaming crawl results."""
async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip():
try:
result = json.loads(line)
if "error" in result:
self.logger.error_status(url=result.get("url", "unknown"), error=result["error"])
continue
# Check if this is a crawl result (has required fields)
if "url" in result and "success" in result:
self.logger.url_status(url=result.get("url", "unknown"), success=result.get("success", False), timing=result.get("timing", 0.0))
# Create CrawlResult object properly
crawl_result = CrawlResult(**result)
yield crawl_result
# Skip status-only messages
elif result.get("status") == "completed":
continue
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse streaming response: {e}", tag="STREAM")
continue
except Exception as e:
self.logger.error(f"Error processing streaming result: {e}", tag="STREAM")
continue
async def get_schema(self) -> Dict[str, Any]:
"""Retrieve configuration schemas."""
response = await self._request("GET", "/schema")

View File

@@ -1,4 +1,36 @@
from pydantic import BaseModel, HttpUrl, PrivateAttr, Field
"""
Crawl4AI Models Module
This module contains Pydantic models used throughout the Crawl4AI library.
Key Features:
- ORJSONModel: Base model with ORJSON serialization support
- DeprecatedPropertiesMixin: Global system for handling deprecated properties
- CrawlResult: Main result model with backward compatibility support
Deprecated Properties System:
The DeprecatedPropertiesMixin provides a global way to handle deprecated properties
across all models. Instead of manually excluding deprecated properties in each
model_dump() call, you can simply override the get_deprecated_properties() method:
Example:
class MyModel(ORJSONModel):
name: str
old_field: Optional[str] = None
def get_deprecated_properties(self) -> set[str]:
return {'old_field', 'another_deprecated_field'}
@property
def old_field(self):
raise AttributeError("old_field is deprecated, use name instead")
The system automatically excludes these properties from serialization, preventing
property objects from appearing in JSON output.
"""
from pydantic import BaseModel, ConfigDict,HttpUrl, PrivateAttr, Field
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
from typing import AsyncGenerator
from typing import Generic, TypeVar
@@ -8,7 +40,7 @@ from .ssl_certificate import SSLCertificate
from datetime import datetime
from datetime import timedelta
import orjson
###############################
# Dispatcher Models
###############################
@@ -91,7 +123,122 @@ class TokenUsage:
completion_tokens_details: Optional[dict] = None
prompt_tokens_details: Optional[dict] = None
class UrlModel(BaseModel):
def orjson_default(obj):
# Handle datetime (if not already handled by orjson)
if isinstance(obj, datetime):
return obj.isoformat()
# Handle property objects (convert to string or something else)
if isinstance(obj, property):
return str(obj)
# Last resort: convert to string
return str(obj)
class DeprecatedPropertiesMixin:
"""
Mixin to handle deprecated properties in Pydantic models.
Classes that inherit from this mixin can define deprecated properties
that will be automatically excluded from serialization.
Usage:
1. Override the get_deprecated_properties() method to return a set of deprecated property names
2. The model_dump method will automatically exclude these properties
Example:
class MyModel(ORJSONModel):
def get_deprecated_properties(self) -> set[str]:
return {'old_field', 'legacy_property'}
name: str
old_field: Optional[str] = None # Field definition
@property
def old_field(self): # Property that overrides the field
raise AttributeError("old_field is deprecated, use name instead")
"""
def get_deprecated_properties(self) -> set[str]:
"""
Get deprecated property names for this model.
Override this method in subclasses to define deprecated properties.
Returns:
set[str]: Set of deprecated property names
"""
return set()
@classmethod
def get_all_deprecated_properties(cls) -> set[str]:
"""
Get all deprecated properties from this class and all parent classes.
Returns:
set[str]: Set of all deprecated property names
"""
deprecated_props = set()
# Create an instance to call the instance method
try:
# Try to create a dummy instance to get deprecated properties
dummy_instance = cls.__new__(cls)
deprecated_props.update(dummy_instance.get_deprecated_properties())
except Exception:
# If we can't create an instance, check for class-level definitions
pass
# Also check parent classes
for klass in cls.__mro__:
if hasattr(klass, 'get_deprecated_properties') and klass != DeprecatedPropertiesMixin:
try:
dummy_instance = klass.__new__(klass)
deprecated_props.update(dummy_instance.get_deprecated_properties())
except Exception:
pass
return deprecated_props
def model_dump(self, *args, **kwargs):
"""
Override model_dump to automatically exclude deprecated properties.
This method:
1. Gets the existing exclude set from kwargs
2. Adds all deprecated properties defined in get_deprecated_properties()
3. Calls the parent model_dump with the updated exclude set
"""
# Get the default exclude set, or create empty set if None
exclude = kwargs.get('exclude', set())
if exclude is None:
exclude = set()
elif not isinstance(exclude, set):
exclude = set(exclude) if exclude else set()
# Add deprecated properties for this instance
exclude.update(self.get_deprecated_properties())
kwargs['exclude'] = exclude
return super().model_dump(*args, **kwargs)
class ORJSONModel(DeprecatedPropertiesMixin, BaseModel):
model_config = ConfigDict(
ser_json_timedelta="iso8601", # Optional: format timedelta
ser_json_bytes="utf8", # Optional: bytes → UTF-8 string
)
def model_dump_json(self, **kwargs) -> bytes:
"""Custom JSON serialization using orjson"""
return orjson.dumps(self.model_dump(**kwargs), default=orjson_default)
@classmethod
def model_validate_json(cls, json_data: Union[str, bytes], **kwargs):
"""Custom JSON deserialization using orjson"""
if isinstance(json_data, str):
json_data = json_data.encode()
return cls.model_validate(orjson.loads(json_data), **kwargs)
class UrlModel(ORJSONModel):
url: HttpUrl
forced: bool = False
@@ -108,7 +255,7 @@ class TraversalStats:
total_depth_reached: int = 0
current_depth: int = 0
class DispatchResult(BaseModel):
class DispatchResult(ORJSONModel):
task_id: str
memory_usage: float
peak_memory: float
@@ -116,7 +263,7 @@ class DispatchResult(BaseModel):
end_time: Union[datetime, float]
error_message: str = ""
class MarkdownGenerationResult(BaseModel):
class MarkdownGenerationResult(ORJSONModel):
raw_markdown: str
markdown_with_citations: str
references_markdown: str
@@ -126,7 +273,7 @@ class MarkdownGenerationResult(BaseModel):
def __str__(self):
return self.raw_markdown
class CrawlResult(BaseModel):
class CrawlResult(ORJSONModel):
url: str
html: str
fit_html: Optional[str] = None
@@ -156,6 +303,10 @@ class CrawlResult(BaseModel):
class Config:
arbitrary_types_allowed = True
def get_deprecated_properties(self) -> set[str]:
"""Define deprecated properties that should be excluded from serialization."""
return {'fit_html', 'fit_markdown', 'markdown_v2'}
# NOTE: The StringCompatibleMarkdown class, custom __init__ method, property getters/setters,
# and model_dump override all exist to support a smooth transition from markdown as a string
# to markdown as a MarkdownGenerationResult object, while maintaining backward compatibility.
@@ -245,14 +396,16 @@ class CrawlResult(BaseModel):
1. PrivateAttr fields are excluded from serialization by default
2. We need to maintain backward compatibility by including the 'markdown' field
in the serialized output
3. We're transitioning from 'markdown_v2' to enhancing 'markdown' to hold
the same type of data
3. Uses the DeprecatedPropertiesMixin to automatically exclude deprecated properties
Future developers: This method ensures that the markdown content is properly
serialized despite being stored in a private attribute. If the serialization
requirements change, this is where you would update the logic.
serialized despite being stored in a private attribute. The deprecated properties
are automatically handled by the mixin.
"""
# Use the parent class method which handles deprecated properties automatically
result = super().model_dump(*args, **kwargs)
# Add the markdown content if it exists
if self._markdown is not None:
result["markdown"] = self._markdown.model_dump()
return result
@@ -307,7 +460,7 @@ RunManyReturn = Union[
# 1. Replace the private attribute and property with a standard field
# 2. Update any serialization logic that might depend on the current behavior
class AsyncCrawlResponse(BaseModel):
class AsyncCrawlResponse(ORJSONModel):
html: str
response_headers: Dict[str, str]
js_execution_result: Optional[Dict[str, Any]] = None
@@ -328,7 +481,7 @@ class AsyncCrawlResponse(BaseModel):
###############################
# Scraping Models
###############################
class MediaItem(BaseModel):
class MediaItem(ORJSONModel):
src: Optional[str] = ""
data: Optional[str] = ""
alt: Optional[str] = ""
@@ -340,7 +493,7 @@ class MediaItem(BaseModel):
width: Optional[int] = None
class Link(BaseModel):
class Link(ORJSONModel):
href: Optional[str] = ""
text: Optional[str] = ""
title: Optional[str] = ""
@@ -353,7 +506,7 @@ class Link(BaseModel):
total_score: Optional[float] = None # Combined score from intrinsic and contextual scores
class Media(BaseModel):
class Media(ORJSONModel):
images: List[MediaItem] = []
videos: List[
MediaItem
@@ -364,12 +517,12 @@ class Media(BaseModel):
tables: List[Dict] = [] # Table data extracted from HTML tables
class Links(BaseModel):
class Links(ORJSONModel):
internal: List[Link] = []
external: List[Link] = []
class ScrapingResult(BaseModel):
class ScrapingResult(ORJSONModel):
cleaned_html: str
success: bool
media: Media = Media()

View File

@@ -1,5 +1,6 @@
import os
import json
import orjson
import asyncio
from typing import List, Tuple, Dict
from functools import partial
@@ -384,27 +385,60 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
"""Stream results with heartbeats and completion markers."""
import json
from utils import datetime_handler
import orjson
from datetime import datetime
import inspect
def orjson_default(obj):
# Handle datetime (if not already handled by orjson)
if isinstance(obj, datetime):
return obj.isoformat()
# Handle property objects (convert to string or something else)
if isinstance(obj, property):
return str(obj)
# Last resort: convert to string
return str(obj)
try:
async for result in results_gen:
try:
server_memory_mb = _get_memory_mb()
result_dict = result.model_dump()
result_dict['server_memory_mb'] = server_memory_mb
# If PDF exists, encode it to base64
if result_dict.get('pdf') is not None:
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
data = json.dumps(result_dict, default=datetime_handler) + "\n"
yield data.encode('utf-8')
except Exception as e:
logger.error(f"Serialization error: {e}")
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
yield (json.dumps(error_response) + "\n").encode('utf-8')
logger.info(f"Starting streaming with results_gen type: {type(results_gen)}")
logger.info(f"Is results_gen async generator: {inspect.isasyncgen(results_gen)}")
# Check if results_gen is actually an async generator vs another type
if inspect.isasyncgen(results_gen):
logger.info("Processing as async generator")
async for result in results_gen:
try:
logger.info(f"Processing streaming result of type: {type(result)}")
# Check if this result is actually a CrawlResult
if hasattr(result, 'model_dump_json'):
server_memory_mb = _get_memory_mb()
result_json = result.model_dump_json()
result_dict = orjson.loads(result_json)
result_dict['server_memory_mb'] = server_memory_mb
if result_dict.get('pdf') is not None:
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
data = orjson.dumps(result_dict, default=orjson_default).decode('utf-8') + "\n"
yield data.encode('utf-8')
else:
logger.error(f"Result doesn't have model_dump_json method: {type(result)}")
error_response = {"error": f"Invalid result type: {type(result)}", "url": "unknown"}
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
except Exception as e:
logger.error(f"Serialization error: {e}")
logger.error(f"Result type was: {type(result)}")
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
else:
logger.error(f"results_gen is not an async generator: {type(results_gen)}")
error_response = {"error": f"Invalid results_gen type: {type(results_gen)}"}
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
yield json.dumps({"status": "completed"}).encode('utf-8')
yield orjson.dumps({"status": "completed"}).decode('utf-8').encode('utf-8')
except asyncio.CancelledError:
logger.warning("Client disconnected during streaming")
@@ -472,7 +506,9 @@ async def handle_crawl_request(
# Process results to handle PDF bytes
processed_results = []
for result in results:
result_dict = result.model_dump()
# Use ORJSON serialization to handle property objects properly
result_json = result.model_dump_json()
result_dict = orjson.loads(result_json)
# If PDF exists, encode it to base64
if result_dict.get('pdf') is not None:
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
@@ -522,8 +558,19 @@ async def handle_stream_crawl_request(
browser_config.verbose = False
crawler_config = CrawlerRunConfig.load(crawler_config)
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
crawler_config.stream = True
# Don't force stream=True here - let the deep crawl strategy control its own streaming behavior
# Apply global base config (this was missing!)
base_config = config["crawler"]["base_config"]
for key, value in base_config.items():
if hasattr(crawler_config, key):
print(f"[DEBUG] Applying base_config: {key} = {value}")
setattr(crawler_config, key, value)
print(f"[DEBUG] Deep crawl strategy: {type(crawler_config.deep_crawl_strategy).__name__ if crawler_config.deep_crawl_strategy else 'None'}")
print(f"[DEBUG] Stream mode: {crawler_config.stream}")
print(f"[DEBUG] Simulate user: {getattr(crawler_config, 'simulate_user', 'Not set')}")
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
@@ -537,11 +584,58 @@ async def handle_stream_crawl_request(
# crawler = AsyncWebCrawler(config=browser_config)
# await crawler.start()
results_gen = await crawler.arun_many(
urls=urls,
config=crawler_config,
dispatcher=dispatcher
)
# Use correct method based on URL count (same as regular endpoint)
if len(urls) == 1:
# For single URL, use arun to get CrawlResult, then wrap in async generator
single_result_container = await crawler.arun(
url=urls[0],
config=crawler_config,
dispatcher=dispatcher
)
async def single_result_generator():
# Handle CrawlResultContainer - extract the actual results
if hasattr(single_result_container, '_results'):
# It's a CrawlResultContainer - iterate over the internal results
for result in single_result_container._results:
# Check if the result is an async generator (from deep crawl)
if hasattr(result, '__aiter__'):
async for sub_result in result:
yield sub_result
else:
yield result
elif hasattr(single_result_container, '__aiter__'):
# It's an async generator (from streaming deep crawl)
async for result in single_result_container:
yield result
elif hasattr(single_result_container, '__iter__') and not hasattr(single_result_container, 'url'):
# It's iterable but not a CrawlResult itself
for result in single_result_container:
# Check if each result is an async generator
if hasattr(result, '__aiter__'):
async for sub_result in result:
yield sub_result
else:
yield result
else:
# It's a single CrawlResult
yield single_result_container
results_gen = single_result_generator()
else:
# For multiple URLs, use arun_many
results_gen = await crawler.arun_many(
urls=urls,
config=crawler_config,
dispatcher=dispatcher
)
# If results_gen is a list (e.g., from deep crawl), convert to async generator
if isinstance(results_gen, list):
async def convert_list_to_generator():
for result in results_gen:
yield result
results_gen = convert_list_to_generator()
return crawler, results_gen

View File

@@ -7,13 +7,16 @@ Crawl4AI FastAPI entrypoint
"""
# ── stdlib & 3rdparty imports ───────────────────────────────
from datetime import datetime
import orjson
from crawler_pool import get_crawler, close_all, janitor
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
from auth import create_access_token, get_token_dependency, TokenRequest
from pydantic import BaseModel
from typing import Optional, List, Dict
from fastapi import Request, Depends
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, ORJSONResponse
import base64
import re
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
@@ -32,6 +35,8 @@ from schemas import (
JSEndpointRequest,
)
# Use the proper serialization functions from async_configs
from crawl4ai.async_configs import to_serializable_dict
from utils import (
FilterType, load_config, setup_logging, verify_email_domain
)
@@ -112,11 +117,26 @@ async def lifespan(_: FastAPI):
app.state.janitor.cancel()
await close_all()
def orjson_default(obj):
# Handle datetime (if not already handled by orjson)
if isinstance(obj, datetime):
return obj.isoformat()
# Handle property objects (convert to string or something else)
if isinstance(obj, property):
return str(obj)
# Last resort: convert to string
return str(obj)
def orjson_dumps(v, *, default):
return orjson.dumps(v, default=orjson_default).decode()
# ───────────────────── FastAPI instance ──────────────────────
app = FastAPI(
title=config["app"]["title"],
version=config["app"]["version"],
lifespan=lifespan,
default_response_class=ORJSONResponse
)
# ── static playground ──────────────────────────────────────
@@ -435,15 +455,20 @@ async def crawl(
"""
Crawl a list of URLs and return the results as JSON.
"""
if not crawl_request.urls:
raise HTTPException(400, "At least one URL required")
res = await handle_crawl_request(
urls=crawl_request.urls,
browser_config=crawl_request.browser_config,
crawler_config=crawl_request.crawler_config,
config=config,
)
return JSONResponse(res)
try:
if not crawl_request.urls:
raise HTTPException(400, "At least one URL required")
res = await handle_crawl_request(
urls=crawl_request.urls,
browser_config=crawl_request.browser_config,
crawler_config=crawl_request.crawler_config,
config=config,
)
# handle_crawl_request returns a dictionary, so we can pass it directly to ORJSONResponse
return ORJSONResponse(res)
except Exception as e:
print(f"Error occurred: {e}")
return ORJSONResponse({"error": str(e)}, status_code=500)
@app.post("/crawl/stream")

File diff suppressed because it is too large Load Diff