feat: enhance crawling functionality with anti-bot strategies and headless mode options (Browser adapters , 12.Undetected/stealth browser)

This commit is contained in:
AHMET YILMAZ
2025-10-03 18:02:10 +08:00
parent a599db8f7b
commit 5dc34dd210
4 changed files with 421 additions and 257 deletions

View File

@@ -1,57 +1,63 @@
import os
import json
import asyncio import asyncio
from typing import List, Tuple, Dict import json
from functools import partial
from uuid import uuid4
from datetime import datetime, timezone
from base64 import b64encode
import logging import logging
from typing import Optional, AsyncGenerator import os
import time
from base64 import b64encode
from datetime import datetime, timezone
from functools import partial
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import unquote from urllib.parse import unquote
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from fastapi.background import BackgroundTasks from fastapi.background import BackgroundTasks
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from redis import asyncio as aioredis from redis import asyncio as aioredis
from crawl4ai import ( from crawl4ai import (
AsyncUrlSeeder,
AsyncWebCrawler, AsyncWebCrawler,
CrawlerRunConfig,
LLMExtractionStrategy,
CacheMode,
BrowserConfig, BrowserConfig,
MemoryAdaptiveDispatcher, CacheMode,
RateLimiter, CrawlerRunConfig,
LLMConfig, LLMConfig,
AsyncUrlSeeder, LLMExtractionStrategy,
SeedingConfig MemoryAdaptiveDispatcher,
PlaywrightAdapter,
RateLimiter,
SeedingConfig,
UndetectedAdapter,
) )
from crawl4ai.utils import perform_completion_with_backoff
# Import StealthAdapter with fallback for compatibility
try:
from crawl4ai import StealthAdapter
except ImportError:
# Fallback: import directly from browser_adapter module
try:
import os
import sys
# Add the project root to path for development
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, project_root)
from crawl4ai.browser_adapter import StealthAdapter
except ImportError:
# If all else fails, create a simple fallback
from crawl4ai.browser_adapter import PlaywrightAdapter
class StealthAdapter(PlaywrightAdapter):
"""Fallback StealthAdapter that uses PlaywrightAdapter"""
pass
from crawl4ai.content_filter_strategy import ( from crawl4ai.content_filter_strategy import (
PruningContentFilter,
BM25ContentFilter, BM25ContentFilter,
LLMContentFilter LLMContentFilter,
PruningContentFilter,
) )
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from utils import ( from crawl4ai.utils import perform_completion_with_backoff
TaskStatus,
FilterType,
get_base_url,
is_task_id,
should_cleanup_task,
decode_redis_hash,
get_llm_api_key,
validate_llm_provider,
get_llm_temperature,
get_llm_base_url
)
import psutil, time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- Helper to get memory --- # --- Helper to get memory ---
def _get_memory_mb(): def _get_memory_mb():
try: try:
@@ -61,17 +67,39 @@ def _get_memory_mb():
return None return None
async def handle_llm_qa( # --- Helper to get browser adapter based on anti_bot_strategy ---
url: str, def _get_browser_adapter(anti_bot_strategy: str, browser_config: BrowserConfig):
query: str, """Get the appropriate browser adapter based on anti_bot_strategy."""
config: dict if anti_bot_strategy == "stealth":
) -> str: return StealthAdapter()
elif anti_bot_strategy == "undetected":
return UndetectedAdapter()
elif anti_bot_strategy == "max_evasion":
# Use undetected for maximum evasion
return UndetectedAdapter()
else: # "default"
# If stealth is enabled in browser config, use stealth adapter
if getattr(browser_config, "enable_stealth", False):
return StealthAdapter()
return PlaywrightAdapter()
# --- Helper to apply headless setting ---
def _apply_headless_setting(browser_config: BrowserConfig, headless: bool):
"""Apply headless setting to browser config."""
browser_config.headless = headless
return browser_config
async def handle_llm_qa(url: str, query: str, config: dict) -> str:
"""Process QA using LLM with crawled content as context.""" """Process QA using LLM with crawled content as context."""
try: try:
if not url.startswith(('http://', 'https://')) and not url.startswith(("raw:", "raw://")): if not url.startswith(("http://", "https://")) and not url.startswith(
url = 'https://' + url ("raw:", "raw://")
):
url = "https://" + url
# Extract base URL by finding last '?q=' occurrence # Extract base URL by finding last '?q=' occurrence
last_q_index = url.rfind('?q=') last_q_index = url.rfind("?q=")
if last_q_index != -1: if last_q_index != -1:
url = url[:last_q_index] url = url[:last_q_index]
@@ -81,7 +109,7 @@ async def handle_llm_qa(
if not result.success: if not result.success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message detail=result.error_message,
) )
content = result.markdown.fit_markdown or result.markdown.raw_markdown content = result.markdown.fit_markdown or result.markdown.raw_markdown
@@ -101,17 +129,17 @@ async def handle_llm_qa(
prompt_with_variables=prompt, prompt_with_variables=prompt,
api_token=get_llm_api_key(config), # Returns None to let litellm handle it api_token=get_llm_api_key(config), # Returns None to let litellm handle it
temperature=get_llm_temperature(config), temperature=get_llm_temperature(config),
base_url=get_llm_base_url(config) base_url=get_llm_base_url(config),
) )
return response.choices[0].message.content return response.choices[0].message.content
except Exception as e: except Exception as e:
logger.error(f"QA processing error: {str(e)}", exc_info=True) logger.error(f"QA processing error: {str(e)}", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
detail=str(e)
) )
async def process_llm_extraction( async def process_llm_extraction(
redis: aioredis.Redis, redis: aioredis.Redis,
config: dict, config: dict,
@@ -122,25 +150,27 @@ async def process_llm_extraction(
cache: str = "0", cache: str = "0",
provider: Optional[str] = None, provider: Optional[str] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
base_url: Optional[str] = None base_url: Optional[str] = None,
) -> None: ) -> None:
"""Process LLM extraction in background.""" """Process LLM extraction in background."""
try: try:
# Validate provider # Validate provider
is_valid, error_msg = validate_llm_provider(config, provider) is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid: if not is_valid:
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.FAILED, f"task:{task_id}",
"error": error_msg mapping={"status": TaskStatus.FAILED, "error": error_msg},
}) )
return return
api_key = get_llm_api_key(config, provider) # Returns None to let litellm handle it api_key = get_llm_api_key(
config, provider
) # Returns None to let litellm handle it
llm_strategy = LLMExtractionStrategy( llm_strategy = LLMExtractionStrategy(
llm_config=LLMConfig( llm_config=LLMConfig(
provider=provider or config["llm"]["provider"], provider=provider or config["llm"]["provider"],
api_token=api_key, api_token=api_key,
temperature=temperature or get_llm_temperature(config, provider), temperature=temperature or get_llm_temperature(config, provider),
base_url=base_url or get_llm_base_url(config, provider) base_url=base_url or get_llm_base_url(config, provider),
), ),
instruction=instruction, instruction=instruction,
schema=json.loads(schema) if schema else None, schema=json.loads(schema) if schema else None,
@@ -154,32 +184,32 @@ async def process_llm_extraction(
config=CrawlerRunConfig( config=CrawlerRunConfig(
extraction_strategy=llm_strategy, extraction_strategy=llm_strategy,
scraping_strategy=LXMLWebScrapingStrategy(), scraping_strategy=LXMLWebScrapingStrategy(),
cache_mode=cache_mode cache_mode=cache_mode,
) ),
) )
if not result.success: if not result.success:
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.FAILED, f"task:{task_id}",
"error": result.error_message mapping={"status": TaskStatus.FAILED, "error": result.error_message},
}) )
return return
try: try:
content = json.loads(result.extracted_content) content = json.loads(result.extracted_content)
except json.JSONDecodeError: except json.JSONDecodeError:
content = result.extracted_content content = result.extracted_content
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.COMPLETED, f"task:{task_id}",
"result": json.dumps(content) mapping={"status": TaskStatus.COMPLETED, "result": json.dumps(content)},
}) )
except Exception as e: except Exception as e:
logger.error(f"LLM extraction error: {str(e)}", exc_info=True) logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.FAILED, f"task:{task_id}", mapping={"status": TaskStatus.FAILED, "error": str(e)}
"error": str(e) )
})
async def handle_markdown_request( async def handle_markdown_request(
url: str, url: str,
@@ -189,7 +219,7 @@ async def handle_markdown_request(
config: Optional[dict] = None, config: Optional[dict] = None,
provider: Optional[str] = None, provider: Optional[str] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
base_url: Optional[str] = None base_url: Optional[str] = None,
) -> str: ) -> str:
"""Handle markdown generation requests.""" """Handle markdown generation requests."""
try: try:
@@ -198,12 +228,13 @@ async def handle_markdown_request(
is_valid, error_msg = validate_llm_provider(config, provider) is_valid, error_msg = validate_llm_provider(config, provider)
if not is_valid: if not is_valid:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg
detail=error_msg
) )
decoded_url = unquote(url) decoded_url = unquote(url)
if not decoded_url.startswith(('http://', 'https://')) and not decoded_url.startswith(("raw:", "raw://")): if not decoded_url.startswith(
decoded_url = 'https://' + decoded_url ("http://", "https://")
) and not decoded_url.startswith(("raw:", "raw://")):
decoded_url = "https://" + decoded_url
if filter_type == FilterType.RAW: if filter_type == FilterType.RAW:
md_generator = DefaultMarkdownGenerator() md_generator = DefaultMarkdownGenerator()
@@ -214,12 +245,15 @@ async def handle_markdown_request(
FilterType.LLM: LLMContentFilter( FilterType.LLM: LLMContentFilter(
llm_config=LLMConfig( llm_config=LLMConfig(
provider=provider or config["llm"]["provider"], provider=provider or config["llm"]["provider"],
api_token=get_llm_api_key(config, provider), # Returns None to let litellm handle it api_token=get_llm_api_key(
temperature=temperature or get_llm_temperature(config, provider), config, provider
base_url=base_url or get_llm_base_url(config, provider) ), # Returns None to let litellm handle it
temperature=temperature
or get_llm_temperature(config, provider),
base_url=base_url or get_llm_base_url(config, provider),
), ),
instruction=query or "Extract main content" instruction=query or "Extract main content",
) ),
}[filter_type] }[filter_type]
md_generator = DefaultMarkdownGenerator(content_filter=content_filter) md_generator = DefaultMarkdownGenerator(content_filter=content_filter)
@@ -231,27 +265,29 @@ async def handle_markdown_request(
config=CrawlerRunConfig( config=CrawlerRunConfig(
markdown_generator=md_generator, markdown_generator=md_generator,
scraping_strategy=LXMLWebScrapingStrategy(), scraping_strategy=LXMLWebScrapingStrategy(),
cache_mode=cache_mode cache_mode=cache_mode,
) ),
) )
if not result.success: if not result.success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result.error_message detail=result.error_message,
) )
return (result.markdown.raw_markdown return (
if filter_type == FilterType.RAW result.markdown.raw_markdown
else result.markdown.fit_markdown) if filter_type == FilterType.RAW
else result.markdown.fit_markdown
)
except Exception as e: except Exception as e:
logger.error(f"Markdown error: {str(e)}", exc_info=True) logger.error(f"Markdown error: {str(e)}", exc_info=True)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
detail=str(e)
) )
async def handle_llm_request( async def handle_llm_request(
redis: aioredis.Redis, redis: aioredis.Redis,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@@ -263,27 +299,27 @@ async def handle_llm_request(
config: Optional[dict] = None, config: Optional[dict] = None,
provider: Optional[str] = None, provider: Optional[str] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
api_base_url: Optional[str] = None api_base_url: Optional[str] = None,
) -> JSONResponse: ) -> JSONResponse:
"""Handle LLM extraction requests.""" """Handle LLM extraction requests."""
base_url = get_base_url(request) base_url = get_base_url(request)
try: try:
if is_task_id(input_path): if is_task_id(input_path):
return await handle_task_status( return await handle_task_status(redis, input_path, base_url)
redis, input_path, base_url
)
if not query: if not query:
return JSONResponse({ return JSONResponse(
"message": "Please provide an instruction", {
"_links": { "message": "Please provide an instruction",
"example": { "_links": {
"href": f"{base_url}/llm/{input_path}?q=Extract+main+content", "example": {
"title": "Try this example" "href": f"{base_url}/llm/{input_path}?q=Extract+main+content",
} "title": "Try this example",
}
},
} }
}) )
return await create_new_task( return await create_new_task(
redis, redis,
@@ -296,31 +332,25 @@ async def handle_llm_request(
config, config,
provider, provider,
temperature, temperature,
api_base_url api_base_url,
) )
except Exception as e: except Exception as e:
logger.error(f"LLM endpoint error: {str(e)}", exc_info=True) logger.error(f"LLM endpoint error: {str(e)}", exc_info=True)
return JSONResponse({ return JSONResponse(
"error": str(e), {"error": str(e), "_links": {"retry": {"href": str(request.url)}}},
"_links": { status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
"retry": {"href": str(request.url)} )
}
}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
async def handle_task_status( async def handle_task_status(
redis: aioredis.Redis, redis: aioredis.Redis, task_id: str, base_url: str, *, keep: bool = False
task_id: str,
base_url: str,
*,
keep: bool = False
) -> JSONResponse: ) -> JSONResponse:
"""Handle task status check requests.""" """Handle task status check requests."""
task = await redis.hgetall(f"task:{task_id}") task = await redis.hgetall(f"task:{task_id}")
if not task: if not task:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
detail="Task not found"
) )
task = decode_redis_hash(task) task = decode_redis_hash(task)
@@ -332,6 +362,7 @@ async def handle_task_status(
return JSONResponse(response) return JSONResponse(response)
async def create_new_task( async def create_new_task(
redis: aioredis.Redis, redis: aioredis.Redis,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@@ -343,21 +374,27 @@ async def create_new_task(
config: dict, config: dict,
provider: Optional[str] = None, provider: Optional[str] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
api_base_url: Optional[str] = None api_base_url: Optional[str] = None,
) -> JSONResponse: ) -> JSONResponse:
"""Create and initialize a new task.""" """Create and initialize a new task."""
decoded_url = unquote(input_path) decoded_url = unquote(input_path)
if not decoded_url.startswith(('http://', 'https://')) and not decoded_url.startswith(("raw:", "raw://")): if not decoded_url.startswith(
decoded_url = 'https://' + decoded_url ("http://", "https://")
) and not decoded_url.startswith(("raw:", "raw://")):
decoded_url = "https://" + decoded_url
from datetime import datetime from datetime import datetime
task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}" task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}"
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.PROCESSING, f"task:{task_id}",
"created_at": datetime.now().isoformat(), mapping={
"url": decoded_url "status": TaskStatus.PROCESSING,
}) "created_at": datetime.now().isoformat(),
"url": decoded_url,
},
)
background_tasks.add_task( background_tasks.add_task(
process_llm_extraction, process_llm_extraction,
@@ -370,18 +407,21 @@ async def create_new_task(
cache, cache,
provider, provider,
temperature, temperature,
api_base_url api_base_url,
) )
return JSONResponse({ return JSONResponse(
"task_id": task_id, {
"status": TaskStatus.PROCESSING, "task_id": task_id,
"url": decoded_url, "status": TaskStatus.PROCESSING,
"_links": { "url": decoded_url,
"self": {"href": f"{base_url}/llm/{task_id}"}, "_links": {
"status": {"href": f"{base_url}/llm/{task_id}"} "self": {"href": f"{base_url}/llm/{task_id}"},
"status": {"href": f"{base_url}/llm/{task_id}"},
},
} }
}) )
def create_task_response(task: dict, task_id: str, base_url: str) -> dict: def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
"""Create response for task status check.""" """Create response for task status check."""
@@ -392,8 +432,8 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
"url": task["url"], "url": task["url"],
"_links": { "_links": {
"self": {"href": f"{base_url}/llm/{task_id}"}, "self": {"href": f"{base_url}/llm/{task_id}"},
"refresh": {"href": f"{base_url}/llm/{task_id}"} "refresh": {"href": f"{base_url}/llm/{task_id}"},
} },
} }
if task["status"] == TaskStatus.COMPLETED: if task["status"] == TaskStatus.COMPLETED:
@@ -403,9 +443,13 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
return response return response
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
async def stream_results(
crawler: AsyncWebCrawler, results_gen: AsyncGenerator
) -> AsyncGenerator[bytes, None]:
"""Stream results with heartbeats and completion markers.""" """Stream results with heartbeats and completion markers."""
import json import json
from utils import datetime_handler from utils import datetime_handler
try: try:
@@ -413,23 +457,29 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
try: try:
server_memory_mb = _get_memory_mb() server_memory_mb = _get_memory_mb()
result_dict = result.model_dump() result_dict = result.model_dump()
result_dict['server_memory_mb'] = server_memory_mb result_dict["server_memory_mb"] = server_memory_mb
# Ensure fit_html is JSON-serializable # Ensure fit_html is JSON-serializable
if "fit_html" in result_dict and not (result_dict["fit_html"] is None or isinstance(result_dict["fit_html"], str)): if "fit_html" in result_dict and not (
result_dict["fit_html"] is None
or isinstance(result_dict["fit_html"], str)
):
result_dict["fit_html"] = None result_dict["fit_html"] = None
# If PDF exists, encode it to base64 # If PDF exists, encode it to base64
if result_dict.get('pdf') is not None: if result_dict.get("pdf") is not None:
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8') result_dict["pdf"] = b64encode(result_dict["pdf"]).decode("utf-8")
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}") logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
data = json.dumps(result_dict, default=datetime_handler) + "\n" data = json.dumps(result_dict, default=datetime_handler) + "\n"
yield data.encode('utf-8') yield data.encode("utf-8")
except Exception as e: except Exception as e:
logger.error(f"Serialization error: {e}") logger.error(f"Serialization error: {e}")
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')} error_response = {
yield (json.dumps(error_response) + "\n").encode('utf-8') "error": str(e),
"url": getattr(result, "url", "unknown"),
}
yield (json.dumps(error_response) + "\n").encode("utf-8")
yield json.dumps({"status": "completed"}).encode("utf-8")
yield json.dumps({"status": "completed"}).encode('utf-8')
except asyncio.CancelledError: except asyncio.CancelledError:
logger.warning("Client disconnected during streaming") logger.warning("Client disconnected during streaming")
finally: finally:
@@ -439,51 +489,70 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
# logger.error(f"Crawler cleanup error: {e}") # logger.error(f"Crawler cleanup error: {e}")
pass pass
async def handle_crawl_request( async def handle_crawl_request(
urls: List[str], urls: List[str],
browser_config: dict, browser_config: dict,
crawler_config: dict, crawler_config: dict,
config: dict, config: dict,
hooks_config: Optional[dict] = None hooks_config: Optional[dict] = None,
anti_bot_strategy: str = "default",
headless: bool = True,
) -> dict: ) -> dict:
"""Handle non-streaming crawl requests with optional hooks.""" """Handle non-streaming crawl requests with optional hooks."""
start_mem_mb = _get_memory_mb() # <--- Get memory before start_mem_mb = _get_memory_mb() # <--- Get memory before
start_time = time.time() start_time = time.time()
mem_delta_mb = None mem_delta_mb = None
peak_mem_mb = start_mem_mb peak_mem_mb = start_mem_mb
hook_manager = None hook_manager = None
try: try:
urls = [('https://' + url) if not url.startswith(('http://', 'https://')) and not url.startswith(("raw:", "raw://")) else url for url in urls] urls = [
("https://" + url)
if not url.startswith(("http://", "https://"))
and not url.startswith(("raw:", "raw://"))
else url
for url in urls
]
browser_config = BrowserConfig.load(browser_config) browser_config = BrowserConfig.load(browser_config)
_apply_headless_setting(browser_config, headless)
crawler_config = CrawlerRunConfig.load(crawler_config) crawler_config = CrawlerRunConfig.load(crawler_config)
# Configure browser adapter based on anti_bot_strategy
browser_adapter = _get_browser_adapter(anti_bot_strategy, browser_config)
# TODO: add support for other dispatchers
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"], memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"]) base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
) if config["crawler"]["rate_limiter"]["enabled"] else None )
if config["crawler"]["rate_limiter"]["enabled"]
else None,
) )
from crawler_pool import get_crawler from crawler_pool import get_crawler
crawler = await get_crawler(browser_config)
crawler = await get_crawler(browser_config, browser_adapter)
# crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config) # crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config)
# await crawler.start() # await crawler.start()
# Attach hooks if provided # Attach hooks if provided
hooks_status = {} hooks_status = {}
if hooks_config: if hooks_config:
from hook_manager import attach_user_hooks_to_crawler, UserHookManager from hook_manager import UserHookManager, attach_user_hooks_to_crawler
hook_manager = UserHookManager(timeout=hooks_config.get('timeout', 30))
hook_manager = UserHookManager(timeout=hooks_config.get("timeout", 30))
hooks_status, hook_manager = await attach_user_hooks_to_crawler( hooks_status, hook_manager = await attach_user_hooks_to_crawler(
crawler, crawler,
hooks_config.get('code', {}), hooks_config.get("code", {}),
timeout=hooks_config.get('timeout', 30), timeout=hooks_config.get("timeout", 30),
hook_manager=hook_manager hook_manager=hook_manager,
) )
logger.info(f"Hooks attachment status: {hooks_status['status']}") logger.info(f"Hooks attachment status: {hooks_status['status']}")
base_config = config["crawler"]["base_config"] base_config = config["crawler"]["base_config"]
# Iterate on key-value pairs in global_config then use hasattr to set them # Iterate on key-value pairs in global_config then use hasattr to set them
for key, value in base_config.items(): for key, value in base_config.items():
@@ -495,32 +564,38 @@ async def handle_crawl_request(
results = [] results = []
func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many") func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many")
partial_func = partial(func, partial_func = partial(
urls[0] if len(urls) == 1 else urls, func,
config=crawler_config, urls[0] if len(urls) == 1 else urls,
dispatcher=dispatcher) config=crawler_config,
dispatcher=dispatcher,
)
results = await partial_func() results = await partial_func()
# Ensure results is always a list # Ensure results is always a list
if not isinstance(results, list): if not isinstance(results, list):
results = [results] results = [results]
# await crawler.close() # await crawler.close()
end_mem_mb = _get_memory_mb() # <--- Get memory after end_mem_mb = _get_memory_mb() # <--- Get memory after
end_time = time.time() end_time = time.time()
if start_mem_mb is not None and end_mem_mb is not None: if start_mem_mb is not None and end_mem_mb is not None:
mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta
peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory peak_mem_mb = max(
logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB") peak_mem_mb if peak_mem_mb else 0, end_mem_mb
) # <--- Get peak memory
logger.info(
f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB"
)
# Process results to handle PDF bytes # Process results to handle PDF bytes
processed_results = [] processed_results = []
for result in results: for result in results:
try: try:
# Check if result has model_dump method (is a proper CrawlResult) # Check if result has model_dump method (is a proper CrawlResult)
if hasattr(result, 'model_dump'): if hasattr(result, "model_dump"):
result_dict = result.model_dump() result_dict = result.model_dump()
elif isinstance(result, dict): elif isinstance(result, dict):
result_dict = result result_dict = result
@@ -528,48 +603,53 @@ async def handle_crawl_request(
# Handle unexpected result type # Handle unexpected result type
logger.warning(f"Unexpected result type: {type(result)}") logger.warning(f"Unexpected result type: {type(result)}")
result_dict = { result_dict = {
"url": str(result) if hasattr(result, '__str__') else "unknown", "url": str(result) if hasattr(result, "__str__") else "unknown",
"success": False, "success": False,
"error_message": f"Unexpected result type: {type(result).__name__}" "error_message": f"Unexpected result type: {type(result).__name__}",
} }
# if fit_html is not a string, set it to None to avoid serialization errors # if fit_html is not a string, set it to None to avoid serialization errors
if "fit_html" in result_dict and not (result_dict["fit_html"] is None or isinstance(result_dict["fit_html"], str)): if "fit_html" in result_dict and not (
result_dict["fit_html"] is None
or isinstance(result_dict["fit_html"], str)
):
result_dict["fit_html"] = None result_dict["fit_html"] = None
# If PDF exists, encode it to base64 # If PDF exists, encode it to base64
if result_dict.get('pdf') is not None and isinstance(result_dict.get('pdf'), bytes): if result_dict.get("pdf") is not None and isinstance(
result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8') result_dict.get("pdf"), bytes
):
result_dict["pdf"] = b64encode(result_dict["pdf"]).decode("utf-8")
processed_results.append(result_dict) processed_results.append(result_dict)
except Exception as e: except Exception as e:
logger.error(f"Error processing result: {e}") logger.error(f"Error processing result: {e}")
processed_results.append({ processed_results.append(
"url": "unknown", {"url": "unknown", "success": False, "error_message": str(e)}
"success": False, )
"error_message": str(e)
})
response = { response = {
"success": True, "success": True,
"results": processed_results, "results": processed_results,
"server_processing_time_s": end_time - start_time, "server_processing_time_s": end_time - start_time,
"server_memory_delta_mb": mem_delta_mb, "server_memory_delta_mb": mem_delta_mb,
"server_peak_memory_mb": peak_mem_mb "server_peak_memory_mb": peak_mem_mb,
} }
# Add hooks information if hooks were used # Add hooks information if hooks were used
if hooks_config and hook_manager: if hooks_config and hook_manager:
from hook_manager import UserHookManager from hook_manager import UserHookManager
if isinstance(hook_manager, UserHookManager): if isinstance(hook_manager, UserHookManager):
try: try:
# Ensure all hook data is JSON serializable # Ensure all hook data is JSON serializable
import json import json
hook_data = { hook_data = {
"status": hooks_status, "status": hooks_status,
"execution_log": hook_manager.execution_log, "execution_log": hook_manager.execution_log,
"errors": hook_manager.errors, "errors": hook_manager.errors,
"summary": hook_manager.get_summary() "summary": hook_manager.get_summary(),
} }
# Test that it's serializable # Test that it's serializable
json.dumps(hook_data) json.dumps(hook_data)
@@ -577,17 +657,22 @@ async def handle_crawl_request(
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
logger.error(f"Hook data not JSON serializable: {e}") logger.error(f"Hook data not JSON serializable: {e}")
response["hooks"] = { response["hooks"] = {
"status": {"status": "error", "message": "Hook data serialization failed"}, "status": {
"status": "error",
"message": "Hook data serialization failed",
},
"execution_log": [], "execution_log": [],
"errors": [{"error": str(e)}], "errors": [{"error": str(e)}],
"summary": {} "summary": {},
} }
return response return response
except Exception as e: except Exception as e:
logger.error(f"Crawl error: {str(e)}", exc_info=True) logger.error(f"Crawl error: {str(e)}", exc_info=True)
if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started if (
"crawler" in locals() and crawler.ready
): # Check if crawler was initialized and started
# try: # try:
# await crawler.close() # await crawler.close()
# except Exception as close_e: # except Exception as close_e:
@@ -601,19 +686,26 @@ async def handle_crawl_request(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=json.dumps({ # Send structured error detail=json.dumps(
"error": str(e), { # Send structured error
"server_memory_delta_mb": mem_delta_mb, "error": str(e),
"server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0) "server_memory_delta_mb": mem_delta_mb,
}) "server_peak_memory_mb": max(
peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0
),
}
),
) )
async def handle_stream_crawl_request( async def handle_stream_crawl_request(
urls: List[str], urls: List[str],
browser_config: dict, browser_config: dict,
crawler_config: dict, crawler_config: dict,
config: dict, config: dict,
hooks_config: Optional[dict] = None hooks_config: Optional[dict] = None,
anti_bot_strategy: str = "default",
headless: bool = True,
) -> Tuple[AsyncWebCrawler, AsyncGenerator, Optional[Dict]]: ) -> Tuple[AsyncWebCrawler, AsyncGenerator, Optional[Dict]]:
"""Handle streaming crawl requests with optional hooks.""" """Handle streaming crawl requests with optional hooks."""
hooks_info = None hooks_info = None
@@ -621,60 +713,68 @@ async def handle_stream_crawl_request(
browser_config = BrowserConfig.load(browser_config) browser_config = BrowserConfig.load(browser_config)
# browser_config.verbose = True # Set to False or remove for production stress testing # browser_config.verbose = True # Set to False or remove for production stress testing
browser_config.verbose = False browser_config.verbose = False
_apply_headless_setting(browser_config, headless)
crawler_config = CrawlerRunConfig.load(crawler_config) crawler_config = CrawlerRunConfig.load(crawler_config)
crawler_config.scraping_strategy = LXMLWebScrapingStrategy() crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
crawler_config.stream = True crawler_config.stream = True
# Configure browser adapter based on anti_bot_strategy
browser_adapter = _get_browser_adapter(anti_bot_strategy, browser_config)
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"], memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"]) base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
) ),
) )
from crawler_pool import get_crawler from crawler_pool import get_crawler
crawler = await get_crawler(browser_config)
crawler = await get_crawler(browser_config, browser_adapter)
# crawler = AsyncWebCrawler(config=browser_config) # crawler = AsyncWebCrawler(config=browser_config)
# await crawler.start() # await crawler.start()
# Attach hooks if provided # Attach hooks if provided
if hooks_config: if hooks_config:
from hook_manager import attach_user_hooks_to_crawler, UserHookManager from hook_manager import UserHookManager, attach_user_hooks_to_crawler
hook_manager = UserHookManager(timeout=hooks_config.get('timeout', 30))
hook_manager = UserHookManager(timeout=hooks_config.get("timeout", 30))
hooks_status, hook_manager = await attach_user_hooks_to_crawler( hooks_status, hook_manager = await attach_user_hooks_to_crawler(
crawler, crawler,
hooks_config.get('code', {}), hooks_config.get("code", {}),
timeout=hooks_config.get('timeout', 30), timeout=hooks_config.get("timeout", 30),
hook_manager=hook_manager hook_manager=hook_manager,
)
logger.info(
f"Hooks attachment status for streaming: {hooks_status['status']}"
) )
logger.info(f"Hooks attachment status for streaming: {hooks_status['status']}")
# Include hook manager in hooks_info for proper tracking # Include hook manager in hooks_info for proper tracking
hooks_info = {'status': hooks_status, 'manager': hook_manager} hooks_info = {"status": hooks_status, "manager": hook_manager}
results_gen = await crawler.arun_many( results_gen = await crawler.arun_many(
urls=urls, urls=urls, config=crawler_config, dispatcher=dispatcher
config=crawler_config,
dispatcher=dispatcher
) )
return crawler, results_gen, hooks_info return crawler, results_gen, hooks_info
except Exception as e: except Exception as e:
# Make sure to close crawler if started during an error here # Make sure to close crawler if started during an error here
if 'crawler' in locals() and crawler.ready: if "crawler" in locals() and crawler.ready:
# try: # try:
# await crawler.close() # await crawler.close()
# except Exception as close_e: # except Exception as close_e:
# logger.error(f"Error closing crawler during stream setup exception: {close_e}") # logger.error(f"Error closing crawler during stream setup exception: {close_e}")
logger.error(f"Error closing crawler during stream setup exception: {str(e)}") logger.error(
f"Error closing crawler during stream setup exception: {str(e)}"
)
logger.error(f"Stream crawl error: {str(e)}", exc_info=True) logger.error(f"Stream crawl error: {str(e)}", exc_info=True)
# Raising HTTPException here will prevent streaming response # Raising HTTPException here will prevent streaming response
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
detail=str(e)
) )
async def handle_crawl_job( async def handle_crawl_job(
redis, redis,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@@ -689,13 +789,16 @@ async def handle_crawl_job(
lets /crawl/job/{task_id} polling fetch the result. lets /crawl/job/{task_id} polling fetch the result.
""" """
task_id = f"crawl_{uuid4().hex[:8]}" task_id = f"crawl_{uuid4().hex[:8]}"
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.PROCESSING, # <-- keep enum values consistent f"task:{task_id}",
"created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), mapping={
"url": json.dumps(urls), # store list as JSON string "status": TaskStatus.PROCESSING, # <-- keep enum values consistent
"result": "", "created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"error": "", "url": json.dumps(urls), # store list as JSON string
}) "result": "",
"error": "",
},
)
async def _runner(): async def _runner():
try: try:
@@ -705,21 +808,28 @@ async def handle_crawl_job(
crawler_config=crawler_config, crawler_config=crawler_config,
config=config, config=config,
) )
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.COMPLETED, f"task:{task_id}",
"result": json.dumps(result), mapping={
}) "status": TaskStatus.COMPLETED,
"result": json.dumps(result),
},
)
await asyncio.sleep(5) # Give Redis time to process the update await asyncio.sleep(5) # Give Redis time to process the update
except Exception as exc: except Exception as exc:
await redis.hset(f"task:{task_id}", mapping={ await redis.hset(
"status": TaskStatus.FAILED, f"task:{task_id}",
"error": str(exc), mapping={
}) "status": TaskStatus.FAILED,
"error": str(exc),
},
)
background_tasks.add_task(_runner) background_tasks.add_task(_runner)
return {"task_id": task_id} return {"task_id": task_id}
async def handle_seed(url ,cfg):
async def handle_seed(url, cfg):
# Create the configuration from the request body # Create the configuration from the request body
try: try:
seeding_config = cfg seeding_config = cfg
@@ -732,7 +842,7 @@ async def handle_seed(url ,cfg):
return urls return urls
except Exception as e: except Exception as e:
return { return {
"seeded_urls": [], "seeded_urls": [],
"count": 0, "count": 0,
"message": "No URLs found for the given domain and configuration.", "message": "No URLs found for the given domain and configuration.",
} }

