Compare commits
2 Commits
pdf_proces
...
feat/ahmed
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e1362acf5 | ||
|
|
07e9d651fb |
8
.gitignore
vendored
8
.gitignore
vendored
@@ -1,6 +1,11 @@
|
|||||||
# Scripts folder (private tools)
|
# Scripts folder (private tools)
|
||||||
.scripts/
|
.scripts/
|
||||||
|
|
||||||
|
# Local development CLI (private)
|
||||||
|
local_dev.py
|
||||||
|
dev
|
||||||
|
DEV_CLI_README.md
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
@@ -271,3 +276,6 @@ docs/**/data
|
|||||||
|
|
||||||
docs/apps/linkdin/debug*/
|
docs/apps/linkdin/debug*/
|
||||||
docs/apps/linkdin/samples/insights/*
|
docs/apps/linkdin/samples/insights/*
|
||||||
|
|
||||||
|
# Production checklist (local, not for version control)
|
||||||
|
PRODUCTION_CHECKLIST.md
|
||||||
@@ -38,6 +38,13 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
self.include_external = include_external
|
self.include_external = include_external
|
||||||
self.score_threshold = score_threshold
|
self.score_threshold = score_threshold
|
||||||
self.max_pages = max_pages
|
self.max_pages = max_pages
|
||||||
|
# Type check for logger
|
||||||
|
if isinstance(logger, dict):
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"BFSDeepCrawlStrategy received a dict as logger; falling back to default logger."
|
||||||
|
)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
else:
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
self.stats = TraversalStats(start_time=datetime.now())
|
self.stats = TraversalStats(start_time=datetime.now())
|
||||||
self._cancel_event = asyncio.Event()
|
self._cancel_event = asyncio.Event()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class Crawl4aiDockerClient:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str = "http://localhost:8000",
|
base_url: str = "http://localhost:8000",
|
||||||
timeout: float = 30.0,
|
timeout: float = 600.0, # Increased to 10 minutes for crawling operations
|
||||||
verify_ssl: bool = True,
|
verify_ssl: bool = True,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
log_file: Optional[str] = None
|
log_file: Optional[str] = None
|
||||||
@@ -113,21 +113,12 @@ class Crawl4aiDockerClient:
|
|||||||
self.logger.info(f"Crawling {len(urls)} URLs {'(streaming)' if is_streaming else ''}", tag="CRAWL")
|
self.logger.info(f"Crawling {len(urls)} URLs {'(streaming)' if is_streaming else ''}", tag="CRAWL")
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
async def stream_results() -> AsyncGenerator[CrawlResult, None]:
|
# For streaming, we need to return the async generator properly
|
||||||
async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response:
|
# The caller should be able to do: async for result in await client.crawl(...)
|
||||||
response.raise_for_status()
|
async def streaming_wrapper():
|
||||||
async for line in response.aiter_lines():
|
async for result in self._stream_crawl_results(data):
|
||||||
if line.strip():
|
yield result
|
||||||
result = json.loads(line)
|
return streaming_wrapper()
|
||||||
if "error" in result:
|
|
||||||
self.logger.error_status(url=result.get("url", "unknown"), error=result["error"])
|
|
||||||
continue
|
|
||||||
self.logger.url_status(url=result.get("url", "unknown"), success=True, timing=result.get("timing", 0.0))
|
|
||||||
if result.get("status") == "completed":
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
yield CrawlResult(**result)
|
|
||||||
return stream_results()
|
|
||||||
|
|
||||||
response = await self._request("POST", "/crawl", json=data)
|
response = await self._request("POST", "/crawl", json=data)
|
||||||
result_data = response.json()
|
result_data = response.json()
|
||||||
@@ -138,6 +129,35 @@ class Crawl4aiDockerClient:
|
|||||||
self.logger.success(f"Crawl completed with {len(results)} results", tag="CRAWL")
|
self.logger.success(f"Crawl completed with {len(results)} results", tag="CRAWL")
|
||||||
return results[0] if len(results) == 1 else results
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
|
async def _stream_crawl_results(self, data: Dict[str, Any]) -> AsyncGenerator[CrawlResult, None]:
|
||||||
|
"""Internal method to handle streaming crawl results."""
|
||||||
|
async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.strip():
|
||||||
|
try:
|
||||||
|
result = json.loads(line)
|
||||||
|
if "error" in result:
|
||||||
|
self.logger.error_status(url=result.get("url", "unknown"), error=result["error"])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is a crawl result (has required fields)
|
||||||
|
if "url" in result and "success" in result:
|
||||||
|
self.logger.url_status(url=result.get("url", "unknown"), success=result.get("success", False), timing=result.get("timing", 0.0))
|
||||||
|
|
||||||
|
# Create CrawlResult object properly
|
||||||
|
crawl_result = CrawlResult(**result)
|
||||||
|
yield crawl_result
|
||||||
|
# Skip status-only messages
|
||||||
|
elif result.get("status") == "completed":
|
||||||
|
continue
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse streaming response: {e}", tag="STREAM")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing streaming result: {e}", tag="STREAM")
|
||||||
|
continue
|
||||||
|
|
||||||
async def get_schema(self) -> Dict[str, Any]:
|
async def get_schema(self) -> Dict[str, Any]:
|
||||||
"""Retrieve configuration schemas."""
|
"""Retrieve configuration schemas."""
|
||||||
response = await self._request("GET", "/schema")
|
response = await self._request("GET", "/schema")
|
||||||
|
|||||||
@@ -1,4 +1,36 @@
|
|||||||
from pydantic import BaseModel, HttpUrl, PrivateAttr, Field
|
|
||||||
|
"""
|
||||||
|
Crawl4AI Models Module
|
||||||
|
|
||||||
|
This module contains Pydantic models used throughout the Crawl4AI library.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- ORJSONModel: Base model with ORJSON serialization support
|
||||||
|
- DeprecatedPropertiesMixin: Global system for handling deprecated properties
|
||||||
|
- CrawlResult: Main result model with backward compatibility support
|
||||||
|
|
||||||
|
Deprecated Properties System:
|
||||||
|
The DeprecatedPropertiesMixin provides a global way to handle deprecated properties
|
||||||
|
across all models. Instead of manually excluding deprecated properties in each
|
||||||
|
model_dump() call, you can simply override the get_deprecated_properties() method:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MyModel(ORJSONModel):
|
||||||
|
name: str
|
||||||
|
old_field: Optional[str] = None
|
||||||
|
|
||||||
|
def get_deprecated_properties(self) -> set[str]:
|
||||||
|
return {'old_field', 'another_deprecated_field'}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def old_field(self):
|
||||||
|
raise AttributeError("old_field is deprecated, use name instead")
|
||||||
|
|
||||||
|
The system automatically excludes these properties from serialization, preventing
|
||||||
|
property objects from appearing in JSON output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict,HttpUrl, PrivateAttr, Field
|
||||||
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
|
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
@@ -8,7 +40,7 @@ from .ssl_certificate import SSLCertificate
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import orjson
|
||||||
###############################
|
###############################
|
||||||
# Dispatcher Models
|
# Dispatcher Models
|
||||||
###############################
|
###############################
|
||||||
@@ -91,7 +123,122 @@ class TokenUsage:
|
|||||||
completion_tokens_details: Optional[dict] = None
|
completion_tokens_details: Optional[dict] = None
|
||||||
prompt_tokens_details: Optional[dict] = None
|
prompt_tokens_details: Optional[dict] = None
|
||||||
|
|
||||||
class UrlModel(BaseModel):
|
|
||||||
|
def orjson_default(obj):
|
||||||
|
# Handle datetime (if not already handled by orjson)
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
|
||||||
|
# Handle property objects (convert to string or something else)
|
||||||
|
if isinstance(obj, property):
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
# Last resort: convert to string
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedPropertiesMixin:
|
||||||
|
"""
|
||||||
|
Mixin to handle deprecated properties in Pydantic models.
|
||||||
|
|
||||||
|
Classes that inherit from this mixin can define deprecated properties
|
||||||
|
that will be automatically excluded from serialization.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
1. Override the get_deprecated_properties() method to return a set of deprecated property names
|
||||||
|
2. The model_dump method will automatically exclude these properties
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MyModel(ORJSONModel):
|
||||||
|
def get_deprecated_properties(self) -> set[str]:
|
||||||
|
return {'old_field', 'legacy_property'}
|
||||||
|
|
||||||
|
name: str
|
||||||
|
old_field: Optional[str] = None # Field definition
|
||||||
|
|
||||||
|
@property
|
||||||
|
def old_field(self): # Property that overrides the field
|
||||||
|
raise AttributeError("old_field is deprecated, use name instead")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_deprecated_properties(self) -> set[str]:
|
||||||
|
"""
|
||||||
|
Get deprecated property names for this model.
|
||||||
|
Override this method in subclasses to define deprecated properties.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
set[str]: Set of deprecated property names
|
||||||
|
"""
|
||||||
|
return set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_deprecated_properties(cls) -> set[str]:
|
||||||
|
"""
|
||||||
|
Get all deprecated properties from this class and all parent classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
set[str]: Set of all deprecated property names
|
||||||
|
"""
|
||||||
|
deprecated_props = set()
|
||||||
|
# Create an instance to call the instance method
|
||||||
|
try:
|
||||||
|
# Try to create a dummy instance to get deprecated properties
|
||||||
|
dummy_instance = cls.__new__(cls)
|
||||||
|
deprecated_props.update(dummy_instance.get_deprecated_properties())
|
||||||
|
except Exception:
|
||||||
|
# If we can't create an instance, check for class-level definitions
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Also check parent classes
|
||||||
|
for klass in cls.__mro__:
|
||||||
|
if hasattr(klass, 'get_deprecated_properties') and klass != DeprecatedPropertiesMixin:
|
||||||
|
try:
|
||||||
|
dummy_instance = klass.__new__(klass)
|
||||||
|
deprecated_props.update(dummy_instance.get_deprecated_properties())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return deprecated_props
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Override model_dump to automatically exclude deprecated properties.
|
||||||
|
|
||||||
|
This method:
|
||||||
|
1. Gets the existing exclude set from kwargs
|
||||||
|
2. Adds all deprecated properties defined in get_deprecated_properties()
|
||||||
|
3. Calls the parent model_dump with the updated exclude set
|
||||||
|
"""
|
||||||
|
# Get the default exclude set, or create empty set if None
|
||||||
|
exclude = kwargs.get('exclude', set())
|
||||||
|
if exclude is None:
|
||||||
|
exclude = set()
|
||||||
|
elif not isinstance(exclude, set):
|
||||||
|
exclude = set(exclude) if exclude else set()
|
||||||
|
|
||||||
|
# Add deprecated properties for this instance
|
||||||
|
exclude.update(self.get_deprecated_properties())
|
||||||
|
kwargs['exclude'] = exclude
|
||||||
|
|
||||||
|
return super().model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ORJSONModel(DeprecatedPropertiesMixin, BaseModel):
|
||||||
|
model_config = ConfigDict(
|
||||||
|
ser_json_timedelta="iso8601", # Optional: format timedelta
|
||||||
|
ser_json_bytes="utf8", # Optional: bytes → UTF-8 string
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_dump_json(self, **kwargs) -> bytes:
|
||||||
|
"""Custom JSON serialization using orjson"""
|
||||||
|
return orjson.dumps(self.model_dump(**kwargs), default=orjson_default)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_validate_json(cls, json_data: Union[str, bytes], **kwargs):
|
||||||
|
"""Custom JSON deserialization using orjson"""
|
||||||
|
if isinstance(json_data, str):
|
||||||
|
json_data = json_data.encode()
|
||||||
|
return cls.model_validate(orjson.loads(json_data), **kwargs)
|
||||||
|
class UrlModel(ORJSONModel):
|
||||||
url: HttpUrl
|
url: HttpUrl
|
||||||
forced: bool = False
|
forced: bool = False
|
||||||
|
|
||||||
@@ -108,7 +255,7 @@ class TraversalStats:
|
|||||||
total_depth_reached: int = 0
|
total_depth_reached: int = 0
|
||||||
current_depth: int = 0
|
current_depth: int = 0
|
||||||
|
|
||||||
class DispatchResult(BaseModel):
|
class DispatchResult(ORJSONModel):
|
||||||
task_id: str
|
task_id: str
|
||||||
memory_usage: float
|
memory_usage: float
|
||||||
peak_memory: float
|
peak_memory: float
|
||||||
@@ -116,7 +263,7 @@ class DispatchResult(BaseModel):
|
|||||||
end_time: Union[datetime, float]
|
end_time: Union[datetime, float]
|
||||||
error_message: str = ""
|
error_message: str = ""
|
||||||
|
|
||||||
class MarkdownGenerationResult(BaseModel):
|
class MarkdownGenerationResult(ORJSONModel):
|
||||||
raw_markdown: str
|
raw_markdown: str
|
||||||
markdown_with_citations: str
|
markdown_with_citations: str
|
||||||
references_markdown: str
|
references_markdown: str
|
||||||
@@ -126,7 +273,7 @@ class MarkdownGenerationResult(BaseModel):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.raw_markdown
|
return self.raw_markdown
|
||||||
|
|
||||||
class CrawlResult(BaseModel):
|
class CrawlResult(ORJSONModel):
|
||||||
url: str
|
url: str
|
||||||
html: str
|
html: str
|
||||||
fit_html: Optional[str] = None
|
fit_html: Optional[str] = None
|
||||||
@@ -156,6 +303,10 @@ class CrawlResult(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def get_deprecated_properties(self) -> set[str]:
|
||||||
|
"""Define deprecated properties that should be excluded from serialization."""
|
||||||
|
return {'fit_html', 'fit_markdown', 'markdown_v2'}
|
||||||
|
|
||||||
# NOTE: The StringCompatibleMarkdown class, custom __init__ method, property getters/setters,
|
# NOTE: The StringCompatibleMarkdown class, custom __init__ method, property getters/setters,
|
||||||
# and model_dump override all exist to support a smooth transition from markdown as a string
|
# and model_dump override all exist to support a smooth transition from markdown as a string
|
||||||
# to markdown as a MarkdownGenerationResult object, while maintaining backward compatibility.
|
# to markdown as a MarkdownGenerationResult object, while maintaining backward compatibility.
|
||||||
@@ -245,14 +396,16 @@ class CrawlResult(BaseModel):
|
|||||||
1. PrivateAttr fields are excluded from serialization by default
|
1. PrivateAttr fields are excluded from serialization by default
|
||||||
2. We need to maintain backward compatibility by including the 'markdown' field
|
2. We need to maintain backward compatibility by including the 'markdown' field
|
||||||
in the serialized output
|
in the serialized output
|
||||||
3. We're transitioning from 'markdown_v2' to enhancing 'markdown' to hold
|
3. Uses the DeprecatedPropertiesMixin to automatically exclude deprecated properties
|
||||||
the same type of data
|
|
||||||
|
|
||||||
Future developers: This method ensures that the markdown content is properly
|
Future developers: This method ensures that the markdown content is properly
|
||||||
serialized despite being stored in a private attribute. If the serialization
|
serialized despite being stored in a private attribute. The deprecated properties
|
||||||
requirements change, this is where you would update the logic.
|
are automatically handled by the mixin.
|
||||||
"""
|
"""
|
||||||
|
# Use the parent class method which handles deprecated properties automatically
|
||||||
result = super().model_dump(*args, **kwargs)
|
result = super().model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
# Add the markdown content if it exists
|
||||||
if self._markdown is not None:
|
if self._markdown is not None:
|
||||||
result["markdown"] = self._markdown.model_dump()
|
result["markdown"] = self._markdown.model_dump()
|
||||||
return result
|
return result
|
||||||
@@ -307,7 +460,7 @@ RunManyReturn = Union[
|
|||||||
# 1. Replace the private attribute and property with a standard field
|
# 1. Replace the private attribute and property with a standard field
|
||||||
# 2. Update any serialization logic that might depend on the current behavior
|
# 2. Update any serialization logic that might depend on the current behavior
|
||||||
|
|
||||||
class AsyncCrawlResponse(BaseModel):
|
class AsyncCrawlResponse(ORJSONModel):
|
||||||
html: str
|
html: str
|
||||||
response_headers: Dict[str, str]
|
response_headers: Dict[str, str]
|
||||||
js_execution_result: Optional[Dict[str, Any]] = None
|
js_execution_result: Optional[Dict[str, Any]] = None
|
||||||
@@ -328,7 +481,7 @@ class AsyncCrawlResponse(BaseModel):
|
|||||||
###############################
|
###############################
|
||||||
# Scraping Models
|
# Scraping Models
|
||||||
###############################
|
###############################
|
||||||
class MediaItem(BaseModel):
|
class MediaItem(ORJSONModel):
|
||||||
src: Optional[str] = ""
|
src: Optional[str] = ""
|
||||||
data: Optional[str] = ""
|
data: Optional[str] = ""
|
||||||
alt: Optional[str] = ""
|
alt: Optional[str] = ""
|
||||||
@@ -340,7 +493,7 @@ class MediaItem(BaseModel):
|
|||||||
width: Optional[int] = None
|
width: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class Link(BaseModel):
|
class Link(ORJSONModel):
|
||||||
href: Optional[str] = ""
|
href: Optional[str] = ""
|
||||||
text: Optional[str] = ""
|
text: Optional[str] = ""
|
||||||
title: Optional[str] = ""
|
title: Optional[str] = ""
|
||||||
@@ -353,7 +506,7 @@ class Link(BaseModel):
|
|||||||
total_score: Optional[float] = None # Combined score from intrinsic and contextual scores
|
total_score: Optional[float] = None # Combined score from intrinsic and contextual scores
|
||||||
|
|
||||||
|
|
||||||
class Media(BaseModel):
|
class Media(ORJSONModel):
|
||||||
images: List[MediaItem] = []
|
images: List[MediaItem] = []
|
||||||
videos: List[
|
videos: List[
|
||||||
MediaItem
|
MediaItem
|
||||||
@@ -364,12 +517,12 @@ class Media(BaseModel):
|
|||||||
tables: List[Dict] = [] # Table data extracted from HTML tables
|
tables: List[Dict] = [] # Table data extracted from HTML tables
|
||||||
|
|
||||||
|
|
||||||
class Links(BaseModel):
|
class Links(ORJSONModel):
|
||||||
internal: List[Link] = []
|
internal: List[Link] = []
|
||||||
external: List[Link] = []
|
external: List[Link] = []
|
||||||
|
|
||||||
|
|
||||||
class ScrapingResult(BaseModel):
|
class ScrapingResult(ORJSONModel):
|
||||||
cleaned_html: str
|
cleaned_html: str
|
||||||
success: bool
|
success: bool
|
||||||
media: Media = Media()
|
media: Media = Media()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import orjson
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -384,27 +385,60 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
|
|||||||
|
|
||||||
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
|
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
|
||||||
"""Stream results with heartbeats and completion markers."""
|
"""Stream results with heartbeats and completion markers."""
|
||||||
import json
|
import orjson
|
||||||
from utils import datetime_handler
|
from datetime import datetime
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
def orjson_default(obj):
|
||||||
|
# Handle datetime (if not already handled by orjson)
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
# Handle property objects (convert to string or something else)
|
||||||
|
if isinstance(obj, property):
|
||||||
|
return str(obj)
|
||||||
|
# Last resort: convert to string
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Starting streaming with results_gen type: {type(results_gen)}")
|
||||||
|
logger.info(f"Is results_gen async generator: {inspect.isasyncgen(results_gen)}")
|
||||||
|
|
||||||
|
# Check if results_gen is actually an async generator vs another type
|
||||||
|
if inspect.isasyncgen(results_gen):
|
||||||
|
logger.info("Processing as async generator")
|
||||||
async for result in results_gen:
|
async for result in results_gen:
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Processing streaming result of type: {type(result)}")
|
||||||
|
|
||||||
|
# Check if this result is actually a CrawlResult
|
||||||
|
if hasattr(result, 'model_dump_json'):
|
||||||
server_memory_mb = _get_memory_mb()
|
server_memory_mb = _get_memory_mb()
|
||||||
result_dict = result.model_dump()
|
result_json = result.model_dump_json()
|
||||||
|
result_dict = orjson.loads(result_json)
|
||||||
result_dict['server_memory_mb'] = server_memory_mb
|
result_dict['server_memory_mb'] = server_memory_mb
|
||||||
# If PDF exists, encode it to base64
|
|
||||||
if result_dict.get('pdf') is not None:
|
if result_dict.get('pdf') is not None:
|
||||||
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
||||||
|
|
||||||
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
||||||
data = json.dumps(result_dict, default=datetime_handler) + "\n"
|
data = orjson.dumps(result_dict, default=orjson_default).decode('utf-8') + "\n"
|
||||||
yield data.encode('utf-8')
|
yield data.encode('utf-8')
|
||||||
|
else:
|
||||||
|
logger.error(f"Result doesn't have model_dump_json method: {type(result)}")
|
||||||
|
error_response = {"error": f"Invalid result type: {type(result)}", "url": "unknown"}
|
||||||
|
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Serialization error: {e}")
|
logger.error(f"Serialization error: {e}")
|
||||||
|
logger.error(f"Result type was: {type(result)}")
|
||||||
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
||||||
yield (json.dumps(error_response) + "\n").encode('utf-8')
|
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
|
||||||
|
else:
|
||||||
|
logger.error(f"results_gen is not an async generator: {type(results_gen)}")
|
||||||
|
error_response = {"error": f"Invalid results_gen type: {type(results_gen)}"}
|
||||||
|
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
|
||||||
|
|
||||||
yield json.dumps({"status": "completed"}).encode('utf-8')
|
yield orjson.dumps({"status": "completed"}).decode('utf-8').encode('utf-8')
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.warning("Client disconnected during streaming")
|
logger.warning("Client disconnected during streaming")
|
||||||
@@ -472,7 +506,9 @@ async def handle_crawl_request(
|
|||||||
# Process results to handle PDF bytes
|
# Process results to handle PDF bytes
|
||||||
processed_results = []
|
processed_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
result_dict = result.model_dump()
|
# Use ORJSON serialization to handle property objects properly
|
||||||
|
result_json = result.model_dump_json()
|
||||||
|
result_dict = orjson.loads(result_json)
|
||||||
# If PDF exists, encode it to base64
|
# If PDF exists, encode it to base64
|
||||||
if result_dict.get('pdf') is not None:
|
if result_dict.get('pdf') is not None:
|
||||||
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
||||||
@@ -522,7 +558,18 @@ async def handle_stream_crawl_request(
|
|||||||
browser_config.verbose = False
|
browser_config.verbose = False
|
||||||
crawler_config = CrawlerRunConfig.load(crawler_config)
|
crawler_config = CrawlerRunConfig.load(crawler_config)
|
||||||
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
|
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
|
||||||
crawler_config.stream = True
|
# Don't force stream=True here - let the deep crawl strategy control its own streaming behavior
|
||||||
|
|
||||||
|
# Apply global base config (this was missing!)
|
||||||
|
base_config = config["crawler"]["base_config"]
|
||||||
|
for key, value in base_config.items():
|
||||||
|
if hasattr(crawler_config, key):
|
||||||
|
print(f"[DEBUG] Applying base_config: {key} = {value}")
|
||||||
|
setattr(crawler_config, key, value)
|
||||||
|
|
||||||
|
print(f"[DEBUG] Deep crawl strategy: {type(crawler_config.deep_crawl_strategy).__name__ if crawler_config.deep_crawl_strategy else 'None'}")
|
||||||
|
print(f"[DEBUG] Stream mode: {crawler_config.stream}")
|
||||||
|
print(f"[DEBUG] Simulate user: {getattr(crawler_config, 'simulate_user', 'Not set')}")
|
||||||
|
|
||||||
dispatcher = MemoryAdaptiveDispatcher(
|
dispatcher = MemoryAdaptiveDispatcher(
|
||||||
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
||||||
@@ -537,12 +584,59 @@ async def handle_stream_crawl_request(
|
|||||||
# crawler = AsyncWebCrawler(config=browser_config)
|
# crawler = AsyncWebCrawler(config=browser_config)
|
||||||
# await crawler.start()
|
# await crawler.start()
|
||||||
|
|
||||||
|
# Use correct method based on URL count (same as regular endpoint)
|
||||||
|
if len(urls) == 1:
|
||||||
|
# For single URL, use arun to get CrawlResult, then wrap in async generator
|
||||||
|
single_result_container = await crawler.arun(
|
||||||
|
url=urls[0],
|
||||||
|
config=crawler_config,
|
||||||
|
dispatcher=dispatcher
|
||||||
|
)
|
||||||
|
|
||||||
|
async def single_result_generator():
|
||||||
|
# Handle CrawlResultContainer - extract the actual results
|
||||||
|
if hasattr(single_result_container, '_results'):
|
||||||
|
# It's a CrawlResultContainer - iterate over the internal results
|
||||||
|
for result in single_result_container._results:
|
||||||
|
# Check if the result is an async generator (from deep crawl)
|
||||||
|
if hasattr(result, '__aiter__'):
|
||||||
|
async for sub_result in result:
|
||||||
|
yield sub_result
|
||||||
|
else:
|
||||||
|
yield result
|
||||||
|
elif hasattr(single_result_container, '__aiter__'):
|
||||||
|
# It's an async generator (from streaming deep crawl)
|
||||||
|
async for result in single_result_container:
|
||||||
|
yield result
|
||||||
|
elif hasattr(single_result_container, '__iter__') and not hasattr(single_result_container, 'url'):
|
||||||
|
# It's iterable but not a CrawlResult itself
|
||||||
|
for result in single_result_container:
|
||||||
|
# Check if each result is an async generator
|
||||||
|
if hasattr(result, '__aiter__'):
|
||||||
|
async for sub_result in result:
|
||||||
|
yield sub_result
|
||||||
|
else:
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
# It's a single CrawlResult
|
||||||
|
yield single_result_container
|
||||||
|
|
||||||
|
results_gen = single_result_generator()
|
||||||
|
else:
|
||||||
|
# For multiple URLs, use arun_many
|
||||||
results_gen = await crawler.arun_many(
|
results_gen = await crawler.arun_many(
|
||||||
urls=urls,
|
urls=urls,
|
||||||
config=crawler_config,
|
config=crawler_config,
|
||||||
dispatcher=dispatcher
|
dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If results_gen is a list (e.g., from deep crawl), convert to async generator
|
||||||
|
if isinstance(results_gen, list):
|
||||||
|
async def convert_list_to_generator():
|
||||||
|
for result in results_gen:
|
||||||
|
yield result
|
||||||
|
results_gen = convert_list_to_generator()
|
||||||
|
|
||||||
return crawler, results_gen
|
return crawler, results_gen
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -7,13 +7,16 @@ Crawl4AI FastAPI entry‑point
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# ── stdlib & 3rd‑party imports ───────────────────────────────
|
# ── stdlib & 3rd‑party imports ───────────────────────────────
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import orjson
|
||||||
from crawler_pool import get_crawler, close_all, janitor
|
from crawler_pool import get_crawler, close_all, janitor
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||||
from auth import create_access_token, get_token_dependency, TokenRequest
|
from auth import create_access_token, get_token_dependency, TokenRequest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from fastapi import Request, Depends
|
from fastapi import Request, Depends
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse, ORJSONResponse
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||||
@@ -32,6 +35,8 @@ from schemas import (
|
|||||||
JSEndpointRequest,
|
JSEndpointRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use the proper serialization functions from async_configs
|
||||||
|
from crawl4ai.async_configs import to_serializable_dict
|
||||||
from utils import (
|
from utils import (
|
||||||
FilterType, load_config, setup_logging, verify_email_domain
|
FilterType, load_config, setup_logging, verify_email_domain
|
||||||
)
|
)
|
||||||
@@ -112,11 +117,26 @@ async def lifespan(_: FastAPI):
|
|||||||
app.state.janitor.cancel()
|
app.state.janitor.cancel()
|
||||||
await close_all()
|
await close_all()
|
||||||
|
|
||||||
|
def orjson_default(obj):
|
||||||
|
# Handle datetime (if not already handled by orjson)
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
|
||||||
|
# Handle property objects (convert to string or something else)
|
||||||
|
if isinstance(obj, property):
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
# Last resort: convert to string
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
def orjson_dumps(v, *, default):
|
||||||
|
return orjson.dumps(v, default=orjson_default).decode()
|
||||||
# ───────────────────── FastAPI instance ──────────────────────
|
# ───────────────────── FastAPI instance ──────────────────────
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=config["app"]["title"],
|
title=config["app"]["title"],
|
||||||
version=config["app"]["version"],
|
version=config["app"]["version"],
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
|
default_response_class=ORJSONResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── static playground ──────────────────────────────────────
|
# ── static playground ──────────────────────────────────────
|
||||||
@@ -435,6 +455,7 @@ async def crawl(
|
|||||||
"""
|
"""
|
||||||
Crawl a list of URLs and return the results as JSON.
|
Crawl a list of URLs and return the results as JSON.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
if not crawl_request.urls:
|
if not crawl_request.urls:
|
||||||
raise HTTPException(400, "At least one URL required")
|
raise HTTPException(400, "At least one URL required")
|
||||||
res = await handle_crawl_request(
|
res = await handle_crawl_request(
|
||||||
@@ -443,7 +464,11 @@ async def crawl(
|
|||||||
crawler_config=crawl_request.crawler_config,
|
crawler_config=crawl_request.crawler_config,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
return JSONResponse(res)
|
# handle_crawl_request returns a dictionary, so we can pass it directly to ORJSONResponse
|
||||||
|
return ORJSONResponse(res)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
return ORJSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/crawl/stream")
|
@app.post("/crawl/stream")
|
||||||
|
|||||||
1097
tests/test_comprehensive_fixes.py
Normal file
1097
tests/test_comprehensive_fixes.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user