feat: enhance crawling functionality with anti-bot strategies and headless mode options (Browser adapters , 12.Undetected/stealth browser)

This commit is contained in:
AHMET YILMAZ
2025-10-03 18:02:10 +08:00
parent a599db8f7b
commit 5dc34dd210
4 changed files with 421 additions and 257 deletions

View File

@@ -1,57 +1,63 @@
import os
import json
import asyncio
from typing import List, Tuple, Dict
from functools import partial
from uuid import uuid4
from datetime import datetime, timezone
from base64 import b64encode
import json
import logging
from typing import Optional, AsyncGenerator
import os
import time
from base64 import b64encode
from datetime import datetime, timezone
from functools import partial
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import unquote
from fastapi import HTTPException, Request, status
from fastapi.background import BackgroundTasks
from fastapi.responses import JSONResponse
from redis import asyncio as aioredis
from crawl4ai import (
AsyncUrlSeeder,
AsyncWebCrawler,
CrawlerRunConfig,
LLMExtractionStrategy,
CacheMode,
BrowserConfig,
MemoryAdaptiveDispatcher,
RateLimiter,
CacheMode,
CrawlerRunConfig,
LLMConfig,
AsyncUrlSeeder,
SeedingConfig
LLMExtractionStrategy,
MemoryAdaptiveDispatcher,
PlaywrightAdapter,
RateLimiter,
SeedingConfig,
UndetectedAdapter,
)
from crawl4ai.utils import perform_completion_with_backoff
# Import StealthAdapter with fallback for compatibility
try:
from crawl4ai import StealthAdapter
except ImportError:
# Fallback: import directly from browser_adapter module
try:
import os
import sys
# Add the project root to path for development
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, project_root)
from crawl4ai.browser_adapter import StealthAdapter
except ImportError:
# If all else fails, create a simple fallback
from crawl4ai.browser_adapter import PlaywrightAdapter
class StealthAdapter(PlaywrightAdapter):
"""Fallback StealthAdapter that uses PlaywrightAdapter"""
pass
from crawl4ai.content_filter_strategy import (
PruningContentFilter,
BM25ContentFilter,
LLMContentFilter
LLMContentFilter,
PruningContentFilter,
)
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy
from utils import (
TaskStatus,
FilterType,
get_base_url,
is_task_id,
should_cleanup_task,
decode_redis_hash,
get_llm_api_key,
validate_llm_provider,
get_llm_temperature,
get_llm_base_url
)
import psutil, time
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.utils import perform_completion_with_backoff
logger = logging.getLogger(__name__)
# --- Helper to get memory ---
def _get_memory_mb():
try:
@@ -61,17 +67,39 @@ def _get_memory_mb():
return None
async def handle_llm_qa(
url: str,
query: str,
config: dict
) -> str:
# --- Helper to get browser adapter based on anti_bot_strategy ---
def _get_browser_adapter(anti_bot_strategy: str, browser_config: BrowserConfig):
"""Get the appropriate browser adapter based on anti_bot_strategy."""
if anti_bot_strategy == "stealth":
return StealthAdapter()
elif anti_bot_strategy == "undetected":
return UndetectedAdapter()
elif anti_bot_strategy == "max_evasion":
# Use undetected for maximum evasion
return UndetectedAdapter()
else: # "default"
# If stealth is enabled in browser config, use stealth adapter
if getattr(browser_config, "enable_stealth", False):
return StealthAdapter()
return PlaywrightAdapter()
# --- Helper to apply headless setting ---
def _apply_headless_setting(browser_config: BrowserConfig, headless: bool):
"""Apply headless setting to browser config."""
browser_config.headless = headless
return browser_config
async def handle_llm_qa(url: str, query: str, config: dict) -> str:
"""Process QA using LLM with crawled content as context."""
try:
if not url.startswith(('http://', 'https://')) and not url.startswith(("raw:", "raw://")):
url = 'https://' + url
if not url.startswith(("http://", "https://")) and not url.startswith(
("raw:", "raw://")
):
url = "https://" + url
# Extract base URL by finding last '?q=' occurrence
last_q_index = url.rfind('?q=')
last_q_index = url.rfind("?q=")
if last_q_index != -1:
url = url[:last_q_index]
@@ -81,7 +109,7 @@ async def handle_llm_qa(
if not result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message
detail=result.error_message,
)
content = result.markdown.fit_markdown or result.markdown.raw_markdown
@@ -101,17 +129,17 @@ async def handle_llm_qa(
prompt_with_variables=prompt,
api_token=get_llm_api_key(config), # Returns None to let litellm handle it
temperature=get_llm_temperature(config),
base_url=get_llm_base_url(config)
base_url=get_llm_base_url(config),
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"QA processing error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
async def process_llm_extraction(
redis: aioredis.Redis,
config: dict,
@@ -122,25 +150,27 @@ async def process_llm_extraction(
cache: str = "0",
provider: Optional[str] = None,
temperature: Optional[float] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> None:
"""Process LLM extraction in background."""
try:
# Validate provider
is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid:
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": error_msg
})
await redis.hset(
f"task:{task_id}",
mapping={"status": TaskStatus.FAILED, "error": error_msg},
)
return
api_key = get_llm_api_key(config, provider) # Returns None to let litellm handle it
api_key = get_llm_api_key(
config, provider
) # Returns None to let litellm handle it
llm_strategy = LLMExtractionStrategy(
llm_config=LLMConfig(
provider=provider or config["llm"]["provider"],
api_token=api_key,
temperature=temperature or get_llm_temperature(config, provider),
base_url=base_url or get_llm_base_url(config, provider)
base_url=base_url or get_llm_base_url(config, provider),
),
instruction=instruction,
schema=json.loads(schema) if schema else None,
@@ -154,32 +184,32 @@ async def process_llm_extraction(
config=CrawlerRunConfig(
extraction_strategy=llm_strategy,
scraping_strategy=LXMLWebScrapingStrategy(),
cache_mode=cache_mode
)
cache_mode=cache_mode,
),
)
if not result.success:
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": result.error_message
})
await redis.hset(
f"task:{task_id}",
mapping={"status": TaskStatus.FAILED, "error": result.error_message},
)
return
try:
content = json.loads(result.extracted_content)
except json.JSONDecodeError:
content = result.extracted_content
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.COMPLETED,
"result": json.dumps(content)
})
await redis.hset(
f"task:{task_id}",
mapping={"status": TaskStatus.COMPLETED, "result": json.dumps(content)},
)
except Exception as e:
logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": str(e)
})
await redis.hset(
f"task:{task_id}", mapping={"status": TaskStatus.FAILED, "error": str(e)}
)
async def handle_markdown_request(
url: str,
@@ -189,7 +219,7 @@ async def handle_markdown_request(
config: Optional[dict] = None,
provider: Optional[str] = None,
temperature: Optional[float] = None,
base_url: Optional[str] = None
base_url: Optional[str] = None,
) -> str:
"""Handle markdown generation requests."""
try:
@@ -198,12 +228,13 @@ async def handle_markdown_request(
is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_msg
status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg
)
decoded_url = unquote(url)
if not decoded_url.startswith(('http://', 'https://')) and not decoded_url.startswith(("raw:", "raw://")):
decoded_url = 'https://' + decoded_url
if not decoded_url.startswith(
("http://", "https://")
) and not decoded_url.startswith(("raw:", "raw://")):
decoded_url = "https://" + decoded_url
if filter_type == FilterType.RAW:
md_generator = DefaultMarkdownGenerator()
@@ -214,12 +245,15 @@ async def handle_markdown_request(
FilterType.LLM: LLMContentFilter(
llm_config=LLMConfig(
provider=provider or config["llm"]["provider"],
api_token=get_llm_api_key(config, provider), # Returns None to let litellm handle it
temperature=temperature or get_llm_temperature(config, provider),
base_url=base_url or get_llm_base_url(config, provider)
api_token=get_llm_api_key(
config, provider
), # Returns None to let litellm handle it
temperature=temperature
or get_llm_temperature(config, provider),
base_url=base_url or get_llm_base_url(config, provider),
),
instruction=query or "Extract main content"
)
instruction=query or "Extract main content",
),
}[filter_type]
md_generator = DefaultMarkdownGenerator(content_filter=content_filter)
@@ -231,27 +265,29 @@ async def handle_markdown_request(
config=CrawlerRunConfig(
markdown_generator=md_generator,
scraping_strategy=LXMLWebScrapingStrategy(),
cache_mode=cache_mode
)
cache_mode=cache_mode,
),
)
if not result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message
detail=result.error_message,
)
return (result.markdown.raw_markdown
if filter_type == FilterType.RAW
else result.markdown.fit_markdown)
return (
result.markdown.raw_markdown
if filter_type == FilterType.RAW
else result.markdown.fit_markdown
)
except Exception as e:
logger.error(f"Markdown error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
async def handle_llm_request(
redis: aioredis.Redis,
background_tasks: BackgroundTasks,
@@ -263,27 +299,27 @@ async def handle_llm_request(
config: Optional[dict] = None,
provider: Optional[str] = None,
temperature: Optional[float] = None,
api_base_url: Optional[str] = None
api_base_url: Optional[str] = None,
) -> JSONResponse:
"""Handle LLM extraction requests."""
base_url = get_base_url(request)
try:
if is_task_id(input_path):
return await handle_task_status(
redis, input_path, base_url
)
return await handle_task_status(redis, input_path, base_url)
if not query:
return JSONResponse({
"message": "Please provide an instruction",
"_links": {
"example": {
"href": f"{base_url}/llm/{input_path}?q=Extract+main+content",
"title": "Try this example"
}
return JSONResponse(
{
"message": "Please provide an instruction",
"_links": {
"example": {
"href": f"{base_url}/llm/{input_path}?q=Extract+main+content",
"title": "Try this example",
}
},
}
})
)
return await create_new_task(
redis,
@@ -296,31 +332,25 @@ async def handle_llm_request(
config,
provider,
temperature,
api_base_url
api_base_url,
)
except Exception as e:
logger.error(f"LLM endpoint error: {str(e)}", exc_info=True)
return JSONResponse({
"error": str(e),
"_links": {
"retry": {"href": str(request.url)}
}
}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return JSONResponse(
{"error": str(e), "_links": {"retry": {"href": str(request.url)}}},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
async def handle_task_status(
redis: aioredis.Redis,
task_id: str,
base_url: str,
*,
keep: bool = False
redis: aioredis.Redis, task_id: str, base_url: str, *, keep: bool = False
) -> JSONResponse:
"""Handle task status check requests."""
task = await redis.hgetall(f"task:{task_id}")
if not task:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Task not found"
status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
)
task = decode_redis_hash(task)
@@ -332,6 +362,7 @@ async def handle_task_status(
return JSONResponse(response)
async def create_new_task(
redis: aioredis.Redis,
background_tasks: BackgroundTasks,
@@ -343,21 +374,27 @@ async def create_new_task(
config: dict,
provider: Optional[str] = None,
temperature: Optional[float] = None,
api_base_url: Optional[str] = None
api_base_url: Optional[str] = None,
) -> JSONResponse:
"""Create and initialize a new task."""
decoded_url = unquote(input_path)
if not decoded_url.startswith(('http://', 'https://')) and not decoded_url.startswith(("raw:", "raw://")):
decoded_url = 'https://' + decoded_url
if not decoded_url.startswith(
("http://", "https://")
) and not decoded_url.startswith(("raw:", "raw://")):
decoded_url = "https://" + decoded_url
from datetime import datetime
task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}"
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.PROCESSING,
"created_at": datetime.now().isoformat(),
"url": decoded_url
})
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.PROCESSING,
"created_at": datetime.now().isoformat(),
"url": decoded_url,
},
)
background_tasks.add_task(
process_llm_extraction,
@@ -370,18 +407,21 @@ async def create_new_task(
cache,
provider,
temperature,
api_base_url
api_base_url,
)
return JSONResponse({
"task_id": task_id,
"status": TaskStatus.PROCESSING,
"url": decoded_url,
"_links": {
"self": {"href": f"{base_url}/llm/{task_id}"},
"status": {"href": f"{base_url}/llm/{task_id}"}
return JSONResponse(
{
"task_id": task_id,
"status": TaskStatus.PROCESSING,
"url": decoded_url,
"_links": {
"self": {"href": f"{base_url}/llm/{task_id}"},
"status": {"href": f"{base_url}/llm/{task_id}"},
},
}
})
)
def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
"""Create response for task status check."""
@@ -392,8 +432,8 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
"url": task["url"],
"_links": {
"self": {"href": f"{base_url}/llm/{task_id}"},
"refresh": {"href": f"{base_url}/llm/{task_id}"}
}
"refresh": {"href": f"{base_url}/llm/{task_id}"},
},
}
if task["status"] == TaskStatus.COMPLETED:
@@ -403,9 +443,13 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
return response
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
async def stream_results(
crawler: AsyncWebCrawler, results_gen: AsyncGenerator
) -> AsyncGenerator[bytes, None]:
"""Stream results with heartbeats and completion markers."""
import json
from utils import datetime_handler
try:
@@ -413,23 +457,29 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
try:
server_memory_mb = _get_memory_mb()
result_dict = result.model_dump()
result_dict['server_memory_mb'] = server_memory_mb
result_dict["server_memory_mb"] = server_memory_mb
# Ensure fit_html is JSON-serializable
if "fit_html" in result_dict and not (result_dict["fit_html"] is None or isinstance(result_dict["fit_html"], str)):
if "fit_html" in result_dict and not (
result_dict["fit_html"] is None
or isinstance(result_dict["fit_html"], str)
):
result_dict["fit_html"] = None
# If PDF exists, encode it to base64
if result_dict.get('pdf') is not None:
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
if result_dict.get("pdf") is not None:
result_dict["pdf"] = b64encode(result_dict["pdf"]).decode("utf-8")
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
data = json.dumps(result_dict, default=datetime_handler) + "\n"
yield data.encode('utf-8')
yield data.encode("utf-8")
except Exception as e:
logger.error(f"Serialization error: {e}")
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
yield (json.dumps(error_response) + "\n").encode('utf-8')
error_response = {
"error": str(e),
"url": getattr(result, "url", "unknown"),
}
yield (json.dumps(error_response) + "\n").encode("utf-8")
yield json.dumps({"status": "completed"}).encode("utf-8")
yield json.dumps({"status": "completed"}).encode('utf-8')
except asyncio.CancelledError:
logger.warning("Client disconnected during streaming")
finally:
@@ -439,51 +489,70 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
# logger.error(f"Crawler cleanup error: {e}")
pass
async def handle_crawl_request(
urls: List[str],
browser_config: dict,
crawler_config: dict,
config: dict,
hooks_config: Optional[dict] = None
hooks_config: Optional[dict] = None,
anti_bot_strategy: str = "default",
headless: bool = True,
) -> dict:
"""Handle non-streaming crawl requests with optional hooks."""
start_mem_mb = _get_memory_mb() # <--- Get memory before
start_mem_mb = _get_memory_mb() # <--- Get memory before
start_time = time.time()
mem_delta_mb = None
peak_mem_mb = start_mem_mb
hook_manager = None
try:
urls = [('https://' + url) if not url.startswith(('http://', 'https://')) and not url.startswith(("raw:", "raw://")) else url for url in urls]
urls = [
("https://" + url)
if not url.startswith(("http://", "https://"))
and not url.startswith(("raw:", "raw://"))
else url
for url in urls
]
browser_config = BrowserConfig.load(browser_config)
_apply_headless_setting(browser_config, headless)
crawler_config = CrawlerRunConfig.load(crawler_config)
# Configure browser adapter based on anti_bot_strategy
browser_adapter = _get_browser_adapter(anti_bot_strategy, browser_config)
# TODO: add support for other dispatchers
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
) if config["crawler"]["rate_limiter"]["enabled"] else None
)
if config["crawler"]["rate_limiter"]["enabled"]
else None,
)
from crawler_pool import get_crawler
crawler = await get_crawler(browser_config)
crawler = await get_crawler(browser_config, browser_adapter)
# crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config)
# await crawler.start()
# Attach hooks if provided
hooks_status = {}
if hooks_config:
from hook_manager import attach_user_hooks_to_crawler, UserHookManager
hook_manager = UserHookManager(timeout=hooks_config.get('timeout', 30))
from hook_manager import UserHookManager, attach_user_hooks_to_crawler
hook_manager = UserHookManager(timeout=hooks_config.get("timeout", 30))
hooks_status, hook_manager = await attach_user_hooks_to_crawler(
crawler,
hooks_config.get('code', {}),
timeout=hooks_config.get('timeout', 30),
hook_manager=hook_manager
hooks_config.get("code", {}),
timeout=hooks_config.get("timeout", 30),
hook_manager=hook_manager,
)
logger.info(f"Hooks attachment status: {hooks_status['status']}")
base_config = config["crawler"]["base_config"]
# Iterate on key-value pairs in global_config then use hasattr to set them
for key, value in base_config.items():
@@ -495,32 +564,38 @@ async def handle_crawl_request(
results = []
func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many")
partial_func = partial(func,
urls[0] if len(urls) == 1 else urls,
config=crawler_config,
dispatcher=dispatcher)
partial_func = partial(
func,
urls[0] if len(urls) == 1 else urls,
config=crawler_config,
dispatcher=dispatcher,
)
results = await partial_func()
# Ensure results is always a list
if not isinstance(results, list):
results = [results]
# await crawler.close()
end_mem_mb = _get_memory_mb() # <--- Get memory after
end_mem_mb = _get_memory_mb() # <--- Get memory after
end_time = time.time()
if start_mem_mb is not None and end_mem_mb is not None:
mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta
peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory
logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB")
mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta
peak_mem_mb = max(
peak_mem_mb if peak_mem_mb else 0, end_mem_mb
) # <--- Get peak memory
logger.info(
f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB"
)
# Process results to handle PDF bytes
processed_results = []
for result in results:
try:
# Check if result has model_dump method (is a proper CrawlResult)
if hasattr(result, 'model_dump'):
if hasattr(result, "model_dump"):
result_dict = result.model_dump()
elif isinstance(result, dict):
result_dict = result
@@ -528,48 +603,53 @@ async def handle_crawl_request(
# Handle unexpected result type
logger.warning(f"Unexpected result type: {type(result)}")
result_dict = {
"url": str(result) if hasattr(result, '__str__') else "unknown",
"url": str(result) if hasattr(result, "__str__") else "unknown",
"success": False,
"error_message": f"Unexpected result type: {type(result).__name__}"
"error_message": f"Unexpected result type: {type(result).__name__}",
}
# if fit_html is not a string, set it to None to avoid serialization errors
if "fit_html" in result_dict and not (result_dict["fit_html"] is None or isinstance(result_dict["fit_html"], str)):
if "fit_html" in result_dict and not (
result_dict["fit_html"] is None
or isinstance(result_dict["fit_html"], str)
):
result_dict["fit_html"] = None
# If PDF exists, encode it to base64
if result_dict.get('pdf') is not None and isinstance(result_dict.get('pdf'), bytes):
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8')
if result_dict.get("pdf") is not None and isinstance(
result_dict.get("pdf"), bytes
):
result_dict["pdf"] = b64encode(result_dict["pdf"]).decode("utf-8")
processed_results.append(result_dict)
except Exception as e:
logger.error(f"Error processing result: {e}")
processed_results.append({
"url": "unknown",
"success": False,
"error_message": str(e)
})
processed_results.append(
{"url": "unknown", "success": False, "error_message": str(e)}
)
response = {
"success": True,
"results": processed_results,
"server_processing_time_s": end_time - start_time,
"server_memory_delta_mb": mem_delta_mb,
"server_peak_memory_mb": peak_mem_mb
"server_peak_memory_mb": peak_mem_mb,
}
# Add hooks information if hooks were used
if hooks_config and hook_manager:
from hook_manager import UserHookManager
if isinstance(hook_manager, UserHookManager):
try:
# Ensure all hook data is JSON serializable
import json
hook_data = {
"status": hooks_status,
"execution_log": hook_manager.execution_log,
"errors": hook_manager.errors,
"summary": hook_manager.get_summary()
"summary": hook_manager.get_summary(),
}
# Test that it's serializable
json.dumps(hook_data)
@@ -577,17 +657,22 @@ async def handle_crawl_request(
except (TypeError, ValueError) as e:
logger.error(f"Hook data not JSON serializable: {e}")
response["hooks"] = {
"status": {"status": "error", "message": "Hook data serialization failed"},
"status": {
"status": "error",
"message": "Hook data serialization failed",
},
"execution_log": [],
"errors": [{"error": str(e)}],
"summary": {}
"summary": {},
}
return response
except Exception as e:
logger.error(f"Crawl error: {str(e)}", exc_info=True)
if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started
if (
"crawler" in locals() and crawler.ready
): # Check if crawler was initialized and started
# try:
# await crawler.close()
# except Exception as close_e:
@@ -601,19 +686,26 @@ async def handle_crawl_request(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=json.dumps({ # Send structured error
"error": str(e),
"server_memory_delta_mb": mem_delta_mb,
"server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0)
})
detail=json.dumps(
{ # Send structured error
"error": str(e),
"server_memory_delta_mb": mem_delta_mb,
"server_peak_memory_mb": max(
peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0
),
}
),
)
async def handle_stream_crawl_request(
urls: List[str],
browser_config: dict,
crawler_config: dict,
config: dict,
hooks_config: Optional[dict] = None
hooks_config: Optional[dict] = None,
anti_bot_strategy: str = "default",
headless: bool = True,
) -> Tuple[AsyncWebCrawler, AsyncGenerator, Optional[Dict]]:
"""Handle streaming crawl requests with optional hooks."""
hooks_info = None
@@ -621,60 +713,68 @@ async def handle_stream_crawl_request(
browser_config = BrowserConfig.load(browser_config)
# browser_config.verbose = True # Set to False or remove for production stress testing
browser_config.verbose = False
_apply_headless_setting(browser_config, headless)
crawler_config = CrawlerRunConfig.load(crawler_config)
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
crawler_config.stream = True
# Configure browser adapter based on anti_bot_strategy
browser_adapter = _get_browser_adapter(anti_bot_strategy, browser_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
),
)
from crawler_pool import get_crawler
crawler = await get_crawler(browser_config)
crawler = await get_crawler(browser_config, browser_adapter)
# crawler = AsyncWebCrawler(config=browser_config)
# await crawler.start()
# Attach hooks if provided
if hooks_config:
from hook_manager import attach_user_hooks_to_crawler, UserHookManager
hook_manager = UserHookManager(timeout=hooks_config.get('timeout', 30))
from hook_manager import UserHookManager, attach_user_hooks_to_crawler
hook_manager = UserHookManager(timeout=hooks_config.get("timeout", 30))
hooks_status, hook_manager = await attach_user_hooks_to_crawler(
crawler,
hooks_config.get('code', {}),
timeout=hooks_config.get('timeout', 30),
hook_manager=hook_manager
hooks_config.get("code", {}),
timeout=hooks_config.get("timeout", 30),
hook_manager=hook_manager,
)
logger.info(
f"Hooks attachment status for streaming: {hooks_status['status']}"
)
logger.info(f"Hooks attachment status for streaming: {hooks_status['status']}")
# Include hook manager in hooks_info for proper tracking
hooks_info = {'status': hooks_status, 'manager': hook_manager}
hooks_info = {"status": hooks_status, "manager": hook_manager}
results_gen = await crawler.arun_many(
urls=urls,
config=crawler_config,
dispatcher=dispatcher
urls=urls, config=crawler_config, dispatcher=dispatcher
)
return crawler, results_gen, hooks_info
except Exception as e:
# Make sure to close crawler if started during an error here
if 'crawler' in locals() and crawler.ready:
if "crawler" in locals() and crawler.ready:
# try:
# await crawler.close()
# except Exception as close_e:
# logger.error(f"Error closing crawler during stream setup exception: {close_e}")
logger.error(f"Error closing crawler during stream setup exception: {str(e)}")
logger.error(
f"Error closing crawler during stream setup exception: {str(e)}"
)
logger.error(f"Stream crawl error: {str(e)}", exc_info=True)
# Raising HTTPException here will prevent streaming response
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
async def handle_crawl_job(
redis,
background_tasks: BackgroundTasks,
@@ -689,13 +789,16 @@ async def handle_crawl_job(
lets /crawl/job/{task_id} polling fetch the result.
"""
task_id = f"crawl_{uuid4().hex[:8]}"
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.PROCESSING, # <-- keep enum values consistent
"created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"url": json.dumps(urls), # store list as JSON string
"result": "",
"error": "",
})
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.PROCESSING, # <-- keep enum values consistent
"created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"url": json.dumps(urls), # store list as JSON string
"result": "",
"error": "",
},
)
async def _runner():
try:
@@ -705,21 +808,28 @@ async def handle_crawl_job(
crawler_config=crawler_config,
config=config,
)
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.COMPLETED,
"result": json.dumps(result),
})
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.COMPLETED,
"result": json.dumps(result),
},
)
await asyncio.sleep(5) # Give Redis time to process the update
except Exception as exc:
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": str(exc),
})
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.FAILED,
"error": str(exc),
},
)
background_tasks.add_task(_runner)
return {"task_id": task_id}
async def handle_seed(url ,cfg):
async def handle_seed(url, cfg):
# Create the configuration from the request body
try:
seeding_config = cfg
@@ -732,7 +842,7 @@ async def handle_seed(url ,cfg):
return urls
except Exception as e:
return {
"seeded_urls": [],
"count": 0,
"message": "No URLs found for the given domain and configuration.",
}
"seeded_urls": [],
"count": 0,
"message": "No URLs found for the given domain and configuration.",
}

View File

@@ -1,10 +1,27 @@
# crawler_pool.py (new file)
import asyncio, json, hashlib, time, psutil
import asyncio
import hashlib
import json
import time
from contextlib import suppress
from typing import Dict
from typing import Dict, Optional
import psutil
from crawl4ai import AsyncWebCrawler, BrowserConfig
from typing import Dict
from utils import load_config
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
# Import browser adapters with fallback
try:
from crawl4ai.browser_adapter import BrowserAdapter, PlaywrightAdapter
except ImportError:
# Fallback for development environment
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from crawl4ai.browser_adapter import BrowserAdapter, PlaywrightAdapter
from utils import load_config
CONFIG = load_config()
@@ -12,25 +29,44 @@ POOL: Dict[str, AsyncWebCrawler] = {}
LAST_USED: Dict[str, float] = {}
LOCK = asyncio.Lock()
MEM_LIMIT = CONFIG.get("crawler", {}).get("memory_threshold_percent", 95.0) # % RAM refuse new browsers above this
IDLE_TTL = CONFIG.get("crawler", {}).get("pool", {}).get("idle_ttl_sec", 1800) # close if unused for 30min
MEM_LIMIT = CONFIG.get("crawler", {}).get(
"memory_threshold_percent", 95.0
) # % RAM refuse new browsers above this
IDLE_TTL = (
CONFIG.get("crawler", {}).get("pool", {}).get("idle_ttl_sec", 1800)
) # close if unused for 30min
def _sig(cfg: BrowserConfig) -> str:
payload = json.dumps(cfg.to_dict(), sort_keys=True, separators=(",",":"))
def _sig(cfg: BrowserConfig, adapter: Optional[BrowserAdapter] = None) -> str:
config_payload = json.dumps(cfg.to_dict(), sort_keys=True, separators=(",", ":"))
adapter_name = adapter.__class__.__name__ if adapter else "PlaywrightAdapter"
payload = f"{config_payload}:{adapter_name}"
return hashlib.sha1(payload.encode()).hexdigest()
async def get_crawler(cfg: BrowserConfig) -> AsyncWebCrawler:
async def get_crawler(
cfg: BrowserConfig, adapter: Optional[BrowserAdapter] = None
) -> AsyncWebCrawler:
try:
sig = _sig(cfg)
sig = _sig(cfg, adapter)
async with LOCK:
if sig in POOL:
LAST_USED[sig] = time.time();
LAST_USED[sig] = time.time()
return POOL[sig]
if psutil.virtual_memory().percent >= MEM_LIMIT:
raise MemoryError("RAM pressure new browser denied")
crawler = AsyncWebCrawler(config=cfg, thread_safe=False)
# Create strategy with the specified adapter
strategy = AsyncPlaywrightCrawlerStrategy(
browser_config=cfg, browser_adapter=adapter or PlaywrightAdapter()
)
crawler = AsyncWebCrawler(
config=cfg, crawler_strategy=strategy, thread_safe=False
)
await crawler.start()
POOL[sig] = crawler; LAST_USED[sig] = time.time()
POOL[sig] = crawler
LAST_USED[sig] = time.time()
return crawler
except MemoryError as e:
raise MemoryError(f"RAM pressure new browser denied: {e}")
@@ -44,10 +80,16 @@ async def get_crawler(cfg: BrowserConfig) -> AsyncWebCrawler:
POOL.pop(sig, None)
LAST_USED.pop(sig, None)
# If we failed to start the browser, we should remove it from the pool
async def close_all():
async with LOCK:
await asyncio.gather(*(c.close() for c in POOL.values()), return_exceptions=True)
POOL.clear(); LAST_USED.clear()
await asyncio.gather(
*(c.close() for c in POOL.values()), return_exceptions=True
)
POOL.clear()
LAST_USED.clear()
async def janitor():
while True:
@@ -56,5 +98,7 @@ async def janitor():
async with LOCK:
for sig, crawler in list(POOL.items()):
if now - LAST_USED[sig] > IDLE_TTL:
with suppress(Exception): await crawler.close()
POOL.pop(sig, None); LAST_USED.pop(sig, None)
with suppress(Exception):
await crawler.close()
POOL.pop(sig, None)
LAST_USED.pop(sig, None)

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
from utils import FilterType
@@ -10,6 +10,11 @@ class CrawlRequest(BaseModel):
browser_config: Optional[Dict] = Field(default_factory=dict)
crawler_config: Optional[Dict] = Field(default_factory=dict)
anti_bot_strategy: Literal["default", "stealth", "undetected", "max_evasion"] = (
Field("default", description="The anti-bot strategy to use for the crawl.")
)
headless: bool = Field(True, description="Run the browser in headless mode.")
class HookConfig(BaseModel):
"""Configuration for user-provided hooks"""

View File

@@ -49,6 +49,7 @@ from rank_bm25 import BM25Okapi
from redis import asyncio as aioredis
from routers import adaptive, scripts
from schemas import (
CrawlRequest,
CrawlRequestWithHooks,
HTMLRequest,
JSEndpointRequest,
@@ -575,7 +576,7 @@ async def metrics():
@mcp_tool("crawl")
async def crawl(
request: Request,
crawl_request: CrawlRequestWithHooks,
crawl_request: CrawlRequest | CrawlRequestWithHooks,
_td: Dict = Depends(token_dep),
):
"""
@@ -592,7 +593,7 @@ async def crawl(
# Prepare hooks config if provided
hooks_config = None
if crawl_request.hooks:
if hasattr(crawl_request, 'hooks') and crawl_request.hooks:
hooks_config = {
"code": crawl_request.hooks.code,
"timeout": crawl_request.hooks.timeout,
@@ -604,6 +605,8 @@ async def crawl(
crawler_config=crawl_request.crawler_config,
config=config,
hooks_config=hooks_config,
anti_bot_strategy=crawl_request.anti_bot_strategy,
headless=crawl_request.headless,
)
# check if all of the results are not successful
if all(not result["success"] for result in results["results"]):
@@ -627,9 +630,9 @@ async def crawl_stream(
async def stream_process(crawl_request: CrawlRequestWithHooks):
# Prepare hooks config if provided# Prepare hooks config if provided
# Prepare hooks config if provided
hooks_config = None
if crawl_request.hooks:
if hasattr(crawl_request, 'hooks') and crawl_request.hooks:
hooks_config = {
"code": crawl_request.hooks.code,
"timeout": crawl_request.hooks.timeout,
@@ -641,6 +644,8 @@ async def stream_process(crawl_request: CrawlRequestWithHooks):
crawler_config=crawl_request.crawler_config,
config=config,
hooks_config=hooks_config,
anti_bot_strategy=crawl_request.anti_bot_strategy,
headless=crawl_request.headless,
)
# Add hooks info to response headers if available