diff --git a/deploy/docker/api.py b/deploy/docker/api.py index 959e64c0..66d02d6d 100644 --- a/deploy/docker/api.py +++ b/deploy/docker/api.py @@ -1,57 +1,63 @@ -import os -import json import asyncio -from typing import List, Tuple, Dict -from functools import partial -from uuid import uuid4 -from datetime import datetime, timezone -from base64 import b64encode - +import json import logging -from typing import Optional, AsyncGenerator +import os +import time +from base64 import b64encode +from datetime import datetime, timezone +from functools import partial +from typing import AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import unquote from fastapi import HTTPException, Request, status from fastapi.background import BackgroundTasks from fastapi.responses import JSONResponse from redis import asyncio as aioredis from crawl4ai import ( + AsyncUrlSeeder, AsyncWebCrawler, - CrawlerRunConfig, - LLMExtractionStrategy, - CacheMode, BrowserConfig, - MemoryAdaptiveDispatcher, - RateLimiter, + CacheMode, + CrawlerRunConfig, LLMConfig, - AsyncUrlSeeder, - SeedingConfig + LLMExtractionStrategy, + MemoryAdaptiveDispatcher, + PlaywrightAdapter, + RateLimiter, + SeedingConfig, + UndetectedAdapter, ) -from crawl4ai.utils import perform_completion_with_backoff + +# Import StealthAdapter with fallback for compatibility +try: + from crawl4ai import StealthAdapter +except ImportError: + # Fallback: import directly from browser_adapter module + try: + import os + import sys + # Add the project root to path for development + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + sys.path.insert(0, project_root) + from crawl4ai.browser_adapter import StealthAdapter + except ImportError: + # If all else fails, create a simple fallback + from crawl4ai.browser_adapter import PlaywrightAdapter + + class StealthAdapter(PlaywrightAdapter): + """Fallback StealthAdapter that uses PlaywrightAdapter""" + pass from crawl4ai.content_filter_strategy import ( - PruningContentFilter, BM25ContentFilter, - LLMContentFilter + LLMContentFilter, + PruningContentFilter, ) -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy - -from utils import ( - TaskStatus, - FilterType, - get_base_url, - is_task_id, - should_cleanup_task, - decode_redis_hash, - get_llm_api_key, - validate_llm_provider, - get_llm_temperature, - get_llm_base_url -) - -import psutil, time +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator +from crawl4ai.utils import perform_completion_with_backoff logger = logging.getLogger(__name__) + # --- Helper to get memory --- def _get_memory_mb(): try: @@ -61,17 +67,39 @@ def _get_memory_mb(): return None -async def handle_llm_qa( - url: str, - query: str, - config: dict -) -> str: +# --- Helper to get browser adapter based on anti_bot_strategy --- +def _get_browser_adapter(anti_bot_strategy: str, browser_config: BrowserConfig): + """Get the appropriate browser adapter based on anti_bot_strategy.""" + if anti_bot_strategy == "stealth": + return StealthAdapter() + elif anti_bot_strategy == "undetected": + return UndetectedAdapter() + elif anti_bot_strategy == "max_evasion": + # Use undetected for maximum evasion + return UndetectedAdapter() + else: # "default" + # If stealth is enabled in browser config, use stealth adapter + if getattr(browser_config, "enable_stealth", False): + return StealthAdapter() + return PlaywrightAdapter() + + +# --- Helper to apply headless setting --- +def _apply_headless_setting(browser_config: BrowserConfig, headless: bool): + """Apply headless setting to browser config.""" + browser_config.headless = headless + return browser_config + + +async def handle_llm_qa(url: str, query: str, config: dict) -> str: """Process QA using LLM with crawled content as context.""" try: - if not url.startswith(('http://', 'https://')) and not url.startswith(("raw:", "raw://")): - url = 'https://' + url + if not url.startswith(("http://", "https://")) and not url.startswith( + ("raw:", "raw://") + ): + url = "https://" + url # Extract base URL by finding last '?q=' occurrence - last_q_index = url.rfind('?q=') + last_q_index = url.rfind("?q=") if last_q_index != -1: url = url[:last_q_index] @@ -81,7 +109,7 @@ async def handle_llm_qa( if not result.success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=result.error_message + detail=result.error_message, ) content = result.markdown.fit_markdown or result.markdown.raw_markdown @@ -101,17 +129,17 @@ async def handle_llm_qa( prompt_with_variables=prompt, api_token=get_llm_api_key(config), # Returns None to let litellm handle it temperature=get_llm_temperature(config), - base_url=get_llm_base_url(config) + base_url=get_llm_base_url(config), ) return response.choices[0].message.content except Exception as e: logger.error(f"QA processing error: {str(e)}", exc_info=True) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) + async def process_llm_extraction( redis: aioredis.Redis, config: dict, @@ -122,25 +150,27 @@ async def process_llm_extraction( cache: str = "0", provider: Optional[str] = None, temperature: Optional[float] = None, - base_url: Optional[str] = None + base_url: Optional[str] = None, ) -> None: """Process LLM extraction in background.""" try: # Validate provider is_valid, error_msg = validate_llm_provider(config, provider) if not is_valid: - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.FAILED, - "error": error_msg - }) + await redis.hset( + f"task:{task_id}", + mapping={"status": TaskStatus.FAILED, "error": error_msg}, + ) return - api_key = get_llm_api_key(config, provider) # Returns None to let litellm handle it + api_key = get_llm_api_key( + config, provider + ) # Returns None to let litellm handle it llm_strategy = LLMExtractionStrategy( llm_config=LLMConfig( provider=provider or config["llm"]["provider"], api_token=api_key, temperature=temperature or get_llm_temperature(config, provider), - base_url=base_url or get_llm_base_url(config, provider) + base_url=base_url or get_llm_base_url(config, provider), ), instruction=instruction, schema=json.loads(schema) if schema else None, @@ -154,32 +184,32 @@ async def process_llm_extraction( config=CrawlerRunConfig( extraction_strategy=llm_strategy, scraping_strategy=LXMLWebScrapingStrategy(), - cache_mode=cache_mode - ) + cache_mode=cache_mode, + ), ) if not result.success: - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.FAILED, - "error": result.error_message - }) + await redis.hset( + f"task:{task_id}", + mapping={"status": TaskStatus.FAILED, "error": result.error_message}, + ) return try: content = json.loads(result.extracted_content) except json.JSONDecodeError: content = result.extracted_content - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.COMPLETED, - "result": json.dumps(content) - }) + await redis.hset( + f"task:{task_id}", + mapping={"status": TaskStatus.COMPLETED, "result": json.dumps(content)}, + ) except Exception as e: logger.error(f"LLM extraction error: {str(e)}", exc_info=True) - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.FAILED, - "error": str(e) - }) + await redis.hset( + f"task:{task_id}", mapping={"status": TaskStatus.FAILED, "error": str(e)} + ) + async def handle_markdown_request( url: str, @@ -189,7 +219,7 @@ async def handle_markdown_request( config: Optional[dict] = None, provider: Optional[str] = None, temperature: Optional[float] = None, - base_url: Optional[str] = None + base_url: Optional[str] = None, ) -> str: """Handle markdown generation requests.""" try: @@ -198,12 +228,13 @@ async def handle_markdown_request( is_valid, error_msg = validate_llm_provider(config, provider) if not is_valid: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=error_msg + status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg ) decoded_url = unquote(url) - if not decoded_url.startswith(('http://', 'https://')) and not decoded_url.startswith(("raw:", "raw://")): - decoded_url = 'https://' + decoded_url + if not decoded_url.startswith( + ("http://", "https://") + ) and not decoded_url.startswith(("raw:", "raw://")): + decoded_url = "https://" + decoded_url if filter_type == FilterType.RAW: md_generator = DefaultMarkdownGenerator() @@ -214,12 +245,15 @@ async def handle_markdown_request( FilterType.LLM: LLMContentFilter( llm_config=LLMConfig( provider=provider or config["llm"]["provider"], - api_token=get_llm_api_key(config, provider), # Returns None to let litellm handle it - temperature=temperature or get_llm_temperature(config, provider), - base_url=base_url or get_llm_base_url(config, provider) + api_token=get_llm_api_key( + config, provider + ), # Returns None to let litellm handle it + temperature=temperature + or get_llm_temperature(config, provider), + base_url=base_url or get_llm_base_url(config, provider), ), - instruction=query or "Extract main content" - ) + instruction=query or "Extract main content", + ), }[filter_type] md_generator = DefaultMarkdownGenerator(content_filter=content_filter) @@ -231,27 +265,29 @@ async def handle_markdown_request( config=CrawlerRunConfig( markdown_generator=md_generator, scraping_strategy=LXMLWebScrapingStrategy(), - cache_mode=cache_mode - ) + cache_mode=cache_mode, + ), ) - + if not result.success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=result.error_message + detail=result.error_message, ) - return (result.markdown.raw_markdown - if filter_type == FilterType.RAW - else result.markdown.fit_markdown) + return ( + result.markdown.raw_markdown + if filter_type == FilterType.RAW + else result.markdown.fit_markdown + ) except Exception as e: logger.error(f"Markdown error: {str(e)}", exc_info=True) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) + async def handle_llm_request( redis: aioredis.Redis, background_tasks: BackgroundTasks, @@ -263,27 +299,27 @@ async def handle_llm_request( config: Optional[dict] = None, provider: Optional[str] = None, temperature: Optional[float] = None, - api_base_url: Optional[str] = None + api_base_url: Optional[str] = None, ) -> JSONResponse: """Handle LLM extraction requests.""" base_url = get_base_url(request) - + try: if is_task_id(input_path): - return await handle_task_status( - redis, input_path, base_url - ) + return await handle_task_status(redis, input_path, base_url) if not query: - return JSONResponse({ - "message": "Please provide an instruction", - "_links": { - "example": { - "href": f"{base_url}/llm/{input_path}?q=Extract+main+content", - "title": "Try this example" - } + return JSONResponse( + { + "message": "Please provide an instruction", + "_links": { + "example": { + "href": f"{base_url}/llm/{input_path}?q=Extract+main+content", + "title": "Try this example", + } + }, } - }) + ) return await create_new_task( redis, @@ -296,31 +332,25 @@ async def handle_llm_request( config, provider, temperature, - api_base_url + api_base_url, ) except Exception as e: logger.error(f"LLM endpoint error: {str(e)}", exc_info=True) - return JSONResponse({ - "error": str(e), - "_links": { - "retry": {"href": str(request.url)} - } - }, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + return JSONResponse( + {"error": str(e), "_links": {"retry": {"href": str(request.url)}}}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + async def handle_task_status( - redis: aioredis.Redis, - task_id: str, - base_url: str, - *, - keep: bool = False + redis: aioredis.Redis, task_id: str, base_url: str, *, keep: bool = False ) -> JSONResponse: """Handle task status check requests.""" task = await redis.hgetall(f"task:{task_id}") if not task: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Task not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" ) task = decode_redis_hash(task) @@ -332,6 +362,7 @@ async def handle_task_status( return JSONResponse(response) + async def create_new_task( redis: aioredis.Redis, background_tasks: BackgroundTasks, @@ -343,21 +374,27 @@ async def create_new_task( config: dict, provider: Optional[str] = None, temperature: Optional[float] = None, - api_base_url: Optional[str] = None + api_base_url: Optional[str] = None, ) -> JSONResponse: """Create and initialize a new task.""" decoded_url = unquote(input_path) - if not decoded_url.startswith(('http://', 'https://')) and not decoded_url.startswith(("raw:", "raw://")): - decoded_url = 'https://' + decoded_url + if not decoded_url.startswith( + ("http://", "https://") + ) and not decoded_url.startswith(("raw:", "raw://")): + decoded_url = "https://" + decoded_url from datetime import datetime + task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}" - - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.PROCESSING, - "created_at": datetime.now().isoformat(), - "url": decoded_url - }) + + await redis.hset( + f"task:{task_id}", + mapping={ + "status": TaskStatus.PROCESSING, + "created_at": datetime.now().isoformat(), + "url": decoded_url, + }, + ) background_tasks.add_task( process_llm_extraction, @@ -370,18 +407,21 @@ async def create_new_task( cache, provider, temperature, - api_base_url + api_base_url, ) - return JSONResponse({ - "task_id": task_id, - "status": TaskStatus.PROCESSING, - "url": decoded_url, - "_links": { - "self": {"href": f"{base_url}/llm/{task_id}"}, - "status": {"href": f"{base_url}/llm/{task_id}"} + return JSONResponse( + { + "task_id": task_id, + "status": TaskStatus.PROCESSING, + "url": decoded_url, + "_links": { + "self": {"href": f"{base_url}/llm/{task_id}"}, + "status": {"href": f"{base_url}/llm/{task_id}"}, + }, } - }) + ) + def create_task_response(task: dict, task_id: str, base_url: str) -> dict: """Create response for task status check.""" @@ -392,8 +432,8 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict: "url": task["url"], "_links": { "self": {"href": f"{base_url}/llm/{task_id}"}, - "refresh": {"href": f"{base_url}/llm/{task_id}"} - } + "refresh": {"href": f"{base_url}/llm/{task_id}"}, + }, } if task["status"] == TaskStatus.COMPLETED: @@ -403,9 +443,13 @@ def create_task_response(task: dict, task_id: str, base_url: str) -> dict: return response -async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]: + +async def stream_results( + crawler: AsyncWebCrawler, results_gen: AsyncGenerator +) -> AsyncGenerator[bytes, None]: """Stream results with heartbeats and completion markers.""" import json + from utils import datetime_handler try: @@ -413,23 +457,29 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) try: server_memory_mb = _get_memory_mb() result_dict = result.model_dump() - result_dict['server_memory_mb'] = server_memory_mb + result_dict["server_memory_mb"] = server_memory_mb # Ensure fit_html is JSON-serializable - if "fit_html" in result_dict and not (result_dict["fit_html"] is None or isinstance(result_dict["fit_html"], str)): + if "fit_html" in result_dict and not ( + result_dict["fit_html"] is None + or isinstance(result_dict["fit_html"], str) + ): result_dict["fit_html"] = None # If PDF exists, encode it to base64 - if result_dict.get('pdf') is not None: - result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8') + if result_dict.get("pdf") is not None: + result_dict["pdf"] = b64encode(result_dict["pdf"]).decode("utf-8") logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}") data = json.dumps(result_dict, default=datetime_handler) + "\n" - yield data.encode('utf-8') + yield data.encode("utf-8") except Exception as e: logger.error(f"Serialization error: {e}") - error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')} - yield (json.dumps(error_response) + "\n").encode('utf-8') + error_response = { + "error": str(e), + "url": getattr(result, "url", "unknown"), + } + yield (json.dumps(error_response) + "\n").encode("utf-8") + + yield json.dumps({"status": "completed"}).encode("utf-8") - yield json.dumps({"status": "completed"}).encode('utf-8') - except asyncio.CancelledError: logger.warning("Client disconnected during streaming") finally: @@ -439,51 +489,70 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) # logger.error(f"Crawler cleanup error: {e}") pass + async def handle_crawl_request( urls: List[str], browser_config: dict, crawler_config: dict, config: dict, - hooks_config: Optional[dict] = None + hooks_config: Optional[dict] = None, + anti_bot_strategy: str = "default", + headless: bool = True, ) -> dict: """Handle non-streaming crawl requests with optional hooks.""" - start_mem_mb = _get_memory_mb() # <--- Get memory before + start_mem_mb = _get_memory_mb() # <--- Get memory before start_time = time.time() mem_delta_mb = None peak_mem_mb = start_mem_mb hook_manager = None - + try: - urls = [('https://' + url) if not url.startswith(('http://', 'https://')) and not url.startswith(("raw:", "raw://")) else url for url in urls] + urls = [ + ("https://" + url) + if not url.startswith(("http://", "https://")) + and not url.startswith(("raw:", "raw://")) + else url + for url in urls + ] browser_config = BrowserConfig.load(browser_config) + _apply_headless_setting(browser_config, headless) crawler_config = CrawlerRunConfig.load(crawler_config) + # Configure browser adapter based on anti_bot_strategy + browser_adapter = _get_browser_adapter(anti_bot_strategy, browser_config) + + # TODO: add support for other dispatchers + dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=config["crawler"]["memory_threshold_percent"], rate_limiter=RateLimiter( base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"]) - ) if config["crawler"]["rate_limiter"]["enabled"] else None + ) + if config["crawler"]["rate_limiter"]["enabled"] + else None, ) - + from crawler_pool import get_crawler - crawler = await get_crawler(browser_config) + + crawler = await get_crawler(browser_config, browser_adapter) # crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config) # await crawler.start() - + # Attach hooks if provided hooks_status = {} if hooks_config: - from hook_manager import attach_user_hooks_to_crawler, UserHookManager - hook_manager = UserHookManager(timeout=hooks_config.get('timeout', 30)) + from hook_manager import UserHookManager, attach_user_hooks_to_crawler + + hook_manager = UserHookManager(timeout=hooks_config.get("timeout", 30)) hooks_status, hook_manager = await attach_user_hooks_to_crawler( crawler, - hooks_config.get('code', {}), - timeout=hooks_config.get('timeout', 30), - hook_manager=hook_manager + hooks_config.get("code", {}), + timeout=hooks_config.get("timeout", 30), + hook_manager=hook_manager, ) logger.info(f"Hooks attachment status: {hooks_status['status']}") - + base_config = config["crawler"]["base_config"] # Iterate on key-value pairs in global_config then use hasattr to set them for key, value in base_config.items(): @@ -495,32 +564,38 @@ async def handle_crawl_request( results = [] func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many") - partial_func = partial(func, - urls[0] if len(urls) == 1 else urls, - config=crawler_config, - dispatcher=dispatcher) + partial_func = partial( + func, + urls[0] if len(urls) == 1 else urls, + config=crawler_config, + dispatcher=dispatcher, + ) results = await partial_func() - + # Ensure results is always a list if not isinstance(results, list): results = [results] # await crawler.close() - - end_mem_mb = _get_memory_mb() # <--- Get memory after + + end_mem_mb = _get_memory_mb() # <--- Get memory after end_time = time.time() - + if start_mem_mb is not None and end_mem_mb is not None: - mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta - peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory - logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB") + mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta + peak_mem_mb = max( + peak_mem_mb if peak_mem_mb else 0, end_mem_mb + ) # <--- Get peak memory + logger.info( + f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB" + ) # Process results to handle PDF bytes processed_results = [] for result in results: try: # Check if result has model_dump method (is a proper CrawlResult) - if hasattr(result, 'model_dump'): + if hasattr(result, "model_dump"): result_dict = result.model_dump() elif isinstance(result, dict): result_dict = result @@ -528,48 +603,53 @@ async def handle_crawl_request( # Handle unexpected result type logger.warning(f"Unexpected result type: {type(result)}") result_dict = { - "url": str(result) if hasattr(result, '__str__') else "unknown", + "url": str(result) if hasattr(result, "__str__") else "unknown", "success": False, - "error_message": f"Unexpected result type: {type(result).__name__}" + "error_message": f"Unexpected result type: {type(result).__name__}", } - + # if fit_html is not a string, set it to None to avoid serialization errors - if "fit_html" in result_dict and not (result_dict["fit_html"] is None or isinstance(result_dict["fit_html"], str)): + if "fit_html" in result_dict and not ( + result_dict["fit_html"] is None + or isinstance(result_dict["fit_html"], str) + ): result_dict["fit_html"] = None - + # If PDF exists, encode it to base64 - if result_dict.get('pdf') is not None and isinstance(result_dict.get('pdf'), bytes): - result_dict['pdf'] = b64encode(result_dict['pdf']).decode('utf-8') - + if result_dict.get("pdf") is not None and isinstance( + result_dict.get("pdf"), bytes + ): + result_dict["pdf"] = b64encode(result_dict["pdf"]).decode("utf-8") + processed_results.append(result_dict) except Exception as e: logger.error(f"Error processing result: {e}") - processed_results.append({ - "url": "unknown", - "success": False, - "error_message": str(e) - }) - + processed_results.append( + {"url": "unknown", "success": False, "error_message": str(e)} + ) + response = { "success": True, "results": processed_results, "server_processing_time_s": end_time - start_time, "server_memory_delta_mb": mem_delta_mb, - "server_peak_memory_mb": peak_mem_mb + "server_peak_memory_mb": peak_mem_mb, } - + # Add hooks information if hooks were used if hooks_config and hook_manager: from hook_manager import UserHookManager + if isinstance(hook_manager, UserHookManager): try: # Ensure all hook data is JSON serializable import json + hook_data = { "status": hooks_status, "execution_log": hook_manager.execution_log, "errors": hook_manager.errors, - "summary": hook_manager.get_summary() + "summary": hook_manager.get_summary(), } # Test that it's serializable json.dumps(hook_data) @@ -577,17 +657,22 @@ async def handle_crawl_request( except (TypeError, ValueError) as e: logger.error(f"Hook data not JSON serializable: {e}") response["hooks"] = { - "status": {"status": "error", "message": "Hook data serialization failed"}, + "status": { + "status": "error", + "message": "Hook data serialization failed", + }, "execution_log": [], "errors": [{"error": str(e)}], - "summary": {} + "summary": {}, } - + return response except Exception as e: logger.error(f"Crawl error: {str(e)}", exc_info=True) - if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started + if ( + "crawler" in locals() and crawler.ready + ): # Check if crawler was initialized and started # try: # await crawler.close() # except Exception as close_e: @@ -601,19 +686,26 @@ async def handle_crawl_request( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=json.dumps({ # Send structured error - "error": str(e), - "server_memory_delta_mb": mem_delta_mb, - "server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0) - }) + detail=json.dumps( + { # Send structured error + "error": str(e), + "server_memory_delta_mb": mem_delta_mb, + "server_peak_memory_mb": max( + peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0 + ), + } + ), ) + async def handle_stream_crawl_request( urls: List[str], browser_config: dict, crawler_config: dict, config: dict, - hooks_config: Optional[dict] = None + hooks_config: Optional[dict] = None, + anti_bot_strategy: str = "default", + headless: bool = True, ) -> Tuple[AsyncWebCrawler, AsyncGenerator, Optional[Dict]]: """Handle streaming crawl requests with optional hooks.""" hooks_info = None @@ -621,60 +713,68 @@ async def handle_stream_crawl_request( browser_config = BrowserConfig.load(browser_config) # browser_config.verbose = True # Set to False or remove for production stress testing browser_config.verbose = False + _apply_headless_setting(browser_config, headless) crawler_config = CrawlerRunConfig.load(crawler_config) crawler_config.scraping_strategy = LXMLWebScrapingStrategy() crawler_config.stream = True + # Configure browser adapter based on anti_bot_strategy + browser_adapter = _get_browser_adapter(anti_bot_strategy, browser_config) + dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=config["crawler"]["memory_threshold_percent"], rate_limiter=RateLimiter( base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"]) - ) + ), ) from crawler_pool import get_crawler - crawler = await get_crawler(browser_config) + + crawler = await get_crawler(browser_config, browser_adapter) # crawler = AsyncWebCrawler(config=browser_config) # await crawler.start() - + # Attach hooks if provided if hooks_config: - from hook_manager import attach_user_hooks_to_crawler, UserHookManager - hook_manager = UserHookManager(timeout=hooks_config.get('timeout', 30)) + from hook_manager import UserHookManager, attach_user_hooks_to_crawler + + hook_manager = UserHookManager(timeout=hooks_config.get("timeout", 30)) hooks_status, hook_manager = await attach_user_hooks_to_crawler( crawler, - hooks_config.get('code', {}), - timeout=hooks_config.get('timeout', 30), - hook_manager=hook_manager + hooks_config.get("code", {}), + timeout=hooks_config.get("timeout", 30), + hook_manager=hook_manager, + ) + logger.info( + f"Hooks attachment status for streaming: {hooks_status['status']}" ) - logger.info(f"Hooks attachment status for streaming: {hooks_status['status']}") # Include hook manager in hooks_info for proper tracking - hooks_info = {'status': hooks_status, 'manager': hook_manager} + hooks_info = {"status": hooks_status, "manager": hook_manager} results_gen = await crawler.arun_many( - urls=urls, - config=crawler_config, - dispatcher=dispatcher + urls=urls, config=crawler_config, dispatcher=dispatcher ) return crawler, results_gen, hooks_info except Exception as e: # Make sure to close crawler if started during an error here - if 'crawler' in locals() and crawler.ready: + if "crawler" in locals() and crawler.ready: # try: # await crawler.close() # except Exception as close_e: # logger.error(f"Error closing crawler during stream setup exception: {close_e}") - logger.error(f"Error closing crawler during stream setup exception: {str(e)}") + logger.error( + f"Error closing crawler during stream setup exception: {str(e)}" + ) logger.error(f"Stream crawl error: {str(e)}", exc_info=True) # Raising HTTPException here will prevent streaming response raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) - + + async def handle_crawl_job( redis, background_tasks: BackgroundTasks, @@ -689,13 +789,16 @@ async def handle_crawl_job( lets /crawl/job/{task_id} polling fetch the result. """ task_id = f"crawl_{uuid4().hex[:8]}" - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.PROCESSING, # <-- keep enum values consistent - "created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), - "url": json.dumps(urls), # store list as JSON string - "result": "", - "error": "", - }) + await redis.hset( + f"task:{task_id}", + mapping={ + "status": TaskStatus.PROCESSING, # <-- keep enum values consistent + "created_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), + "url": json.dumps(urls), # store list as JSON string + "result": "", + "error": "", + }, + ) async def _runner(): try: @@ -705,21 +808,28 @@ async def handle_crawl_job( crawler_config=crawler_config, config=config, ) - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.COMPLETED, - "result": json.dumps(result), - }) + await redis.hset( + f"task:{task_id}", + mapping={ + "status": TaskStatus.COMPLETED, + "result": json.dumps(result), + }, + ) await asyncio.sleep(5) # Give Redis time to process the update except Exception as exc: - await redis.hset(f"task:{task_id}", mapping={ - "status": TaskStatus.FAILED, - "error": str(exc), - }) + await redis.hset( + f"task:{task_id}", + mapping={ + "status": TaskStatus.FAILED, + "error": str(exc), + }, + ) background_tasks.add_task(_runner) return {"task_id": task_id} -async def handle_seed(url ,cfg): + +async def handle_seed(url, cfg): # Create the configuration from the request body try: seeding_config = cfg @@ -732,7 +842,7 @@ async def handle_seed(url ,cfg): return urls except Exception as e: return { - "seeded_urls": [], - "count": 0, - "message": "No URLs found for the given domain and configuration.", - } + "seeded_urls": [], + "count": 0, + "message": "No URLs found for the given domain and configuration.", + } diff --git a/deploy/docker/crawler_pool.py b/deploy/docker/crawler_pool.py index d15102e4..8a5f9381 100644 --- a/deploy/docker/crawler_pool.py +++ b/deploy/docker/crawler_pool.py @@ -1,10 +1,27 @@ # crawler_pool.py (new file) -import asyncio, json, hashlib, time, psutil +import asyncio +import hashlib +import json +import time from contextlib import suppress -from typing import Dict +from typing import Dict, Optional + +import psutil + from crawl4ai import AsyncWebCrawler, BrowserConfig -from typing import Dict -from utils import load_config +from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy + +# Import browser adapters with fallback +try: + from crawl4ai.browser_adapter import BrowserAdapter, PlaywrightAdapter +except ImportError: + # Fallback for development environment + import os + import sys + + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + from crawl4ai.browser_adapter import BrowserAdapter, PlaywrightAdapter +from utils import load_config CONFIG = load_config() @@ -12,25 +29,44 @@ POOL: Dict[str, AsyncWebCrawler] = {} LAST_USED: Dict[str, float] = {} LOCK = asyncio.Lock() -MEM_LIMIT = CONFIG.get("crawler", {}).get("memory_threshold_percent", 95.0) # % RAM – refuse new browsers above this -IDLE_TTL = CONFIG.get("crawler", {}).get("pool", {}).get("idle_ttl_sec", 1800) # close if unused for 30 min +MEM_LIMIT = CONFIG.get("crawler", {}).get( + "memory_threshold_percent", 95.0 +) # % RAM – refuse new browsers above this +IDLE_TTL = ( + CONFIG.get("crawler", {}).get("pool", {}).get("idle_ttl_sec", 1800) +) # close if unused for 30 min -def _sig(cfg: BrowserConfig) -> str: - payload = json.dumps(cfg.to_dict(), sort_keys=True, separators=(",",":")) + +def _sig(cfg: BrowserConfig, adapter: Optional[BrowserAdapter] = None) -> str: + config_payload = json.dumps(cfg.to_dict(), sort_keys=True, separators=(",", ":")) + adapter_name = adapter.__class__.__name__ if adapter else "PlaywrightAdapter" + payload = f"{config_payload}:{adapter_name}" return hashlib.sha1(payload.encode()).hexdigest() -async def get_crawler(cfg: BrowserConfig) -> AsyncWebCrawler: + +async def get_crawler( + cfg: BrowserConfig, adapter: Optional[BrowserAdapter] = None +) -> AsyncWebCrawler: try: - sig = _sig(cfg) + sig = _sig(cfg, adapter) async with LOCK: if sig in POOL: - LAST_USED[sig] = time.time(); + LAST_USED[sig] = time.time() return POOL[sig] if psutil.virtual_memory().percent >= MEM_LIMIT: raise MemoryError("RAM pressure – new browser denied") - crawler = AsyncWebCrawler(config=cfg, thread_safe=False) + + # Create strategy with the specified adapter + strategy = AsyncPlaywrightCrawlerStrategy( + browser_config=cfg, browser_adapter=adapter or PlaywrightAdapter() + ) + + crawler = AsyncWebCrawler( + config=cfg, crawler_strategy=strategy, thread_safe=False + ) await crawler.start() - POOL[sig] = crawler; LAST_USED[sig] = time.time() + POOL[sig] = crawler + LAST_USED[sig] = time.time() return crawler except MemoryError as e: raise MemoryError(f"RAM pressure – new browser denied: {e}") @@ -44,10 +80,16 @@ async def get_crawler(cfg: BrowserConfig) -> AsyncWebCrawler: POOL.pop(sig, None) LAST_USED.pop(sig, None) # If we failed to start the browser, we should remove it from the pool + + async def close_all(): async with LOCK: - await asyncio.gather(*(c.close() for c in POOL.values()), return_exceptions=True) - POOL.clear(); LAST_USED.clear() + await asyncio.gather( + *(c.close() for c in POOL.values()), return_exceptions=True + ) + POOL.clear() + LAST_USED.clear() + async def janitor(): while True: @@ -56,5 +98,7 @@ async def janitor(): async with LOCK: for sig, crawler in list(POOL.items()): if now - LAST_USED[sig] > IDLE_TTL: - with suppress(Exception): await crawler.close() - POOL.pop(sig, None); LAST_USED.pop(sig, None) + with suppress(Exception): + await crawler.close() + POOL.pop(sig, None) + LAST_USED.pop(sig, None) diff --git a/deploy/docker/schemas.py b/deploy/docker/schemas.py index 5263bfeb..d9384eac 100644 --- a/deploy/docker/schemas.py +++ b/deploy/docker/schemas.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field from utils import FilterType @@ -10,6 +10,11 @@ class CrawlRequest(BaseModel): browser_config: Optional[Dict] = Field(default_factory=dict) crawler_config: Optional[Dict] = Field(default_factory=dict) + anti_bot_strategy: Literal["default", "stealth", "undetected", "max_evasion"] = ( + Field("default", description="The anti-bot strategy to use for the crawl.") + ) + headless: bool = Field(True, description="Run the browser in headless mode.") + class HookConfig(BaseModel): """Configuration for user-provided hooks""" diff --git a/deploy/docker/server.py b/deploy/docker/server.py index 1fe2783f..dd31a093 100644 --- a/deploy/docker/server.py +++ b/deploy/docker/server.py @@ -49,6 +49,7 @@ from rank_bm25 import BM25Okapi from redis import asyncio as aioredis from routers import adaptive, scripts from schemas import ( + CrawlRequest, CrawlRequestWithHooks, HTMLRequest, JSEndpointRequest, @@ -575,7 +576,7 @@ async def metrics(): @mcp_tool("crawl") async def crawl( request: Request, - crawl_request: CrawlRequestWithHooks, + crawl_request: CrawlRequest | CrawlRequestWithHooks, _td: Dict = Depends(token_dep), ): """ @@ -592,7 +593,7 @@ async def crawl( # Prepare hooks config if provided hooks_config = None - if crawl_request.hooks: + if hasattr(crawl_request, 'hooks') and crawl_request.hooks: hooks_config = { "code": crawl_request.hooks.code, "timeout": crawl_request.hooks.timeout, @@ -604,6 +605,8 @@ async def crawl( crawler_config=crawl_request.crawler_config, config=config, hooks_config=hooks_config, + anti_bot_strategy=crawl_request.anti_bot_strategy, + headless=crawl_request.headless, ) # check if all of the results are not successful if all(not result["success"] for result in results["results"]): @@ -627,9 +630,9 @@ async def crawl_stream( async def stream_process(crawl_request: CrawlRequestWithHooks): - # Prepare hooks config if provided# Prepare hooks config if provided + # Prepare hooks config if provided hooks_config = None - if crawl_request.hooks: + if hasattr(crawl_request, 'hooks') and crawl_request.hooks: hooks_config = { "code": crawl_request.hooks.code, "timeout": crawl_request.hooks.timeout, @@ -641,6 +644,8 @@ async def stream_process(crawl_request: CrawlRequestWithHooks): crawler_config=crawl_request.crawler_config, config=config, hooks_config=hooks_config, + anti_bot_strategy=crawl_request.anti_bot_strategy, + headless=crawl_request.headless, ) # Add hooks info to response headers if available