View File

@@ -1,10 +1,27 @@
# crawler_pool.py (new file) # crawler_pool.py (new file)
import asyncio, json, hashlib, time, psutil import asyncio
import hashlib
import json
import time
from contextlib import suppress from contextlib import suppress
from typing import Dict from typing import Dict, Optional
import psutil
from crawl4ai import AsyncWebCrawler, BrowserConfig from crawl4ai import AsyncWebCrawler, BrowserConfig
from typing import Dict from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
from utils import load_config
# Import browser adapters with fallback
try:
from crawl4ai.browser_adapter import BrowserAdapter, PlaywrightAdapter
except ImportError:
# Fallback for development environment
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from crawl4ai.browser_adapter import BrowserAdapter, PlaywrightAdapter
from utils import load_config
CONFIG = load_config() CONFIG = load_config()
@@ -12,25 +29,44 @@ POOL: Dict[str, AsyncWebCrawler] = {}
LAST_USED: Dict[str, float] = {} LAST_USED: Dict[str, float] = {}
LOCK = asyncio.Lock() LOCK = asyncio.Lock()
MEM_LIMIT = CONFIG.get("crawler", {}).get("memory_threshold_percent", 95.0) # % RAM refuse new browsers above this MEM_LIMIT = CONFIG.get("crawler", {}).get(
IDLE_TTL = CONFIG.get("crawler", {}).get("pool", {}).get("idle_ttl_sec", 1800) # close if unused for 30min "memory_threshold_percent", 95.0
) # % RAM refuse new browsers above this
IDLE_TTL = (
CONFIG.get("crawler", {}).get("pool", {}).get("idle_ttl_sec", 1800)
) # close if unused for 30min
def _sig(cfg: BrowserConfig) -> str:
payload = json.dumps(cfg.to_dict(), sort_keys=True, separators=(",",":")) def _sig(cfg: BrowserConfig, adapter: Optional[BrowserAdapter] = None) -> str:
config_payload = json.dumps(cfg.to_dict(), sort_keys=True, separators=(",", ":"))
adapter_name = adapter.__class__.__name__ if adapter else "PlaywrightAdapter"
payload = f"{config_payload}:{adapter_name}"
return hashlib.sha1(payload.encode()).hexdigest() return hashlib.sha1(payload.encode()).hexdigest()
async def get_crawler(cfg: BrowserConfig) -> AsyncWebCrawler:
async def get_crawler(
cfg: BrowserConfig, adapter: Optional[BrowserAdapter] = None
) -> AsyncWebCrawler:
try: try:
sig = _sig(cfg) sig = _sig(cfg, adapter)
async with LOCK: async with LOCK:
if sig in POOL: if sig in POOL:
LAST_USED[sig] = time.time(); LAST_USED[sig] = time.time()
return POOL[sig] return POOL[sig]
if psutil.virtual_memory().percent >= MEM_LIMIT: if psutil.virtual_memory().percent >= MEM_LIMIT:
raise MemoryError("RAM pressure new browser denied") raise MemoryError("RAM pressure new browser denied")
crawler = AsyncWebCrawler(config=cfg, thread_safe=False)
# Create strategy with the specified adapter
strategy = AsyncPlaywrightCrawlerStrategy(
browser_config=cfg, browser_adapter=adapter or PlaywrightAdapter()
)
crawler = AsyncWebCrawler(
config=cfg, crawler_strategy=strategy, thread_safe=False
)
await crawler.start() await crawler.start()
POOL[sig] = crawler; LAST_USED[sig] = time.time() POOL[sig] = crawler
LAST_USED[sig] = time.time()
return crawler return crawler
except MemoryError as e: except MemoryError as e:
raise MemoryError(f"RAM pressure new browser denied: {e}") raise MemoryError(f"RAM pressure new browser denied: {e}")
@@ -44,10 +80,16 @@ async def get_crawler(cfg: BrowserConfig) -> AsyncWebCrawler:
POOL.pop(sig, None) POOL.pop(sig, None)
LAST_USED.pop(sig, None) LAST_USED.pop(sig, None)
# If we failed to start the browser, we should remove it from the pool # If we failed to start the browser, we should remove it from the pool
async def close_all(): async def close_all():
async with LOCK: async with LOCK:
await asyncio.gather(*(c.close() for c in POOL.values()), return_exceptions=True) await asyncio.gather(
POOL.clear(); LAST_USED.clear() *(c.close() for c in POOL.values()), return_exceptions=True
)
POOL.clear()
LAST_USED.clear()
async def janitor(): async def janitor():
while True: while True:
@@ -56,5 +98,7 @@ async def janitor():
async with LOCK: async with LOCK:
for sig, crawler in list(POOL.items()): for sig, crawler in list(POOL.items()):
if now - LAST_USED[sig] > IDLE_TTL: if now - LAST_USED[sig] > IDLE_TTL:
with suppress(Exception): await crawler.close() with suppress(Exception):
POOL.pop(sig, None); LAST_USED.pop(sig, None) await crawler.close()
POOL.pop(sig, None)
LAST_USED.pop(sig, None)

