Compare commits
2 Commits
fix/https-
...
feat/ahmed
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e1362acf5 | ||
|
|
07e9d651fb |
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,11 @@
|
||||
# Scripts folder (private tools)
|
||||
.scripts/
|
||||
|
||||
# Local development CLI (private)
|
||||
local_dev.py
|
||||
dev
|
||||
DEV_CLI_README.md
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@@ -270,4 +275,7 @@ docs/**/data
|
||||
.codecat/
|
||||
|
||||
docs/apps/linkdin/debug*/
|
||||
docs/apps/linkdin/samples/insights/*
|
||||
docs/apps/linkdin/samples/insights/*
|
||||
|
||||
# Production checklist (local, not for version control)
|
||||
PRODUCTION_CHECKLIST.md
|
||||
@@ -38,7 +38,14 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
||||
self.include_external = include_external
|
||||
self.score_threshold = score_threshold
|
||||
self.max_pages = max_pages
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
# Type check for logger
|
||||
if isinstance(logger, dict):
|
||||
logging.getLogger(__name__).warning(
|
||||
"BFSDeepCrawlStrategy received a dict as logger; falling back to default logger."
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
else:
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.stats = TraversalStats(start_time=datetime.now())
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._pages_crawled = 0
|
||||
|
||||
@@ -30,7 +30,7 @@ class Crawl4aiDockerClient:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
timeout: float = 30.0,
|
||||
timeout: float = 600.0, # Increased to 10 minutes for crawling operations
|
||||
verify_ssl: bool = True,
|
||||
verbose: bool = True,
|
||||
log_file: Optional[str] = None
|
||||
@@ -113,21 +113,12 @@ class Crawl4aiDockerClient:
|
||||
self.logger.info(f"Crawling {len(urls)} URLs {'(streaming)' if is_streaming else ''}", tag="CRAWL")
|
||||
|
||||
if is_streaming:
|
||||
async def stream_results() -> AsyncGenerator[CrawlResult, None]:
|
||||
async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip():
|
||||
result = json.loads(line)
|
||||
if "error" in result:
|
||||
self.logger.error_status(url=result.get("url", "unknown"), error=result["error"])
|
||||
continue
|
||||
self.logger.url_status(url=result.get("url", "unknown"), success=True, timing=result.get("timing", 0.0))
|
||||
if result.get("status") == "completed":
|
||||
continue
|
||||
else:
|
||||
yield CrawlResult(**result)
|
||||
return stream_results()
|
||||
# For streaming, we need to return the async generator properly
|
||||
# The caller should be able to do: async for result in await client.crawl(...)
|
||||
async def streaming_wrapper():
|
||||
async for result in self._stream_crawl_results(data):
|
||||
yield result
|
||||
return streaming_wrapper()
|
||||
|
||||
response = await self._request("POST", "/crawl", json=data)
|
||||
result_data = response.json()
|
||||
@@ -138,6 +129,35 @@ class Crawl4aiDockerClient:
|
||||
self.logger.success(f"Crawl completed with {len(results)} results", tag="CRAWL")
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
async def _stream_crawl_results(self, data: Dict[str, Any]) -> AsyncGenerator[CrawlResult, None]:
|
||||
"""Internal method to handle streaming crawl results."""
|
||||
async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
result = json.loads(line)
|
||||
if "error" in result:
|
||||
self.logger.error_status(url=result.get("url", "unknown"), error=result["error"])
|
||||
continue
|
||||
|
||||
# Check if this is a crawl result (has required fields)
|
||||
if "url" in result and "success" in result:
|
||||
self.logger.url_status(url=result.get("url", "unknown"), success=result.get("success", False), timing=result.get("timing", 0.0))
|
||||
|
||||
# Create CrawlResult object properly
|
||||
crawl_result = CrawlResult(**result)
|
||||
yield crawl_result
|
||||
# Skip status-only messages
|
||||
elif result.get("status") == "completed":
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"Failed to parse streaming response: {e}", tag="STREAM")
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing streaming result: {e}", tag="STREAM")
|
||||
continue
|
||||
|
||||
async def get_schema(self) -> Dict[str, Any]:
|
||||
"""Retrieve configuration schemas."""
|
||||
response = await self._request("GET", "/schema")
|
||||
|
||||
@@ -1,4 +1,36 @@
|
||||
from pydantic import BaseModel, HttpUrl, PrivateAttr, Field
|
||||
|
||||
"""
|
||||
Crawl4AI Models Module
|
||||
|
||||
This module contains Pydantic models used throughout the Crawl4AI library.
|
||||
|
||||
Key Features:
|
||||
- ORJSONModel: Base model with ORJSON serialization support
|
||||
- DeprecatedPropertiesMixin: Global system for handling deprecated properties
|
||||
- CrawlResult: Main result model with backward compatibility support
|
||||
|
||||
Deprecated Properties System:
|
||||
The DeprecatedPropertiesMixin provides a global way to handle deprecated properties
|
||||
across all models. Instead of manually excluding deprecated properties in each
|
||||
model_dump() call, you can simply override the get_deprecated_properties() method:
|
||||
|
||||
Example:
|
||||
class MyModel(ORJSONModel):
|
||||
name: str
|
||||
old_field: Optional[str] = None
|
||||
|
||||
def get_deprecated_properties(self) -> set[str]:
|
||||
return {'old_field', 'another_deprecated_field'}
|
||||
|
||||
@property
|
||||
def old_field(self):
|
||||
raise AttributeError("old_field is deprecated, use name instead")
|
||||
|
||||
The system automatically excludes these properties from serialization, preventing
|
||||
property objects from appearing in JSON output.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict,HttpUrl, PrivateAttr, Field
|
||||
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generic, TypeVar
|
||||
@@ -8,7 +40,7 @@ from .ssl_certificate import SSLCertificate
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
import orjson
|
||||
###############################
|
||||
# Dispatcher Models
|
||||
###############################
|
||||
@@ -91,7 +123,122 @@ class TokenUsage:
|
||||
completion_tokens_details: Optional[dict] = None
|
||||
prompt_tokens_details: Optional[dict] = None
|
||||
|
||||
class UrlModel(BaseModel):
|
||||
|
||||
def orjson_default(obj):
|
||||
# Handle datetime (if not already handled by orjson)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
|
||||
# Handle property objects (convert to string or something else)
|
||||
if isinstance(obj, property):
|
||||
return str(obj)
|
||||
|
||||
# Last resort: convert to string
|
||||
return str(obj)
|
||||
|
||||
|
||||
class DeprecatedPropertiesMixin:
|
||||
"""
|
||||
Mixin to handle deprecated properties in Pydantic models.
|
||||
|
||||
Classes that inherit from this mixin can define deprecated properties
|
||||
that will be automatically excluded from serialization.
|
||||
|
||||
Usage:
|
||||
1. Override the get_deprecated_properties() method to return a set of deprecated property names
|
||||
2. The model_dump method will automatically exclude these properties
|
||||
|
||||
Example:
|
||||
class MyModel(ORJSONModel):
|
||||
def get_deprecated_properties(self) -> set[str]:
|
||||
return {'old_field', 'legacy_property'}
|
||||
|
||||
name: str
|
||||
old_field: Optional[str] = None # Field definition
|
||||
|
||||
@property
|
||||
def old_field(self): # Property that overrides the field
|
||||
raise AttributeError("old_field is deprecated, use name instead")
|
||||
"""
|
||||
|
||||
def get_deprecated_properties(self) -> set[str]:
|
||||
"""
|
||||
Get deprecated property names for this model.
|
||||
Override this method in subclasses to define deprecated properties.
|
||||
|
||||
Returns:
|
||||
set[str]: Set of deprecated property names
|
||||
"""
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def get_all_deprecated_properties(cls) -> set[str]:
|
||||
"""
|
||||
Get all deprecated properties from this class and all parent classes.
|
||||
|
||||
Returns:
|
||||
set[str]: Set of all deprecated property names
|
||||
"""
|
||||
deprecated_props = set()
|
||||
# Create an instance to call the instance method
|
||||
try:
|
||||
# Try to create a dummy instance to get deprecated properties
|
||||
dummy_instance = cls.__new__(cls)
|
||||
deprecated_props.update(dummy_instance.get_deprecated_properties())
|
||||
except Exception:
|
||||
# If we can't create an instance, check for class-level definitions
|
||||
pass
|
||||
|
||||
# Also check parent classes
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, 'get_deprecated_properties') and klass != DeprecatedPropertiesMixin:
|
||||
try:
|
||||
dummy_instance = klass.__new__(klass)
|
||||
deprecated_props.update(dummy_instance.get_deprecated_properties())
|
||||
except Exception:
|
||||
pass
|
||||
return deprecated_props
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""
|
||||
Override model_dump to automatically exclude deprecated properties.
|
||||
|
||||
This method:
|
||||
1. Gets the existing exclude set from kwargs
|
||||
2. Adds all deprecated properties defined in get_deprecated_properties()
|
||||
3. Calls the parent model_dump with the updated exclude set
|
||||
"""
|
||||
# Get the default exclude set, or create empty set if None
|
||||
exclude = kwargs.get('exclude', set())
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
elif not isinstance(exclude, set):
|
||||
exclude = set(exclude) if exclude else set()
|
||||
|
||||
# Add deprecated properties for this instance
|
||||
exclude.update(self.get_deprecated_properties())
|
||||
kwargs['exclude'] = exclude
|
||||
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
|
||||
class ORJSONModel(DeprecatedPropertiesMixin, BaseModel):
|
||||
model_config = ConfigDict(
|
||||
ser_json_timedelta="iso8601", # Optional: format timedelta
|
||||
ser_json_bytes="utf8", # Optional: bytes → UTF-8 string
|
||||
)
|
||||
|
||||
def model_dump_json(self, **kwargs) -> bytes:
|
||||
"""Custom JSON serialization using orjson"""
|
||||
return orjson.dumps(self.model_dump(**kwargs), default=orjson_default)
|
||||
|
||||
@classmethod
|
||||
def model_validate_json(cls, json_data: Union[str, bytes], **kwargs):
|
||||
"""Custom JSON deserialization using orjson"""
|
||||
if isinstance(json_data, str):
|
||||
json_data = json_data.encode()
|
||||
return cls.model_validate(orjson.loads(json_data), **kwargs)
|
||||
class UrlModel(ORJSONModel):
|
||||
url: HttpUrl
|
||||
forced: bool = False
|
||||
|
||||
@@ -108,7 +255,7 @@ class TraversalStats:
|
||||
total_depth_reached: int = 0
|
||||
current_depth: int = 0
|
||||
|
||||
class DispatchResult(BaseModel):
|
||||
class DispatchResult(ORJSONModel):
|
||||
task_id: str
|
||||
memory_usage: float
|
||||
peak_memory: float
|
||||
@@ -116,7 +263,7 @@ class DispatchResult(BaseModel):
|
||||
end_time: Union[datetime, float]
|
||||
error_message: str = ""
|
||||
|
||||
class MarkdownGenerationResult(BaseModel):
|
||||
class MarkdownGenerationResult(ORJSONModel):
|
||||
raw_markdown: str
|
||||
markdown_with_citations: str
|
||||
references_markdown: str
|
||||
@@ -126,7 +273,7 @@ class MarkdownGenerationResult(BaseModel):
|
||||
def __str__(self):
|
||||
return self.raw_markdown
|
||||
|
||||
class CrawlResult(BaseModel):
|
||||
class CrawlResult(ORJSONModel):
|
||||
url: str
|
||||
html: str
|
||||
fit_html: Optional[str] = None
|
||||
@@ -156,6 +303,10 @@ class CrawlResult(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_deprecated_properties(self) -> set[str]:
|
||||
"""Define deprecated properties that should be excluded from serialization."""
|
||||
return {'fit_html', 'fit_markdown', 'markdown_v2'}
|
||||
|
||||
# NOTE: The StringCompatibleMarkdown class, custom __init__ method, property getters/setters,
|
||||
# and model_dump override all exist to support a smooth transition from markdown as a string
|
||||
# to markdown as a MarkdownGenerationResult object, while maintaining backward compatibility.
|
||||
@@ -245,14 +396,16 @@ class CrawlResult(BaseModel):
|
||||
1. PrivateAttr fields are excluded from serialization by default
|
||||
2. We need to maintain backward compatibility by including the 'markdown' field
|
||||
in the serialized output
|
||||
3. We're transitioning from 'markdown_v2' to enhancing 'markdown' to hold
|
||||
the same type of data
|
||||
3. Uses the DeprecatedPropertiesMixin to automatically exclude deprecated properties
|
||||
|
||||
Future developers: This method ensures that the markdown content is properly
|
||||
serialized despite being stored in a private attribute. If the serialization
|
||||
requirements change, this is where you would update the logic.
|
||||
serialized despite being stored in a private attribute. The deprecated properties
|
||||
are automatically handled by the mixin.
|
||||
"""
|
||||
# Use the parent class method which handles deprecated properties automatically
|
||||
result = super().model_dump(*args, **kwargs)
|
||||
|
||||
# Add the markdown content if it exists
|
||||
if self._markdown is not None:
|
||||
result["markdown"] = self._markdown.model_dump()
|
||||
return result
|
||||
@@ -307,7 +460,7 @@ RunManyReturn = Union[
|
||||
# 1. Replace the private attribute and property with a standard field
|
||||
# 2. Update any serialization logic that might depend on the current behavior
|
||||
|
||||
class AsyncCrawlResponse(BaseModel):
|
||||
class AsyncCrawlResponse(ORJSONModel):
|
||||
html: str
|
||||
response_headers: Dict[str, str]
|
||||
js_execution_result: Optional[Dict[str, Any]] = None
|
||||
@@ -328,7 +481,7 @@ class AsyncCrawlResponse(BaseModel):
|
||||
###############################
|
||||
# Scraping Models
|
||||
###############################
|
||||
class MediaItem(BaseModel):
|
||||
class MediaItem(ORJSONModel):
|
||||
src: Optional[str] = ""
|
||||
data: Optional[str] = ""
|
||||
alt: Optional[str] = ""
|
||||
@@ -340,7 +493,7 @@ class MediaItem(BaseModel):
|
||||
width: Optional[int] = None
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
class Link(ORJSONModel):
|
||||
href: Optional[str] = ""
|
||||
text: Optional[str] = ""
|
||||
title: Optional[str] = ""
|
||||
@@ -353,7 +506,7 @@ class Link(BaseModel):
|
||||
total_score: Optional[float] = None # Combined score from intrinsic and contextual scores
|
||||
|
||||
|
||||
class Media(BaseModel):
|
||||
class Media(ORJSONModel):
|
||||
images: List[MediaItem] = []
|
||||
videos: List[
|
||||
MediaItem
|
||||
@@ -364,12 +517,12 @@ class Media(BaseModel):
|
||||
tables: List[Dict] = [] # Table data extracted from HTML tables
|
||||
|
||||
|
||||
class Links(BaseModel):
|
||||
class Links(ORJSONModel):
|
||||
internal: List[Link] = []
|
||||
external: List[Link] = []
|
||||
|
||||
|
||||
class ScrapingResult(BaseModel):
|
||||
class ScrapingResult(ORJSONModel):
|
||||
cleaned_html: str
|
||||
success: bool
|
||||
media: Media = Media()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import orjson
|
||||
import asyncio
|
||||
from typing import List, Tuple, Dict
|
||||
from functools import partial
|
||||
@@ -384,27 +385,60 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
|
||||
|
||||
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream results with heartbeats and completion markers."""
|
||||
import json
|
||||
from utils import datetime_handler
|
||||
import orjson
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
|
||||
def orjson_default(obj):
|
||||
# Handle datetime (if not already handled by orjson)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
# Handle property objects (convert to string or something else)
|
||||
if isinstance(obj, property):
|
||||
return str(obj)
|
||||
# Last resort: convert to string
|
||||
return str(obj)
|
||||
|
||||
try:
|
||||
async for result in results_gen:
|
||||
try:
|
||||
server_memory_mb = _get_memory_mb()
|
||||
result_dict = result.model_dump()
|
||||
result_dict['server_memory_mb'] = server_memory_mb
|
||||
# If PDF exists, encode it to base64
|
||||
if result_dict.get('pdf') is not None:
|
||||
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
||||
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
||||
data = json.dumps(result_dict, default=datetime_handler) + "\n"
|
||||
yield data.encode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"Serialization error: {e}")
|
||||
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
||||
yield (json.dumps(error_response) + "\n").encode('utf-8')
|
||||
logger.info(f"Starting streaming with results_gen type: {type(results_gen)}")
|
||||
logger.info(f"Is results_gen async generator: {inspect.isasyncgen(results_gen)}")
|
||||
|
||||
# Check if results_gen is actually an async generator vs another type
|
||||
if inspect.isasyncgen(results_gen):
|
||||
logger.info("Processing as async generator")
|
||||
async for result in results_gen:
|
||||
try:
|
||||
logger.info(f"Processing streaming result of type: {type(result)}")
|
||||
|
||||
# Check if this result is actually a CrawlResult
|
||||
if hasattr(result, 'model_dump_json'):
|
||||
server_memory_mb = _get_memory_mb()
|
||||
result_json = result.model_dump_json()
|
||||
result_dict = orjson.loads(result_json)
|
||||
result_dict['server_memory_mb'] = server_memory_mb
|
||||
|
||||
if result_dict.get('pdf') is not None:
|
||||
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
||||
|
||||
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
||||
data = orjson.dumps(result_dict, default=orjson_default).decode('utf-8') + "\n"
|
||||
yield data.encode('utf-8')
|
||||
else:
|
||||
logger.error(f"Result doesn't have model_dump_json method: {type(result)}")
|
||||
error_response = {"error": f"Invalid result type: {type(result)}", "url": "unknown"}
|
||||
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Serialization error: {e}")
|
||||
logger.error(f"Result type was: {type(result)}")
|
||||
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
||||
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
|
||||
else:
|
||||
logger.error(f"results_gen is not an async generator: {type(results_gen)}")
|
||||
error_response = {"error": f"Invalid results_gen type: {type(results_gen)}"}
|
||||
yield (orjson.dumps(error_response).decode('utf-8') + "\n").encode('utf-8')
|
||||
|
||||
yield json.dumps({"status": "completed"}).encode('utf-8')
|
||||
yield orjson.dumps({"status": "completed"}).decode('utf-8').encode('utf-8')
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Client disconnected during streaming")
|
||||
@@ -472,7 +506,9 @@ async def handle_crawl_request(
|
||||
# Process results to handle PDF bytes
|
||||
processed_results = []
|
||||
for result in results:
|
||||
result_dict = result.model_dump()
|
||||
# Use ORJSON serialization to handle property objects properly
|
||||
result_json = result.model_dump_json()
|
||||
result_dict = orjson.loads(result_json)
|
||||
# If PDF exists, encode it to base64
|
||||
if result_dict.get('pdf') is not None:
|
||||
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
|
||||
@@ -522,8 +558,19 @@ async def handle_stream_crawl_request(
|
||||
browser_config.verbose = False
|
||||
crawler_config = CrawlerRunConfig.load(crawler_config)
|
||||
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
|
||||
crawler_config.stream = True
|
||||
# Don't force stream=True here - let the deep crawl strategy control its own streaming behavior
|
||||
|
||||
# Apply global base config (this was missing!)
|
||||
base_config = config["crawler"]["base_config"]
|
||||
for key, value in base_config.items():
|
||||
if hasattr(crawler_config, key):
|
||||
print(f"[DEBUG] Applying base_config: {key} = {value}")
|
||||
setattr(crawler_config, key, value)
|
||||
|
||||
print(f"[DEBUG] Deep crawl strategy: {type(crawler_config.deep_crawl_strategy).__name__ if crawler_config.deep_crawl_strategy else 'None'}")
|
||||
print(f"[DEBUG] Stream mode: {crawler_config.stream}")
|
||||
print(f"[DEBUG] Simulate user: {getattr(crawler_config, 'simulate_user', 'Not set')}")
|
||||
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
||||
rate_limiter=RateLimiter(
|
||||
@@ -537,11 +584,58 @@ async def handle_stream_crawl_request(
|
||||
# crawler = AsyncWebCrawler(config=browser_config)
|
||||
# await crawler.start()
|
||||
|
||||
results_gen = await crawler.arun_many(
|
||||
urls=urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
)
|
||||
# Use correct method based on URL count (same as regular endpoint)
|
||||
if len(urls) == 1:
|
||||
# For single URL, use arun to get CrawlResult, then wrap in async generator
|
||||
single_result_container = await crawler.arun(
|
||||
url=urls[0],
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
)
|
||||
|
||||
async def single_result_generator():
|
||||
# Handle CrawlResultContainer - extract the actual results
|
||||
if hasattr(single_result_container, '_results'):
|
||||
# It's a CrawlResultContainer - iterate over the internal results
|
||||
for result in single_result_container._results:
|
||||
# Check if the result is an async generator (from deep crawl)
|
||||
if hasattr(result, '__aiter__'):
|
||||
async for sub_result in result:
|
||||
yield sub_result
|
||||
else:
|
||||
yield result
|
||||
elif hasattr(single_result_container, '__aiter__'):
|
||||
# It's an async generator (from streaming deep crawl)
|
||||
async for result in single_result_container:
|
||||
yield result
|
||||
elif hasattr(single_result_container, '__iter__') and not hasattr(single_result_container, 'url'):
|
||||
# It's iterable but not a CrawlResult itself
|
||||
for result in single_result_container:
|
||||
# Check if each result is an async generator
|
||||
if hasattr(result, '__aiter__'):
|
||||
async for sub_result in result:
|
||||
yield sub_result
|
||||
else:
|
||||
yield result
|
||||
else:
|
||||
# It's a single CrawlResult
|
||||
yield single_result_container
|
||||
|
||||
results_gen = single_result_generator()
|
||||
else:
|
||||
# For multiple URLs, use arun_many
|
||||
results_gen = await crawler.arun_many(
|
||||
urls=urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
)
|
||||
|
||||
# If results_gen is a list (e.g., from deep crawl), convert to async generator
|
||||
if isinstance(results_gen, list):
|
||||
async def convert_list_to_generator():
|
||||
for result in results_gen:
|
||||
yield result
|
||||
results_gen = convert_list_to_generator()
|
||||
|
||||
return crawler, results_gen
|
||||
|
||||
|
||||
@@ -7,13 +7,16 @@ Crawl4AI FastAPI entry‑point
|
||||
"""
|
||||
|
||||
# ── stdlib & 3rd‑party imports ───────────────────────────────
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from crawler_pool import get_crawler, close_all, janitor
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||
from auth import create_access_token, get_token_dependency, TokenRequest
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from fastapi import Request, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, ORJSONResponse
|
||||
import base64
|
||||
import re
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||
@@ -32,6 +35,8 @@ from schemas import (
|
||||
JSEndpointRequest,
|
||||
)
|
||||
|
||||
# Use the proper serialization functions from async_configs
|
||||
from crawl4ai.async_configs import to_serializable_dict
|
||||
from utils import (
|
||||
FilterType, load_config, setup_logging, verify_email_domain
|
||||
)
|
||||
@@ -112,11 +117,26 @@ async def lifespan(_: FastAPI):
|
||||
app.state.janitor.cancel()
|
||||
await close_all()
|
||||
|
||||
def orjson_default(obj):
|
||||
# Handle datetime (if not already handled by orjson)
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
|
||||
# Handle property objects (convert to string or something else)
|
||||
if isinstance(obj, property):
|
||||
return str(obj)
|
||||
|
||||
# Last resort: convert to string
|
||||
return str(obj)
|
||||
|
||||
def orjson_dumps(v, *, default):
|
||||
return orjson.dumps(v, default=orjson_default).decode()
|
||||
# ───────────────────── FastAPI instance ──────────────────────
|
||||
app = FastAPI(
|
||||
title=config["app"]["title"],
|
||||
version=config["app"]["version"],
|
||||
lifespan=lifespan,
|
||||
default_response_class=ORJSONResponse
|
||||
)
|
||||
|
||||
# ── static playground ──────────────────────────────────────
|
||||
@@ -435,15 +455,20 @@ async def crawl(
|
||||
"""
|
||||
Crawl a list of URLs and return the results as JSON.
|
||||
"""
|
||||
if not crawl_request.urls:
|
||||
raise HTTPException(400, "At least one URL required")
|
||||
res = await handle_crawl_request(
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config,
|
||||
crawler_config=crawl_request.crawler_config,
|
||||
config=config,
|
||||
)
|
||||
return JSONResponse(res)
|
||||
try:
|
||||
if not crawl_request.urls:
|
||||
raise HTTPException(400, "At least one URL required")
|
||||
res = await handle_crawl_request(
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config,
|
||||
crawler_config=crawl_request.crawler_config,
|
||||
config=config,
|
||||
)
|
||||
# handle_crawl_request returns a dictionary, so we can pass it directly to ORJSONResponse
|
||||
return ORJSONResponse(res)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
return ORJSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@app.post("/crawl/stream")
|
||||
|
||||
1097
tests/test_comprehensive_fixes.py
Normal file
1097
tests/test_comprehensive_fixes.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user