View File

@@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from utils import FilterType from utils import FilterType
@@ -10,6 +10,11 @@ class CrawlRequest(BaseModel):
browser_config: Optional[Dict] = Field(default_factory=dict) browser_config: Optional[Dict] = Field(default_factory=dict)
crawler_config: Optional[Dict] = Field(default_factory=dict) crawler_config: Optional[Dict] = Field(default_factory=dict)
anti_bot_strategy: Literal["default", "stealth", "undetected", "max_evasion"] = (
Field("default", description="The anti-bot strategy to use for the crawl.")
)
headless: bool = Field(True, description="Run the browser in headless mode.")
class HookConfig(BaseModel): class HookConfig(BaseModel):
"""Configuration for user-provided hooks""" """Configuration for user-provided hooks"""

View File

@@ -49,6 +49,7 @@ from rank_bm25 import BM25Okapi
from redis import asyncio as aioredis from redis import asyncio as aioredis
from routers import adaptive, scripts from routers import adaptive, scripts
from schemas import ( from schemas import (
CrawlRequest,
CrawlRequestWithHooks, CrawlRequestWithHooks,
HTMLRequest, HTMLRequest,
JSEndpointRequest, JSEndpointRequest,
@@ -575,7 +576,7 @@ async def metrics():
@mcp_tool("crawl") @mcp_tool("crawl")
async def crawl( async def crawl(
request: Request, request: Request,
crawl_request: CrawlRequestWithHooks, crawl_request: CrawlRequest | CrawlRequestWithHooks,
_td: Dict = Depends(token_dep), _td: Dict = Depends(token_dep),
): ):
""" """
@@ -592,7 +593,7 @@ async def crawl(
# Prepare hooks config if provided # Prepare hooks config if provided
hooks_config = None hooks_config = None
if crawl_request.hooks: if hasattr(crawl_request, 'hooks') and crawl_request.hooks:
hooks_config = { hooks_config = {
"code": crawl_request.hooks.code, "code": crawl_request.hooks.code,
"timeout": crawl_request.hooks.timeout, "timeout": crawl_request.hooks.timeout,
@@ -604,6 +605,8 @@ async def crawl(
crawler_config=crawl_request.crawler_config, crawler_config=crawl_request.crawler_config,
config=config, config=config,
hooks_config=hooks_config, hooks_config=hooks_config,
anti_bot_strategy=crawl_request.anti_bot_strategy,
headless=crawl_request.headless,
) )
# check if all of the results are not successful # check if all of the results are not successful
if all(not result["success"] for result in results["results"]): if all(not result["success"] for result in results["results"]):
@@ -627,9 +630,9 @@ async def crawl_stream(
async def stream_process(crawl_request: CrawlRequestWithHooks): async def stream_process(crawl_request: CrawlRequestWithHooks):
# Prepare hooks config if provided# Prepare hooks config if provided # Prepare hooks config if provided
hooks_config = None hooks_config = None
if crawl_request.hooks: if hasattr(crawl_request, 'hooks') and crawl_request.hooks:
hooks_config = { hooks_config = {
"code": crawl_request.hooks.code, "code": crawl_request.hooks.code,
"timeout": crawl_request.hooks.timeout, "timeout": crawl_request.hooks.timeout,
@@ -641,6 +644,8 @@ async def stream_process(crawl_request: CrawlRequestWithHooks):
crawler_config=crawl_request.crawler_config, crawler_config=crawl_request.crawler_config,
config=config, config=config,
hooks_config=hooks_config, hooks_config=hooks_config,
anti_bot_strategy=crawl_request.anti_bot_strategy,
headless=crawl_request.headless,
) )
# Add hooks info to response headers if available # Add hooks info to response headers if available