diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 1eaea156..8940b8ab 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -542,9 +542,9 @@ class AsyncWebCrawler: markdown_input_html = source_lambda() # Log which source is being used (optional, but helpful for debugging) - if self.logger and verbose: - actual_source_used = selected_html_source if selected_html_source in html_source_selector else 'cleaned_html (default)' - self.logger.debug(f"Using '{actual_source_used}' as source for Markdown generation for {url}", tag="MARKDOWN_SRC") + # if self.logger and verbose: + # actual_source_used = selected_html_source if selected_html_source in html_source_selector else 'cleaned_html (default)' + # self.logger.debug(f"Using '{actual_source_used}' as source for Markdown generation for {url}", tag="MARKDOWN_SRC") except Exception as e: # Handle potential errors, especially from preprocess_html_for_schema diff --git a/deploy/docker/api copy.py b/deploy/docker/api copy.py new file mode 100644 index 00000000..341e23e1 --- /dev/null +++ b/deploy/docker/api copy.py @@ -0,0 +1,503 @@ +import os +import json +import asyncio +from typing import List, Tuple +from functools import partial + +import logging +from typing import Optional, AsyncGenerator +from urllib.parse import unquote +from fastapi import HTTPException, Request, status +from fastapi.background import BackgroundTasks +from fastapi.responses import JSONResponse +from redis import asyncio as aioredis + +from crawl4ai import ( + AsyncWebCrawler, + CrawlerRunConfig, + LLMExtractionStrategy, + CacheMode, + BrowserConfig, + MemoryAdaptiveDispatcher, + RateLimiter, + LLMConfig +) +from crawl4ai.utils import perform_completion_with_backoff +from crawl4ai.content_filter_strategy import ( + PruningContentFilter, + BM25ContentFilter, + LLMContentFilter +) +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator +from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy + +from utils import ( + TaskStatus, + FilterType, + get_base_url, + is_task_id, + should_cleanup_task, + decode_redis_hash +) + +import psutil, time + +logger = logging.getLogger(__name__) + +# --- Helper to get memory --- +def _get_memory_mb(): + try: + return psutil.Process().memory_info().rss / (1024 * 1024) + except Exception as e: + logger.warning(f"Could not get memory info: {e}") + return None + + +async def handle_llm_qa( + url: str, + query: str, + config: dict +) -> str: + """Process QA using LLM with crawled content as context.""" + try: + # Extract base URL by finding last '?q=' occurrence + last_q_index = url.rfind('?q=') + if last_q_index != -1: + url = url[:last_q_index] + + # Get markdown content + async with AsyncWebCrawler() as crawler: + result = await crawler.arun(url) + if not result.success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result.error_message + ) + content = result.markdown.fit_markdown + + # Create prompt and get LLM response + prompt = f"""Use the following content as context to answer the question. + Content: + {content} + + Question: {query} + + Answer:""" + + response = perform_completion_with_backoff( + provider=config["llm"]["provider"], + prompt_with_variables=prompt, + api_token=os.environ.get(config["llm"].get("api_key_env", "")) + ) + + return response.choices[0].message.content + except Exception as e: + logger.error(f"QA processing error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e) + ) + +async def process_llm_extraction( + redis: aioredis.Redis, + config: dict, + task_id: str, + url: str, + instruction: str, + schema: Optional[str] = None, + cache: str = "0" +) -> None: + """Process LLM extraction in background.""" + try: + # If config['llm'] has api_key then ignore the api_key_env + api_key = "" + if "api_key" in config["llm"]: + api_key = config["llm"]["api_key"] + else: + api_key = os.environ.get(config["llm"].get("api_key_env", None), "") + llm_strategy = LLMExtractionStrategy( + llm_config=LLMConfig( + provider=config["llm"]["provider"], + api_token=api_key + ), + instruction=instruction, + schema=json.loads(schema) if schema else None, + ) + + cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY + + async with AsyncWebCrawler() as crawler: + result = await crawler.arun( + url=url, + config=CrawlerRunConfig( + extraction_strategy=llm_strategy, + scraping_strategy=LXMLWebScrapingStrategy(), + cache_mode=cache_mode + ) + ) + + if not result.success: + await redis.hset(f"task:{task_id}", mapping={ + "status": TaskStatus.FAILED, + "error": result.error_message + }) + return + + try: + content = json.loads(result.extracted_content) + except json.JSONDecodeError: + content = result.extracted_content + await redis.hset(f"task:{task_id}", mapping={ + "status": TaskStatus.COMPLETED, + "result": json.dumps(content) + }) + + except Exception as e: + logger.error(f"LLM extraction error: {str(e)}", exc_info=True) + await redis.hset(f"task:{task_id}", mapping={ + "status": TaskStatus.FAILED, + "error": str(e) + }) + +async def handle_markdown_request( + url: str, + filter_type: FilterType, + query: Optional[str] = None, + cache: str = "0", + config: Optional[dict] = None +) -> str: + """Handle markdown generation requests.""" + try: + decoded_url = unquote(url) + if not decoded_url.startswith(('http://', 'https://')): + decoded_url = 'https://' + decoded_url + + if filter_type == FilterType.RAW: + md_generator = DefaultMarkdownGenerator() + else: + content_filter = { + FilterType.FIT: PruningContentFilter(), + FilterType.BM25: BM25ContentFilter(user_query=query or ""), + FilterType.LLM: LLMContentFilter( + llm_config=LLMConfig( + provider=config["llm"]["provider"], + api_token=os.environ.get(config["llm"].get("api_key_env", None), ""), + ), + instruction=query or "Extract main content" + ) + }[filter_type] + md_generator = DefaultMarkdownGenerator(content_filter=content_filter) + + cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY + + async with AsyncWebCrawler() as crawler: + result = await crawler.arun( + url=decoded_url, + config=CrawlerRunConfig( + markdown_generator=md_generator, + scraping_strategy=LXMLWebScrapingStrategy(), + cache_mode=cache_mode + ) + ) + + if not result.success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=result.error_message + ) + + return (result.markdown.raw_markdown + if filter_type == FilterType.RAW + else result.markdown.fit_markdown) + + except Exception as e: + logger.error(f"Markdown error: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e) + ) + +async def handle_llm_request( + redis: aioredis.Redis, + background_tasks: BackgroundTasks, + request: Request, + input_path: str, + query: Optional[str] = None, + schema: Optional[str] = None, + cache: str = "0", + config: Optional[dict] = None +) -> JSONResponse: + """Handle LLM extraction requests.""" + base_url = get_base_url(request) + + try: + if is_task_id(input_path): + return await handle_task_status( + redis, input_path, base_url + ) + + if not query: + return JSONResponse({ + "message": "Please provide an instruction", + "_links": { + "example": { + "href": f"{base_url}/llm/{input_path}?q=Extract+main+content", + "title": "Try this example" + } + } + }) + + return await create_new_task( + redis, + background_tasks, + input_path, + query, + schema, + cache, + base_url, + config + ) + + except Exception as e: + logger.error(f"LLM endpoint error: {str(e)}", exc_info=True) + return JSONResponse({ + "error": str(e), + "_links": { + "retry": {"href": str(request.url)} + } + }, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + +async def handle_task_status( + redis: aioredis.Redis, + task_id: str, + base_url: str +) -> JSONResponse: + """Handle task status check requests.""" + task = await redis.hgetall(f"task:{task_id}") + if not task: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Task not found" + ) + + task = decode_redis_hash(task) + response = create_task_response(task, task_id, base_url) + + if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]: + if should_cleanup_task(task["created_at"]): + await redis.delete(f"task:{task_id}") + + return JSONResponse(response) + +async def create_new_task( + redis: aioredis.Redis, + background_tasks: BackgroundTasks, + input_path: str, + query: str, + schema: Optional[str], + cache: str, + base_url: str, + config: dict +) -> JSONResponse: + """Create and initialize a new task.""" + decoded_url = unquote(input_path) + if not decoded_url.startswith(('http://', 'https://')): + decoded_url = 'https://' + decoded_url + + from datetime import datetime + task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}" + + await redis.hset(f"task:{task_id}", mapping={ + "status": TaskStatus.PROCESSING, + "created_at": datetime.now().isoformat(), + "url": decoded_url + }) + + background_tasks.add_task( + process_llm_extraction, + redis, + config, + task_id, + decoded_url, + query, + schema, + cache + ) + + return JSONResponse({ + "task_id": task_id, + "status": TaskStatus.PROCESSING, + "url": decoded_url, + "_links": { + "self": {"href": f"{base_url}/llm/{task_id}"}, + "status": {"href": f"{base_url}/llm/{task_id}"} + } + }) + +def create_task_response(task: dict, task_id: str, base_url: str) -> dict: + """Create response for task status check.""" + response = { + "task_id": task_id, + "status": task["status"], + "created_at": task["created_at"], + "url": task["url"], + "_links": { + "self": {"href": f"{base_url}/llm/{task_id}"}, + "refresh": {"href": f"{base_url}/llm/{task_id}"} + } + } + + if task["status"] == TaskStatus.COMPLETED: + response["result"] = json.loads(task["result"]) + elif task["status"] == TaskStatus.FAILED: + response["error"] = task["error"] + + return response + +async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]: + """Stream results with heartbeats and completion markers.""" + import json + from utils import datetime_handler + + try: + async for result in results_gen: + try: + server_memory_mb = _get_memory_mb() + result_dict = result.model_dump() + result_dict['server_memory_mb'] = server_memory_mb + logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}") + data = json.dumps(result_dict, default=datetime_handler) + "\n" + yield data.encode('utf-8') + except Exception as e: + logger.error(f"Serialization error: {e}") + error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')} + yield (json.dumps(error_response) + "\n").encode('utf-8') + + yield json.dumps({"status": "completed"}).encode('utf-8') + + except asyncio.CancelledError: + logger.warning("Client disconnected during streaming") + finally: + try: + await crawler.close() + except Exception as e: + logger.error(f"Crawler cleanup error: {e}") + +async def handle_crawl_request( + urls: List[str], + browser_config: dict, + crawler_config: dict, + config: dict +) -> dict: + """Handle non-streaming crawl requests.""" + start_mem_mb = _get_memory_mb() # <--- Get memory before + start_time = time.time() + mem_delta_mb = None + peak_mem_mb = start_mem_mb + + try: + browser_config = BrowserConfig.load(browser_config) + crawler_config = CrawlerRunConfig.load(crawler_config) + + dispatcher = MemoryAdaptiveDispatcher( + memory_threshold_percent=config["crawler"]["memory_threshold_percent"], + rate_limiter=RateLimiter( + base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"]) + ) + ) + + crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config) + await crawler.start() + results = [] + func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many") + partial_func = partial(func, + urls[0] if len(urls) == 1 else urls, + config=crawler_config, + dispatcher=dispatcher) + results = await partial_func() + await crawler.close() + + end_mem_mb = _get_memory_mb() # <--- Get memory after + end_time = time.time() + + if start_mem_mb is not None and end_mem_mb is not None: + mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta + peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory + logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB") + + return { + "success": True, + "results": [result.model_dump() for result in results], + "server_processing_time_s": end_time - start_time, + "server_memory_delta_mb": mem_delta_mb, + "server_peak_memory_mb": peak_mem_mb + } + + except Exception as e: + logger.error(f"Crawl error: {str(e)}", exc_info=True) + if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started + try: + await crawler.close() + except Exception as close_e: + logger.error(f"Error closing crawler during exception handling: {close_e}") + + # Measure memory even on error if possible + end_mem_mb_error = _get_memory_mb() + if start_mem_mb is not None and end_mem_mb_error is not None: + mem_delta_mb = end_mem_mb_error - start_mem_mb + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=json.dumps({ # Send structured error + "error": str(e), + "server_memory_delta_mb": mem_delta_mb, + "server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0) + }) + ) + +async def handle_stream_crawl_request( + urls: List[str], + browser_config: dict, + crawler_config: dict, + config: dict +) -> Tuple[AsyncWebCrawler, AsyncGenerator]: + """Handle streaming crawl requests.""" + try: + browser_config = BrowserConfig.load(browser_config) + # browser_config.verbose = True # Set to False or remove for production stress testing + browser_config.verbose = False + crawler_config = CrawlerRunConfig.load(crawler_config) + crawler_config.scraping_strategy = LXMLWebScrapingStrategy() + crawler_config.stream = True + + dispatcher = MemoryAdaptiveDispatcher( + memory_threshold_percent=config["crawler"]["memory_threshold_percent"], + rate_limiter=RateLimiter( + base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"]) + ) + ) + + crawler = AsyncWebCrawler(config=browser_config) + await crawler.start() + + results_gen = await crawler.arun_many( + urls=urls, + config=crawler_config, + dispatcher=dispatcher + ) + + return crawler, results_gen + + except Exception as e: + # Make sure to close crawler if started during an error here + if 'crawler' in locals() and crawler.ready: + try: + await crawler.close() + except Exception as close_e: + logger.error(f"Error closing crawler during stream setup exception: {close_e}") + logger.error(f"Stream crawl error: {str(e)}", exc_info=True) + # Raising HTTPException here will prevent streaming response + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e) + ) \ No newline at end of file diff --git a/deploy/docker/api.py b/deploy/docker/api.py index c01696b2..b226682f 100644 --- a/deploy/docker/api.py +++ b/deploy/docker/api.py @@ -40,8 +40,19 @@ from utils import ( decode_redis_hash ) +import psutil, time + logger = logging.getLogger(__name__) +# --- Helper to get memory --- +def _get_memory_mb(): + try: + return psutil.Process().memory_info().rss / (1024 * 1024) + except Exception as e: + logger.warning(f"Could not get memory info: {e}") + return None + + async def handle_llm_qa( url: str, query: str, @@ -351,7 +362,9 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) try: async for result in results_gen: try: + server_memory_mb = _get_memory_mb() result_dict = result.model_dump() + result_dict['server_memory_mb'] = server_memory_mb logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}") data = json.dumps(result_dict, default=datetime_handler) + "\n" yield data.encode('utf-8') @@ -364,19 +377,25 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) except asyncio.CancelledError: logger.warning("Client disconnected during streaming") - finally: - try: - await crawler.close() - except Exception as e: - logger.error(f"Crawler cleanup error: {e}") + # finally: + # try: + # await crawler.close() + # except Exception as e: + # logger.error(f"Crawler cleanup error: {e}") async def handle_crawl_request( + crawler: AsyncWebCrawler, urls: List[str], browser_config: dict, crawler_config: dict, config: dict ) -> dict: """Handle non-streaming crawl requests.""" + start_mem_mb = _get_memory_mb() # <--- Get memory before + start_time = time.time() + mem_delta_mb = None + peak_mem_mb = start_mem_mb + try: browser_config = BrowserConfig.load(browser_config) crawler_config = CrawlerRunConfig.load(crawler_config) @@ -388,31 +407,63 @@ async def handle_crawl_request( ) ) - crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config) - await crawler.start() + # crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config) + # await crawler.start() results = [] func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many") partial_func = partial(func, urls[0] if len(urls) == 1 else urls, config=crawler_config, dispatcher=dispatcher) + + # Simulate work being done by the crawler + # logger.debug(f"Request (URLs: {len(urls)}) starting simulated work...") # Add log + # await asyncio.sleep(2) # <--- ADD ARTIFICIAL DELAY (e.g., 0.5 seconds) + # logger.debug(f"Request (URLs: {len(urls)}) finished simulated work.") + results = await partial_func() - await crawler.close() + # await crawler.close() + + end_mem_mb = _get_memory_mb() # <--- Get memory after + end_time = time.time() + + if start_mem_mb is not None and end_mem_mb is not None: + mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta + peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory + logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB") + return { "success": True, - "results": [result.model_dump() for result in results] + "results": [result.model_dump() for result in results], + "server_processing_time_s": end_time - start_time, + "server_memory_delta_mb": mem_delta_mb, + "server_peak_memory_mb": peak_mem_mb } except Exception as e: logger.error(f"Crawl error: {str(e)}", exc_info=True) - if 'crawler' in locals(): - await crawler.close() + # if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started + # try: + # await crawler.close() + # except Exception as close_e: + # logger.error(f"Error closing crawler during exception handling: {close_e}") + + # Measure memory even on error if possible + end_mem_mb_error = _get_memory_mb() + if start_mem_mb is not None and end_mem_mb_error is not None: + mem_delta_mb = end_mem_mb_error - start_mem_mb + raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) + detail=json.dumps({ # Send structured error + "error": str(e), + "server_memory_delta_mb": mem_delta_mb, + "server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0) + }) ) async def handle_stream_crawl_request( + crawler: AsyncWebCrawler, urls: List[str], browser_config: dict, crawler_config: dict, @@ -421,9 +472,11 @@ async def handle_stream_crawl_request( """Handle streaming crawl requests.""" try: browser_config = BrowserConfig.load(browser_config) - browser_config.verbose = True + # browser_config.verbose = True # Set to False or remove for production stress testing + browser_config.verbose = False crawler_config = CrawlerRunConfig.load(crawler_config) crawler_config.scraping_strategy = LXMLWebScrapingStrategy() + crawler_config.stream = True dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=config["crawler"]["memory_threshold_percent"], @@ -432,8 +485,8 @@ async def handle_stream_crawl_request( ) ) - crawler = AsyncWebCrawler(config=browser_config) - await crawler.start() + # crawler = AsyncWebCrawler(config=browser_config) + # await crawler.start() results_gen = await crawler.arun_many( urls=urls, @@ -441,12 +494,19 @@ async def handle_stream_crawl_request( dispatcher=dispatcher ) + # Return the *same* crawler instance and the generator + # The caller (server.py) manages the crawler lifecycle via the pool context return crawler, results_gen except Exception as e: - if 'crawler' in locals(): - await crawler.close() + # Make sure to close crawler if started during an error here + # if 'crawler' in locals() and crawler.ready: + # try: + # await crawler.close() + # except Exception as close_e: + # logger.error(f"Error closing crawler during stream setup exception: {close_e}") logger.error(f"Stream crawl error: {str(e)}", exc_info=True) + # Raising HTTPException here will prevent streaming response raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) diff --git a/deploy/docker/config.yml b/deploy/docker/config.yml index 3b5fead6..17848e99 100644 --- a/deploy/docker/config.yml +++ b/deploy/docker/config.yml @@ -48,6 +48,38 @@ security: content_security_policy: "default-src 'self'" strict_transport_security: "max-age=63072000; includeSubDomains" +# Crawler Pool Configuration +crawler_pool: + enabled: true # Set to false to disable the pool + + # --- Option 1: Auto-calculate size --- + auto_calculate_size: true + calculation_params: + mem_headroom_mb: 512 # Memory reserved for OS/other apps + avg_page_mem_mb: 150 # Estimated MB per concurrent "tab"/page in browsers + fd_per_page: 20 # Estimated file descriptors per page + core_multiplier: 4 # Max crawlers per CPU core + min_pool_size: 2 # Minimum number of primary crawlers + max_pool_size: 16 # Maximum number of primary crawlers + + # --- Option 2: Manual size (ignored if auto_calculate_size is true) --- + # pool_size: 8 + + # --- Other Pool Settings --- + backup_pool_size: 1 # Number of backup crawlers + max_wait_time_s: 30.0 # Max seconds a request waits for a free crawler + throttle_threshold_percent: 70.0 # Start throttling delay above this % usage + throttle_delay_min_s: 0.1 # Min throttle delay + throttle_delay_max_s: 0.5 # Max throttle delay + + # --- Browser Config for Pooled Crawlers --- + browser_config: + # No need for "type": "BrowserConfig" here, just params + headless: true + verbose: false # Keep pool crawlers less verbose in production + # user_agent: "MyPooledCrawler/1.0" # Example + # Add other BrowserConfig params as needed (e.g., proxy, viewport) + # Crawler Configuration crawler: memory_threshold_percent: 95.0 @@ -61,6 +93,8 @@ crawler: logging: level: "INFO" format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: "logs/app.log" + verbose: true # Observability Configuration observability: diff --git a/deploy/docker/crawler_manager.py b/deploy/docker/crawler_manager.py new file mode 100644 index 00000000..b566e2d3 --- /dev/null +++ b/deploy/docker/crawler_manager.py @@ -0,0 +1,556 @@ +# crawler_manager.py +import asyncio +import time +import uuid +import psutil +import os +import resource # For FD limit +import random +import math +from typing import Optional, Tuple, Any, List, Dict, AsyncGenerator +from pydantic import BaseModel, Field, field_validator +from contextlib import asynccontextmanager +import logging + +from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, AsyncLogger +# Assuming api.py handlers are accessible or refactored slightly if needed +# We might need to import the specific handler functions if we call them directly +# from api import handle_crawl_request, handle_stream_crawl_request, _get_memory_mb, stream_results + +# --- Custom Exceptions --- +class PoolTimeoutError(Exception): + """Raised when waiting for a crawler resource times out.""" + pass + +class PoolConfigurationError(Exception): + """Raised for configuration issues.""" + pass + +class NoHealthyCrawlerError(Exception): + """Raised when no healthy crawler is available.""" + pass + + +# --- Configuration Models --- +class CalculationParams(BaseModel): + mem_headroom_mb: int = 512 + avg_page_mem_mb: int = 150 + fd_per_page: int = 20 + core_multiplier: int = 4 + min_pool_size: int = 1 # Min safe pages should be at least 1 + max_pool_size: int = 16 + + # V2 validation for avg_page_mem_mb + @field_validator('avg_page_mem_mb') + @classmethod + def check_avg_page_mem(cls, v: int) -> int: + if v <= 0: + raise ValueError("avg_page_mem_mb must be positive") + return v + + # V2 validation for fd_per_page + @field_validator('fd_per_page') + @classmethod + def check_fd_per_page(cls, v: int) -> int: + if v <= 0: + raise ValueError("fd_per_page must be positive") + return v + +# crawler_manager.py +# ... (imports including BaseModel, Field from pydantic) ... +from pydantic import BaseModel, Field, field_validator # <-- Import field_validator + +# --- Configuration Models (Pydantic V2 Syntax) --- +class CalculationParams(BaseModel): + mem_headroom_mb: int = 512 + avg_page_mem_mb: int = 150 + fd_per_page: int = 20 + core_multiplier: int = 4 + min_pool_size: int = 1 # Min safe pages should be at least 1 + max_pool_size: int = 16 + + # V2 validation for avg_page_mem_mb + @field_validator('avg_page_mem_mb') + @classmethod + def check_avg_page_mem(cls, v: int) -> int: + if v <= 0: + raise ValueError("avg_page_mem_mb must be positive") + return v + + # V2 validation for fd_per_page + @field_validator('fd_per_page') + @classmethod + def check_fd_per_page(cls, v: int) -> int: + if v <= 0: + raise ValueError("fd_per_page must be positive") + return v + +class CrawlerManagerConfig(BaseModel): + enabled: bool = True + auto_calculate_size: bool = True + calculation_params: CalculationParams = Field(default_factory=CalculationParams) # Use Field for default_factory + backup_pool_size: int = Field(1, ge=0) # Allow 0 backups + max_wait_time_s: float = 30.0 + throttle_threshold_percent: float = Field(70.0, ge=0, le=100) + throttle_delay_min_s: float = 0.1 + throttle_delay_max_s: float = 0.5 + browser_config: Dict[str, Any] = Field(default_factory=lambda: {"headless": True, "verbose": False}) # Use Field for default_factory + primary_reload_delay_s: float = 60.0 + +# --- Crawler Manager --- +class CrawlerManager: + """Manages shared AsyncWebCrawler instances, concurrency, and failover.""" + + def __init__(self, config: CrawlerManagerConfig, logger = None): + if not config.enabled: + self.logger.warning("CrawlerManager is disabled by configuration.") + # Set defaults to allow server to run, but manager won't function + self.config = config + self._initialized = False, + return + + self.config = config + self._primary_crawler: Optional[AsyncWebCrawler] = None + self._secondary_crawlers: List[AsyncWebCrawler] = [] + self._active_crawler_index: int = 0 # 0 for primary, 1+ for secondary index + self._primary_healthy: bool = False + self._secondary_healthy_flags: List[bool] = [] + + self._safe_pages: int = 1 # Default, calculated in initialize + self._semaphore: Optional[asyncio.Semaphore] = None + self._state_lock = asyncio.Lock() # Protects active_crawler, health flags + self._reload_tasks: List[Optional[asyncio.Task]] = [] # Track reload background tasks + + self._initialized = False + self._shutting_down = False + + # Initialize logger if provided + if logger is None: + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + else: + self.logger = logger + + self.logger.info("CrawlerManager initialized with config.") + self.logger.debug(f"Config: {self.config.model_dump_json(indent=2)}") + + def is_enabled(self) -> bool: + return self.config.enabled and self._initialized + + def _get_system_resources(self) -> Tuple[int, int, int]: + """Gets RAM, CPU cores, and FD limit.""" + total_ram_mb = 0 + cpu_cores = 0 + try: + mem_info = psutil.virtual_memory() + total_ram_mb = mem_info.total // (1024 * 1024) + cpu_cores = psutil.cpu_count(logical=False) or psutil.cpu_count(logical=True) # Prefer physical cores + except Exception as e: + self.logger.warning(f"Could not get RAM/CPU info via psutil: {e}") + total_ram_mb = 2048 # Default fallback + cpu_cores = 2 # Default fallback + + fd_limit = 1024 # Default fallback + try: + soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + fd_limit = soft_limit # Use the soft limit + except (ImportError, ValueError, OSError, AttributeError) as e: + self.logger.warning(f"Could not get file descriptor limit (common on Windows): {e}. Using default: {fd_limit}") + + self.logger.info(f"System Resources: RAM={total_ram_mb}MB, Cores={cpu_cores}, FD Limit={fd_limit}") + return total_ram_mb, cpu_cores, fd_limit + + def _calculate_safe_pages(self) -> int: + """Calculates the safe number of concurrent pages based on resources.""" + if not self.config.auto_calculate_size: + # If auto-calc is off, use max_pool_size as the hard limit + # This isn't ideal based on the prompt, but provides *some* manual override + # A dedicated `manual_safe_pages` might be better. Let's use max_pool_size for now. + self.logger.warning("Auto-calculation disabled. Using max_pool_size as safe_pages limit.") + return self.config.calculation_params.max_pool_size + + params = self.config.calculation_params + total_ram_mb, cpu_cores, fd_limit = self._get_system_resources() + + available_ram_mb = total_ram_mb - params.mem_headroom_mb + if available_ram_mb <= 0: + self.logger.error(f"Not enough RAM ({total_ram_mb}MB) after headroom ({params.mem_headroom_mb}MB). Cannot calculate safe pages.") + return params.min_pool_size # Fallback to minimum + + try: + # Calculate limits from each resource + mem_limit = available_ram_mb // params.avg_page_mem_mb if params.avg_page_mem_mb > 0 else float('inf') + fd_limit_pages = fd_limit // params.fd_per_page if params.fd_per_page > 0 else float('inf') + cpu_limit = cpu_cores * params.core_multiplier if cpu_cores > 0 else float('inf') + + # Determine the most constraining limit + calculated_limit = math.floor(min(mem_limit, fd_limit_pages, cpu_limit)) + + except ZeroDivisionError: + self.logger.error("Division by zero in safe_pages calculation (avg_page_mem_mb or fd_per_page is zero).") + calculated_limit = params.min_pool_size # Fallback + + # Clamp the result within min/max bounds + safe_pages = max(params.min_pool_size, min(calculated_limit, params.max_pool_size)) + + self.logger.info(f"Calculated safe pages: MemoryLimit={mem_limit}, FDLimit={fd_limit_pages}, CPULimit={cpu_limit} -> RawCalc={calculated_limit} -> Clamped={safe_pages}") + return safe_pages + + async def _create_and_start_crawler(self, crawler_id: str) -> Optional[AsyncWebCrawler]: + """Creates, starts, and returns a crawler instance.""" + try: + # Create BrowserConfig from the dictionary in manager config + browser_conf = BrowserConfig(**self.config.browser_config) + crawler = AsyncWebCrawler(config=browser_conf) + await crawler.start() + self.logger.info(f"Successfully started crawler instance: {crawler_id}") + return crawler + except Exception as e: + self.logger.error(f"Failed to start crawler instance {crawler_id}: {e}", exc_info=True) + return None + + async def initialize(self): + """Initializes crawlers and semaphore. Called at server startup.""" + if not self.config.enabled or self._initialized: + return + + self.logger.info("Initializing CrawlerManager...") + self._safe_pages = self._calculate_safe_pages() + self._semaphore = asyncio.Semaphore(self._safe_pages) + + self._primary_crawler = await self._create_and_start_crawler("Primary") + if self._primary_crawler: + self._primary_healthy = True + else: + self._primary_healthy = False + self.logger.critical("Primary crawler failed to initialize!") + + self._secondary_crawlers = [] + self._secondary_healthy_flags = [] + self._reload_tasks = [None] * (1 + self.config.backup_pool_size) # For primary + backups + + for i in range(self.config.backup_pool_size): + sec_id = f"Secondary-{i+1}" + crawler = await self._create_and_start_crawler(sec_id) + self._secondary_crawlers.append(crawler) # Add even if None + self._secondary_healthy_flags.append(crawler is not None) + if crawler is None: + self.logger.error(f"{sec_id} crawler failed to initialize!") + + # Set initial active crawler (prefer primary) + if self._primary_healthy: + self._active_crawler_index = 0 + self.logger.info("Primary crawler is active.") + else: + # Find the first healthy secondary + found_healthy_backup = False + for i, healthy in enumerate(self._secondary_healthy_flags): + if healthy: + self._active_crawler_index = i + 1 # 1-based index for secondaries + self.logger.warning(f"Primary failed, Secondary-{i+1} is active.") + found_healthy_backup = True + break + if not found_healthy_backup: + self.logger.critical("FATAL: No healthy crawlers available after initialization!") + # Server should probably refuse connections in this state + + self._initialized = True + self.logger.info(f"CrawlerManager initialized. Safe Pages: {self._safe_pages}. Active Crawler Index: {self._active_crawler_index}") + + async def shutdown(self): + """Shuts down all crawler instances. Called at server shutdown.""" + if not self._initialized or self._shutting_down: + return + + self._shutting_down = True + self.logger.info("Shutting down CrawlerManager...") + + # Cancel any ongoing reload tasks + for i, task in enumerate(self._reload_tasks): + if task and not task.done(): + try: + task.cancel() + await task # Wait for cancellation + self.logger.info(f"Cancelled reload task for crawler index {i}.") + except asyncio.CancelledError: + self.logger.info(f"Reload task for crawler index {i} was already cancelled.") + except Exception as e: + self.logger.warning(f"Error cancelling reload task for crawler index {i}: {e}") + self._reload_tasks = [] + + + # Close primary + if self._primary_crawler: + try: + self.logger.info("Closing primary crawler...") + await self._primary_crawler.close() + self._primary_crawler = None + except Exception as e: + self.logger.error(f"Error closing primary crawler: {e}", exc_info=True) + + # Close secondaries + for i, crawler in enumerate(self._secondary_crawlers): + if crawler: + try: + self.logger.info(f"Closing secondary crawler {i+1}...") + await crawler.close() + except Exception as e: + self.logger.error(f"Error closing secondary crawler {i+1}: {e}", exc_info=True) + self._secondary_crawlers = [] + + self._initialized = False + self.logger.info("CrawlerManager shut down complete.") + + @asynccontextmanager + async def get_crawler(self) -> AsyncGenerator[AsyncWebCrawler, None]: + """Acquires semaphore, yields active crawler, handles throttling & failover.""" + if not self.is_enabled(): + raise NoHealthyCrawlerError("CrawlerManager is disabled or not initialized.") + + if self._shutting_down: + raise NoHealthyCrawlerError("CrawlerManager is shutting down.") + + active_crawler: Optional[AsyncWebCrawler] = None + acquired = False + request_id = uuid.uuid4() + start_wait = time.time() + + # --- Throttling --- + try: + # Check semaphore value without acquiring + current_usage = self._safe_pages - self._semaphore._value + usage_percent = (current_usage / self._safe_pages) * 100 if self._safe_pages > 0 else 0 + + if usage_percent >= self.config.throttle_threshold_percent: + delay = random.uniform(self.config.throttle_delay_min_s, self.config.throttle_delay_max_s) + self.logger.debug(f"Throttling: Usage {usage_percent:.1f}% >= {self.config.throttle_threshold_percent}%. Delaying {delay:.3f}s") + await asyncio.sleep(delay) + except Exception as e: + self.logger.warning(f"Error during throttling check: {e}") # Continue attempt even if throttle check fails + + # --- Acquire Semaphore --- + try: + # self.logger.debug(f"Attempting to acquire semaphore (Available: {self._semaphore._value}/{self._safe_pages}). Wait Timeout: {self.config.max_wait_time_s}s") + + # --- Logging Before Acquire --- + sem_value = self._semaphore._value if self._semaphore else 'N/A' + sem_waiters = len(self._semaphore._waiters) if self._semaphore and self._semaphore._waiters else 0 + self.logger.debug(f"Req {request_id}: Attempting acquire. Available={sem_value}/{self._safe_pages}, Waiters={sem_waiters}, Timeout={self.config.max_wait_time_s}s") + + await asyncio.wait_for( + self._semaphore.acquire(), timeout=self.config.max_wait_time_s + ) + acquired = True + wait_duration = time.time() - start_wait + if wait_duration > 1: + self.logger.warning(f"Semaphore acquired after {wait_duration:.3f}s. (Available: {self._semaphore._value}/{self._safe_pages})") + + self.logger.debug(f"Semaphore acquired successfully after {wait_duration:.3f}s. (Available: {self._semaphore._value}/{self._safe_pages})") + + # --- Select Active Crawler (Critical Section) --- + async with self._state_lock: + current_active_index = self._active_crawler_index + is_primary_active = (current_active_index == 0) + + if is_primary_active: + if self._primary_healthy and self._primary_crawler: + active_crawler = self._primary_crawler + else: + # Primary is supposed to be active but isn't healthy + self.logger.warning("Primary crawler unhealthy, attempting immediate failover...") + if not await self._try_failover_sync(): # Try to switch active crawler NOW + raise NoHealthyCrawlerError("Primary unhealthy and no healthy backup available.") + # If failover succeeded, active_crawler_index is updated + current_active_index = self._active_crawler_index + # Fall through to select the new active secondary + + # Check if we need to use a secondary (either initially or after failover) + if current_active_index > 0: + secondary_idx = current_active_index - 1 + if secondary_idx < len(self._secondary_crawlers) and \ + self._secondary_healthy_flags[secondary_idx] and \ + self._secondary_crawlers[secondary_idx]: + active_crawler = self._secondary_crawlers[secondary_idx] + else: + self.logger.error(f"Selected Secondary-{current_active_index} is unhealthy or missing.") + # Attempt failover to *another* secondary if possible? (Adds complexity) + # For now, raise error if the selected one isn't good. + raise NoHealthyCrawlerError(f"Selected Secondary-{current_active_index} is unavailable.") + + if active_crawler is None: + # This shouldn't happen if logic above is correct, but safeguard + raise NoHealthyCrawlerError("Failed to select a healthy active crawler.") + + # --- Yield Crawler --- + try: + yield active_crawler + except Exception as crawl_error: + self.logger.error(f"Error during crawl execution using {active_crawler}: {crawl_error}", exc_info=True) + # Determine if this error warrants failover + # For now, let's assume any exception triggers a health check/failover attempt + await self._handle_crawler_failure(active_crawler) + raise # Re-raise the original error for the API handler + + except asyncio.TimeoutError: + self.logger.warning(f"Timeout waiting for semaphore after {self.config.max_wait_time_s}s.") + raise PoolTimeoutError(f"Timed out waiting for available crawler resource after {self.config.max_wait_time_s}s") + except NoHealthyCrawlerError: + # Logged within the selection logic + raise # Re-raise for API handler + except Exception as e: + self.logger.error(f"Unexpected error in get_crawler context manager: {e}", exc_info=True) + raise # Re-raise potentially unknown errors + finally: + if acquired: + self._semaphore.release() + self.logger.debug(f"Semaphore released. (Available: {self._semaphore._value}/{self._safe_pages})") + + + async def _try_failover_sync(self) -> bool: + """Synchronous part of failover logic (must be called under state_lock). Finds next healthy secondary.""" + if not self._primary_healthy: # Only failover if primary is already marked down + found_healthy_backup = False + start_idx = (self._active_crawler_index % (self.config.backup_pool_size +1)) # Start check after current + for i in range(self.config.backup_pool_size): + check_idx = (start_idx + i) % self.config.backup_pool_size # Circular check + if self._secondary_healthy_flags[check_idx] and self._secondary_crawlers[check_idx]: + self._active_crawler_index = check_idx + 1 + self.logger.warning(f"Failover successful: Switched active crawler to Secondary-{self._active_crawler_index}") + found_healthy_backup = True + break # Found one + if not found_healthy_backup: + # If primary is down AND no backups are healthy, mark primary as active index (0) but it's still unhealthy + self._active_crawler_index = 0 + self.logger.error("Failover failed: No healthy secondary crawlers available.") + return False + return True + return True # Primary is healthy, no failover needed + + async def _handle_crawler_failure(self, failed_crawler: AsyncWebCrawler): + """Handles marking a crawler as unhealthy and initiating recovery.""" + if self._shutting_down: return # Don't handle failures during shutdown + + async with self._state_lock: + crawler_index = -1 + is_primary = False + + if failed_crawler is self._primary_crawler and self._primary_healthy: + self.logger.warning("Primary crawler reported failure.") + self._primary_healthy = False + is_primary = True + crawler_index = 0 + # Try immediate failover within the lock + await self._try_failover_sync() + # Start reload task if not already running for primary + if self._reload_tasks[0] is None or self._reload_tasks[0].done(): + self.logger.info("Initiating primary crawler reload task.") + self._reload_tasks[0] = asyncio.create_task(self._reload_crawler(0)) + + else: + # Check if it was one of the secondaries + for i, crawler in enumerate(self._secondary_crawlers): + if failed_crawler is crawler and self._secondary_healthy_flags[i]: + self.logger.warning(f"Secondary-{i+1} crawler reported failure.") + self._secondary_healthy_flags[i] = False + is_primary = False + crawler_index = i + 1 + # If this *was* the active crawler, trigger failover check + if self._active_crawler_index == crawler_index: + self.logger.warning(f"Active secondary {crawler_index} failed, attempting failover...") + await self._try_failover_sync() + # Start reload task for this secondary + if self._reload_tasks[crawler_index] is None or self._reload_tasks[crawler_index].done(): + self.logger.info(f"Initiating Secondary-{i+1} crawler reload task.") + self._reload_tasks[crawler_index] = asyncio.create_task(self._reload_crawler(crawler_index)) + break # Found the failed secondary + + if crawler_index == -1: + self.logger.debug("Failure reported by an unknown or already unhealthy crawler instance. Ignoring.") + + + async def _reload_crawler(self, crawler_index_to_reload: int): + """Background task to close, recreate, and start a specific crawler.""" + is_primary = (crawler_index_to_reload == 0) + crawler_id = "Primary" if is_primary else f"Secondary-{crawler_index_to_reload}" + original_crawler = self._primary_crawler if is_primary else self._secondary_crawlers[crawler_index_to_reload - 1] + + self.logger.info(f"Starting reload process for {crawler_id}...") + + # 1. Delay before attempting reload (e.g., allow transient issues to clear) + if not is_primary: # Maybe shorter delay for backups? + await asyncio.sleep(self.config.primary_reload_delay_s / 2) + else: + await asyncio.sleep(self.config.primary_reload_delay_s) + + + # 2. Attempt to close the old instance cleanly + if original_crawler: + try: + self.logger.info(f"Attempting to close existing {crawler_id} instance...") + await original_crawler.close() + self.logger.info(f"Successfully closed old {crawler_id} instance.") + except Exception as e: + self.logger.warning(f"Error closing old {crawler_id} instance during reload: {e}") + + # 3. Create and start a new instance + self.logger.info(f"Attempting to start new {crawler_id} instance...") + new_crawler = await self._create_and_start_crawler(crawler_id) + + # 4. Update state if successful + async with self._state_lock: + if new_crawler: + self.logger.info(f"Successfully reloaded {crawler_id}. Marking as healthy.") + if is_primary: + self._primary_crawler = new_crawler + self._primary_healthy = True + # Switch back to primary if no other failures occurred + # Check if ANY secondary is currently active + secondary_is_active = self._active_crawler_index > 0 + if not secondary_is_active or not self._secondary_healthy_flags[self._active_crawler_index - 1]: + self.logger.info("Switching active crawler back to primary.") + self._active_crawler_index = 0 + else: # Is secondary + secondary_idx = crawler_index_to_reload - 1 + self._secondary_crawlers[secondary_idx] = new_crawler + self._secondary_healthy_flags[secondary_idx] = True + # Potentially switch back if primary is still down and this was needed? + if not self._primary_healthy and self._active_crawler_index == 0: + self.logger.info(f"Primary still down, activating reloaded Secondary-{crawler_index_to_reload}.") + self._active_crawler_index = crawler_index_to_reload + + else: + self.logger.error(f"Failed to reload {crawler_id}. It remains unhealthy.") + # Keep the crawler marked as unhealthy + if is_primary: + self._primary_healthy = False # Ensure it stays false + else: + self._secondary_healthy_flags[crawler_index_to_reload - 1] = False + + + # Clear the reload task reference for this index + self._reload_tasks[crawler_index_to_reload] = None + + + async def get_status(self) -> Dict: + """Returns the current status of the manager.""" + if not self.is_enabled(): + return {"status": "disabled"} + + async with self._state_lock: + active_id = "Primary" if self._active_crawler_index == 0 else f"Secondary-{self._active_crawler_index}" + primary_status = "Healthy" if self._primary_healthy else "Unhealthy" + secondary_statuses = [f"Secondary-{i+1}: {'Healthy' if healthy else 'Unhealthy'}" + for i, healthy in enumerate(self._secondary_healthy_flags)] + semaphore_available = self._semaphore._value if self._semaphore else 'N/A' + semaphore_locked = len(self._semaphore._waiters) if self._semaphore and self._semaphore._waiters else 0 + + return { + "status": "enabled", + "safe_pages": self._safe_pages, + "semaphore_available": semaphore_available, + "semaphore_waiters": semaphore_locked, + "active_crawler": active_id, + "primary_status": primary_status, + "secondary_statuses": secondary_statuses, + "reloading_tasks": [i for i, t in enumerate(self._reload_tasks) if t and not t.done()] + } \ No newline at end of file diff --git a/deploy/docker/server.py b/deploy/docker/server.py index edb55130..f577348b 100644 --- a/deploy/docker/server.py +++ b/deploy/docker/server.py @@ -1,8 +1,20 @@ +# Import from auth.py +from auth import create_access_token, get_token_dependency, TokenRequest +from api import ( + handle_markdown_request, + handle_llm_qa, + handle_stream_crawl_request, + handle_crawl_request, + stream_results, + _get_memory_mb +) +from utils import FilterType, load_config, setup_logging, verify_email_domain import os import sys import time -from typing import List, Optional, Dict -from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends +from typing import List, Optional, Dict, AsyncGenerator +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends, status from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware @@ -11,28 +23,39 @@ from slowapi import Limiter from slowapi.util import get_remote_address from prometheus_fastapi_instrumentator import Instrumentator from redis import asyncio as aioredis +from crawl4ai import ( + BrowserConfig, + CrawlerRunConfig, + AsyncLogger +) + +from crawler_manager import ( + CrawlerManager, + CrawlerManagerConfig, + PoolTimeoutError, + NoHealthyCrawlerError +) + sys.path.append(os.path.dirname(os.path.realpath(__file__))) -from utils import FilterType, load_config, setup_logging, verify_email_domain -from api import ( - handle_markdown_request, - handle_llm_qa, - handle_stream_crawl_request, - handle_crawl_request, - stream_results -) -from auth import create_access_token, get_token_dependency, TokenRequest # Import from auth.py __version__ = "0.2.6" + class CrawlRequest(BaseModel): urls: List[str] = Field(min_length=1, max_length=100) browser_config: Optional[Dict] = Field(default_factory=dict) crawler_config: Optional[Dict] = Field(default_factory=dict) + # Load configuration and setup config = load_config() setup_logging(config) +logger = AsyncLogger( + log_file=config["logging"].get("log_file", "app.log"), + verbose=config["logging"].get("verbose", False), + tag_width=10, +) # Initialize Redis redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost")) @@ -44,9 +67,43 @@ limiter = Limiter( storage_uri=config["rate_limiting"]["storage_uri"] ) +# --- Initialize Manager (will be done in lifespan) --- +# Load manager config from the main config +manager_config_dict = config.get("crawler_pool", {}) +# Use Pydantic to parse and validate +manager_config = CrawlerManagerConfig(**manager_config_dict) +crawler_manager = CrawlerManager(config=manager_config, logger=logger) + +# --- FastAPI App and Lifespan --- + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + logger.info("Starting up the server...") + if manager_config.enabled: + logger.info("Initializing Crawler Manager...") + await crawler_manager.initialize() + app.state.crawler_manager = crawler_manager # Store manager in app state + logger.info("Crawler Manager is enabled.") + else: + logger.warning("Crawler Manager is disabled.") + app.state.crawler_manager = None # Indicate disabled state + + yield # Server runs here + + # Shutdown + logger.info("Shutting down server...") + if app.state.crawler_manager: + logger.info("Shutting down Crawler Manager...") + await app.state.crawler_manager.shutdown() + logger.info("Crawler Manager shut down.") + logger.info("Server shut down.") + app = FastAPI( title=config["app"]["title"], - version=config["app"]["version"] + version=config["app"]["version"], + lifespan=lifespan, ) # Configure middleware @@ -56,7 +113,9 @@ def setup_security_middleware(app, config): if sec_config.get("https_redirect", False): app.add_middleware(HTTPSRedirectMiddleware) if sec_config.get("trusted_hosts", []) != ["*"]: - app.add_middleware(TrustedHostMiddleware, allowed_hosts=sec_config["trusted_hosts"]) + app.add_middleware(TrustedHostMiddleware, + allowed_hosts=sec_config["trusted_hosts"]) + setup_security_middleware(app, config) @@ -68,6 +127,8 @@ if config["observability"]["prometheus"]["enabled"]: token_dependency = get_token_dependency(config) # Middleware for security headers + + @app.middleware("http") async def add_security_headers(request: Request, call_next): response = await call_next(request) @@ -75,7 +136,24 @@ async def add_security_headers(request: Request, call_next): response.headers.update(config["security"]["headers"]) return response + +async def get_manager() -> CrawlerManager: + # Ensure manager exists and is enabled before yielding + if not hasattr(app.state, 'crawler_manager') or app.state.crawler_manager is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Crawler service is disabled or not initialized" + ) + if not app.state.crawler_manager.is_enabled(): + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Crawler service is currently disabled" + ) + return app.state.crawler_manager + # Token endpoint (always available, but usage depends on config) + + @app.post("/token") async def get_token(request_data: TokenRequest): if not verify_email_domain(request_data.email): @@ -84,6 +162,8 @@ async def get_token(request_data: TokenRequest): return {"email": request_data.email, "access_token": token, "token_type": "bearer"} # Endpoints with conditional auth + + @app.get("/md/{url:path}") @limiter.limit(config["rate_limiting"]["default_limit"]) async def get_markdown( @@ -97,6 +177,7 @@ async def get_markdown( result = await handle_markdown_request(url, f, q, c, config) return PlainTextResponse(result) + @app.get("/llm/{url:path}", description="URL should be without http/https prefix") async def llm_endpoint( request: Request, @@ -105,7 +186,8 @@ async def llm_endpoint( token_data: Optional[Dict] = Depends(token_dependency) ): if not q: - raise HTTPException(status_code=400, detail="Query parameter 'q' is required") + raise HTTPException( + status_code=400, detail="Query parameter 'q' is required") if not url.startswith(('http://', 'https://')): url = 'https://' + url try: @@ -114,37 +196,89 @@ async def llm_endpoint( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.get("/schema") async def get_schema(): from crawl4ai import BrowserConfig, CrawlerRunConfig return {"browser": BrowserConfig().dump(), "crawler": CrawlerRunConfig().dump()} + @app.get(config["observability"]["health_check"]["endpoint"]) async def health(): return {"status": "ok", "timestamp": time.time(), "version": __version__} + @app.get(config["observability"]["prometheus"]["endpoint"]) async def metrics(): return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"]) + +@app.get("/browswers") +# Optional dependency +async def health(manager: Optional[CrawlerManager] = Depends(get_manager, use_cache=False)): + base_status = {"status": "ok", "timestamp": time.time(), + "version": __version__} + if manager: + try: + manager_status = await manager.get_status() + base_status["crawler_manager"] = manager_status + except Exception as e: + base_status["crawler_manager"] = { + "status": "error", "detail": str(e)} + else: + base_status["crawler_manager"] = {"status": "disabled"} + return base_status + + @app.post("/crawl") @limiter.limit(config["rate_limiting"]["default_limit"]) async def crawl( request: Request, crawl_request: CrawlRequest, - token_data: Optional[Dict] = Depends(token_dependency) + manager: CrawlerManager = Depends(get_manager), # Use dependency + token_data: Optional[Dict] = Depends(token_dependency) # Keep auth ): if not crawl_request.urls: - raise HTTPException(status_code=400, detail="At least one URL required") - - results = await handle_crawl_request( - urls=crawl_request.urls, - browser_config=crawl_request.browser_config, - crawler_config=crawl_request.crawler_config, - config=config - ) + raise HTTPException( + status_code=400, detail="At least one URL required") - return JSONResponse(results) + try: + # Use the manager's context to get a crawler instance + async with manager.get_crawler() as active_crawler: + # Call the actual handler from api.py, passing the acquired crawler + results_dict = await handle_crawl_request( + crawler=active_crawler, # Pass the live crawler instance + urls=crawl_request.urls, + # Pass user-provided configs, these might override pool defaults if needed + # Or the manager/handler could decide how to merge them + browser_config=crawl_request.browser_config or {}, # Ensure dict + crawler_config=crawl_request.crawler_config or {}, # Ensure dict + config=config # Pass the global server config + ) + return JSONResponse(results_dict) + + except PoolTimeoutError as e: + logger.warning(f"Request rejected due to pool timeout: {e}") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, # Or 429 + detail=f"Crawler resources busy. Please try again later. Timeout: {e}" + ) + except NoHealthyCrawlerError as e: + logger.error(f"Request failed as no healthy crawler available: {e}") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=f"Crawler service temporarily unavailable: {e}" + ) + except HTTPException: # Re-raise HTTP exceptions from handler + raise + except Exception as e: + logger.error( + f"Unexpected error during batch crawl processing: {e}", exc_info=True) + # Return generic error, details might be logged by handle_crawl_request + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"An unexpected error occurred: {e}" + ) @app.post("/crawl/stream") @@ -152,23 +286,114 @@ async def crawl( async def crawl_stream( request: Request, crawl_request: CrawlRequest, + manager: CrawlerManager = Depends(get_manager), token_data: Optional[Dict] = Depends(token_dependency) ): if not crawl_request.urls: - raise HTTPException(status_code=400, detail="At least one URL required") + raise HTTPException( + status_code=400, detail="At least one URL required") - crawler, results_gen = await handle_stream_crawl_request( - urls=crawl_request.urls, - browser_config=crawl_request.browser_config, - crawler_config=crawl_request.crawler_config, - config=config - ) + try: + # THIS IS A BIT WORK OF ART RATHER THAN ENGINEERING + # Acquire the crawler context from the manager + # IMPORTANT: The context needs to be active for the *duration* of the stream + # This structure might be tricky with FastAPI's StreamingResponse which consumes + # the generator *after* the endpoint function returns. - return StreamingResponse( - stream_results(crawler, results_gen), - media_type='application/x-ndjson', - headers={'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Stream-Status': 'active'} - ) + # --- Option A: Acquire crawler, pass to handler, handler yields --- + # (Requires handler NOT to be async generator itself, but return one) + # async with manager.get_crawler() as active_crawler: + # # Handler returns the generator + # _, results_gen = await handle_stream_crawl_request( + # crawler=active_crawler, + # urls=crawl_request.urls, + # browser_config=crawl_request.browser_config or {}, + # crawler_config=crawl_request.crawler_config or {}, + # config=config + # ) + # # PROBLEM: `active_crawler` context exits before StreamingResponse uses results_gen + # # This releases the semaphore too early. + + # --- Option B: Pass manager to handler, handler uses context internally --- + # (Requires modifying handle_stream_crawl_request signature/logic) + # This seems cleaner. Let's assume api.py is adapted for this. + # We need a way for the generator yielded by stream_results to know when + # to release the semaphore. + + # --- Option C: Create a wrapper generator that handles context --- + async def stream_wrapper(manager: CrawlerManager, crawl_request: CrawlRequest, config: dict) -> AsyncGenerator[bytes, None]: + active_crawler = None + try: + async with manager.get_crawler() as acquired_crawler: + active_crawler = acquired_crawler # Keep reference for cleanup + # Call the handler which returns the raw result generator + _crawler_ref, results_gen = await handle_stream_crawl_request( + crawler=acquired_crawler, + urls=crawl_request.urls, + browser_config=crawl_request.browser_config or {}, + crawler_config=crawl_request.crawler_config or {}, + config=config + ) + # Use the stream_results utility to format and yield + async for data_bytes in stream_results(_crawler_ref, results_gen): + yield data_bytes + except (PoolTimeoutError, NoHealthyCrawlerError) as e: + # Yield a final error message in the stream + error_payload = {"status": "error", "detail": str(e)} + yield (json.dumps(error_payload) + "\n").encode('utf-8') + logger.warning(f"Stream request failed: {e}") + # Re-raise might be better if StreamingResponse handles it? Test needed. + except HTTPException as e: # Catch HTTP exceptions from handler setup + error_payload = {"status": "error", + "detail": e.detail, "status_code": e.status_code} + yield (json.dumps(error_payload) + "\n").encode('utf-8') + logger.warning( + f"Stream request failed with HTTPException: {e.detail}") + except Exception as e: + error_payload = {"status": "error", + "detail": f"Unexpected stream error: {e}"} + yield (json.dumps(error_payload) + "\n").encode('utf-8') + logger.error( + f"Unexpected error during stream processing: {e}", exc_info=True) + # finally: + # Ensure crawler cleanup if stream_results doesn't handle it? + # stream_results *should* call crawler.close(), but only on the + # instance it received. If we pass the *manager* instead, this gets complex. + # Let's stick to passing the acquired_crawler and rely on stream_results. + + # Create the generator using the wrapper + streaming_generator = stream_wrapper(manager, crawl_request, config) + + return StreamingResponse( + streaming_generator, # Use the wrapper + media_type='application/x-ndjson', + headers={'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', 'X-Stream-Status': 'active'} + ) + + except (PoolTimeoutError, NoHealthyCrawlerError) as e: + # These might occur if get_crawler fails *before* stream starts + # Or if the wrapper re-raises them. + logger.warning(f"Stream request rejected before starting: {e}") + status_code = status.HTTP_503_SERVICE_UNAVAILABLE # Or 429 for timeout + # Don't raise HTTPException here, let the wrapper yield the error message. + # If we want to return a non-200 initial status, need more complex handling. + # Return an *empty* stream with error headers? Or just let wrapper yield error. + + async def _error_stream(): + error_payload = {"status": "error", "detail": str(e)} + yield (json.dumps(error_payload) + "\n").encode('utf-8') + return StreamingResponse(_error_stream(), status_code=status_code, media_type='application/x-ndjson') + + except HTTPException: # Re-raise HTTP exceptions from setup + raise + except Exception as e: + logger.error( + f"Unexpected error setting up stream crawl: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"An unexpected error occurred setting up the stream: {e}" + ) if __name__ == "__main__": import uvicorn @@ -178,4 +403,4 @@ if __name__ == "__main__": port=config["app"]["port"], reload=config["app"]["reload"], timeout_keep_alive=config["app"]["timeout_keep_alive"] - ) \ No newline at end of file + ) diff --git a/tests/memory/test_stress_api.py b/tests/memory/test_stress_api.py new file mode 100644 index 00000000..232964c1 --- /dev/null +++ b/tests/memory/test_stress_api.py @@ -0,0 +1,516 @@ +#!/usr/bin/env python3 +""" +Stress test for Crawl4AI's Docker API server (/crawl and /crawl/stream endpoints). + +This version targets a running Crawl4AI API server, sending concurrent requests +to test its ability to handle multiple crawl jobs simultaneously. +It uses httpx for async HTTP requests and logs results per batch of requests, +including server-side memory usage reported by the API. +""" + +import asyncio +import time +import uuid +import argparse +import json +import sys +import os +import shutil +from typing import List, Dict, Optional, Union, AsyncGenerator, Tuple +import httpx +import pathlib # Import pathlib explicitly +from rich.console import Console +from rich.panel import Panel +from rich.syntax import Syntax + +# --- Constants --- +# DEFAULT_API_URL = "http://localhost:11235" # Default port +DEFAULT_API_URL = "http://localhost:8020" # Default port +DEFAULT_URL_COUNT = 1000 +DEFAULT_MAX_CONCURRENT_REQUESTS = 5 +DEFAULT_CHUNK_SIZE = 10 +DEFAULT_REPORT_PATH = "reports_api" +DEFAULT_STREAM_MODE = False +REQUEST_TIMEOUT = 180.0 + +# Initialize Rich console +console = Console() + +# --- API Health Check (Unchanged) --- +async def check_server_health(client: httpx.AsyncClient, health_endpoint: str = "/health"): + """Check if the API server is healthy.""" + console.print(f"[bold cyan]Checking API server health at {client.base_url}{health_endpoint}...[/]", end="") + try: + response = await client.get(health_endpoint, timeout=10.0) + response.raise_for_status() + health_data = response.json() + version = health_data.get('version', 'N/A') + console.print(f"[bold green] Server OK! Version: {version}[/]") + return True + except (httpx.RequestError, httpx.HTTPStatusError) as e: + console.print(f"\n[bold red]Server health check FAILED:[/]") + console.print(f"Error: {e}") + console.print(f"Is the server running and accessible at {client.base_url}?") + return False + except Exception as e: + console.print(f"\n[bold red]An unexpected error occurred during health check:[/]") + console.print(e) + return False + +# --- API Stress Test Class --- +class ApiStressTest: + """Orchestrates the stress test by sending concurrent requests to the API.""" + + def __init__( + self, + api_url: str, + url_count: int, + max_concurrent_requests: int, + chunk_size: int, + report_path: str, + stream_mode: bool, + ): + self.api_base_url = api_url.rstrip('/') + self.url_count = url_count + self.max_concurrent_requests = max_concurrent_requests + self.chunk_size = chunk_size + self.report_path = pathlib.Path(report_path) + self.report_path.mkdir(parents=True, exist_ok=True) + self.stream_mode = stream_mode + + self.test_id = time.strftime("%Y%m%d_%H%M%S") + self.results_summary = { + "test_id": self.test_id, "api_url": api_url, "url_count": url_count, + "max_concurrent_requests": max_concurrent_requests, "chunk_size": chunk_size, + "stream_mode": stream_mode, "start_time": "", "end_time": "", + "total_time_seconds": 0, "successful_requests": 0, "failed_requests": 0, + "successful_urls": 0, "failed_urls": 0, "total_urls_processed": 0, + "total_api_calls": 0, + "server_memory_metrics": { # To store aggregated server memory info + "batch_mode_avg_delta_mb": None, + "batch_mode_max_delta_mb": None, + "stream_mode_avg_max_snapshot_mb": None, + "stream_mode_max_max_snapshot_mb": None, + "samples": [] # Store individual request memory results + } + } + self.http_client = httpx.AsyncClient(base_url=self.api_base_url, timeout=REQUEST_TIMEOUT, limits=httpx.Limits(max_connections=max_concurrent_requests + 5, max_keepalive_connections=max_concurrent_requests)) + + async def close_client(self): + """Close the httpx client.""" + await self.http_client.aclose() + + async def run(self) -> Dict: + """Run the API stress test.""" + # No client memory tracker needed + urls_to_process = [f"https://httpbin.org/anything/{uuid.uuid4()}" for _ in range(self.url_count)] + url_chunks = [urls_to_process[i:i+self.chunk_size] for i in range(0, len(urls_to_process), self.chunk_size)] + + self.results_summary["start_time"] = time.strftime("%Y-%m-%d %H:%M:%S") + start_time = time.time() + + console.print(f"\n[bold cyan]Crawl4AI API Stress Test - {self.url_count} URLs, {self.max_concurrent_requests} concurrent requests[/bold cyan]") + console.print(f"[bold cyan]Target API:[/bold cyan] {self.api_base_url}, [bold cyan]Mode:[/bold cyan] {'Streaming' if self.stream_mode else 'Batch'}, [bold cyan]URLs per Request:[/bold cyan] {self.chunk_size}") + # Removed client memory log + + semaphore = asyncio.Semaphore(self.max_concurrent_requests) + + # Updated Batch logging header + console.print("\n[bold]API Request Batch Progress:[/bold]") + # Adjusted spacing and added Peak + console.print("[bold] Batch | Progress | SrvMem Peak / Δ|Max (MB) | Reqs/sec | S/F URLs | Time (s) | Status [/bold]") + # Adjust separator length if needed, looks okay for now + console.print("─" * 95) + + # No client memory monitor task needed + + tasks = [] + total_api_calls = len(url_chunks) + self.results_summary["total_api_calls"] = total_api_calls + + try: + for i, chunk in enumerate(url_chunks): + task = asyncio.create_task(self._make_api_request( + chunk=chunk, + batch_idx=i + 1, + total_batches=total_api_calls, + semaphore=semaphore + # No memory tracker passed + )) + tasks.append(task) + + api_results = await asyncio.gather(*tasks) + + # Process aggregated results including server memory + total_successful_requests = sum(1 for r in api_results if r['request_success']) + total_failed_requests = total_api_calls - total_successful_requests + total_successful_urls = sum(r['success_urls'] for r in api_results) + total_failed_urls = sum(r['failed_urls'] for r in api_results) + total_urls_processed = total_successful_urls + total_failed_urls + + # Aggregate server memory metrics + valid_samples = [r for r in api_results if r.get('server_delta_or_max_mb') is not None] # Filter results with valid mem data + self.results_summary["server_memory_metrics"]["samples"] = valid_samples # Store raw samples with both peak and delta/max + + if valid_samples: + delta_or_max_values = [r['server_delta_or_max_mb'] for r in valid_samples] + if self.stream_mode: + # Stream mode: delta_or_max holds max snapshot + self.results_summary["server_memory_metrics"]["stream_mode_avg_max_snapshot_mb"] = sum(delta_or_max_values) / len(delta_or_max_values) + self.results_summary["server_memory_metrics"]["stream_mode_max_max_snapshot_mb"] = max(delta_or_max_values) + else: # Batch mode + # delta_or_max holds delta + self.results_summary["server_memory_metrics"]["batch_mode_avg_delta_mb"] = sum(delta_or_max_values) / len(delta_or_max_values) + self.results_summary["server_memory_metrics"]["batch_mode_max_delta_mb"] = max(delta_or_max_values) + + # Aggregate peak values for batch mode + peak_values = [r['server_peak_memory_mb'] for r in valid_samples if r.get('server_peak_memory_mb') is not None] + if peak_values: + self.results_summary["server_memory_metrics"]["batch_mode_avg_peak_mb"] = sum(peak_values) / len(peak_values) + self.results_summary["server_memory_metrics"]["batch_mode_max_peak_mb"] = max(peak_values) + + + self.results_summary.update({ + "successful_requests": total_successful_requests, + "failed_requests": total_failed_requests, + "successful_urls": total_successful_urls, + "failed_urls": total_failed_urls, + "total_urls_processed": total_urls_processed, + }) + + except Exception as e: + console.print(f"[bold red]An error occurred during task execution: {e}[/bold red]") + import traceback + traceback.print_exc() + # No finally block needed for monitor task + + end_time = time.time() + self.results_summary.update({ + "end_time": time.strftime("%Y-%m-%d %H:%M:%S"), + "total_time_seconds": end_time - start_time, + # No client memory report + }) + self._save_results() + return self.results_summary + + async def _make_api_request( + self, + chunk: List[str], + batch_idx: int, + total_batches: int, + semaphore: asyncio.Semaphore + # No memory tracker + ) -> Dict: + """Makes a single API request for a chunk of URLs, handling concurrency and logging server memory.""" + request_success = False + success_urls = 0 + failed_urls = 0 + status = "Pending" + status_color = "grey" + server_memory_metric = None # Store delta (batch) or max snapshot (stream) + api_call_start_time = time.time() + + async with semaphore: + try: + # No client memory sampling + + endpoint = "/crawl/stream" if self.stream_mode else "/crawl" + payload = { + "urls": chunk, + "browser_config": {"type": "BrowserConfig", "params": {"headless": True}}, + "crawler_config": { + "type": "CrawlerRunConfig", + "params": {"cache_mode": "BYPASS", "stream": self.stream_mode} + } + } + + if self.stream_mode: + max_server_mem_snapshot = 0.0 # Track max memory seen in this stream + async with self.http_client.stream("POST", endpoint, json=payload) as response: + initial_status_code = response.status_code + response.raise_for_status() + + completed_marker_received = False + async for line in response.aiter_lines(): + if line: + try: + data = json.loads(line) + if data.get("status") == "completed": + completed_marker_received = True + break + elif data.get("url"): + if data.get("success"): success_urls += 1 + else: failed_urls += 1 + # Extract server memory snapshot per result + mem_snapshot = data.get('server_memory_mb') + if mem_snapshot is not None: + max_server_mem_snapshot = max(max_server_mem_snapshot, float(mem_snapshot)) + except json.JSONDecodeError: + console.print(f"[Batch {batch_idx}] [red]Stream decode error for line:[/red] {line}") + failed_urls = len(chunk) + break + request_success = completed_marker_received + if not request_success: + failed_urls = len(chunk) - success_urls + server_memory_metric = max_server_mem_snapshot # Use max snapshot for stream logging + + else: # Batch mode + response = await self.http_client.post(endpoint, json=payload) + response.raise_for_status() + data = response.json() + + # Extract server memory delta from the response + server_memory_metric = data.get('server_memory_delta_mb') + server_peak_mem_mb = data.get('server_peak_memory_mb') + + if data.get("success") and "results" in data: + request_success = True + results_list = data.get("results", []) + for result_item in results_list: + if result_item.get("success"): success_urls += 1 + else: failed_urls += 1 + if len(results_list) != len(chunk): + console.print(f"[Batch {batch_idx}] [yellow]Warning: Result count ({len(results_list)}) doesn't match URL count ({len(chunk)})[/yellow]") + failed_urls = len(chunk) - success_urls + else: + request_success = False + failed_urls = len(chunk) + # Try to get memory from error detail if available + detail = data.get('detail') + if isinstance(detail, str): + try: detail_json = json.loads(detail) + except: detail_json = {} + elif isinstance(detail, dict): + detail_json = detail + else: detail_json = {} + server_peak_mem_mb = detail_json.get('server_peak_memory_mb', None) + server_memory_metric = detail_json.get('server_memory_delta_mb', None) + console.print(f"[Batch {batch_idx}] [red]API request failed:[/red] {detail_json.get('error', 'No details')}") + + + except httpx.HTTPStatusError as e: + request_success = False + failed_urls = len(chunk) + console.print(f"[Batch {batch_idx}] [bold red]HTTP Error {e.response.status_code}:[/] {e.request.url}") + try: + error_detail = e.response.json() + # Attempt to extract memory info even from error responses + detail_content = error_detail.get('detail', {}) + if isinstance(detail_content, str): # Handle if detail is stringified JSON + try: detail_content = json.loads(detail_content) + except: detail_content = {} + server_memory_metric = detail_content.get('server_memory_delta_mb', None) + server_peak_mem_mb = detail_content.get('server_peak_memory_mb', None) + console.print(f"Response: {error_detail}") + except Exception: + console.print(f"Response Text: {e.response.text[:200]}...") + except httpx.RequestError as e: + request_success = False + failed_urls = len(chunk) + console.print(f"[Batch {batch_idx}] [bold red]Request Error:[/bold] {e.request.url} - {e}") + except Exception as e: + request_success = False + failed_urls = len(chunk) + console.print(f"[Batch {batch_idx}] [bold red]Unexpected Error:[/bold] {e}") + import traceback + traceback.print_exc() + + finally: + api_call_time = time.time() - api_call_start_time + total_processed_urls = success_urls + failed_urls + + if request_success and failed_urls == 0: status_color, status = "green", "Success" + elif request_success and success_urls > 0: status_color, status = "yellow", "Partial" + else: status_color, status = "red", "Failed" + + current_total_urls = batch_idx * self.chunk_size + progress_pct = min(100.0, (current_total_urls / self.url_count) * 100) + reqs_per_sec = 1.0 / api_call_time if api_call_time > 0 else float('inf') + + # --- New Memory Formatting --- + mem_display = " N/A " # Default + peak_mem_value = None + delta_or_max_value = None + + if self.stream_mode: + # server_memory_metric holds max snapshot for stream + if server_memory_metric is not None: + mem_display = f"{server_memory_metric:.1f} (Max)" + delta_or_max_value = server_memory_metric # Store for aggregation + else: # Batch mode - expect peak and delta + # We need to get peak and delta from the API response + peak_mem_value = locals().get('server_peak_mem_mb', None) # Get from response data if available + delta_value = server_memory_metric # server_memory_metric holds delta for batch + + if peak_mem_value is not None and delta_value is not None: + mem_display = f"{peak_mem_value:.1f} / {delta_value:+.1f}" + delta_or_max_value = delta_value # Store delta for aggregation + elif peak_mem_value is not None: + mem_display = f"{peak_mem_value:.1f} / N/A" + elif delta_value is not None: + mem_display = f"N/A / {delta_value:+.1f}" + delta_or_max_value = delta_value # Store delta for aggregation + + # --- Updated Print Statement with Adjusted Padding --- + console.print( + f" {batch_idx:<5} | {progress_pct:6.1f}% | {mem_display:>24} | {reqs_per_sec:8.1f} | " # Increased width for memory column + f"{success_urls:^7}/{failed_urls:<6} | {api_call_time:8.2f} | [{status_color}]{status:<7}[/{status_color}] " # Added trailing space + ) + + # --- Updated Return Dictionary --- + return_data = { + "batch_idx": batch_idx, + "request_success": request_success, + "success_urls": success_urls, + "failed_urls": failed_urls, + "time": api_call_time, + # Return both peak (if available) and delta/max + "server_peak_memory_mb": peak_mem_value, # Will be None for stream mode + "server_delta_or_max_mb": delta_or_max_value # Delta for batch, Max for stream + } + # Add back the specific batch mode delta if needed elsewhere, but delta_or_max covers it + # if not self.stream_mode: + # return_data["server_memory_delta_mb"] = delta_value + return return_data + + # No _periodic_memory_sample needed + + def _save_results(self) -> None: + """Saves the results summary to a JSON file.""" + results_path = self.report_path / f"api_test_summary_{self.test_id}.json" + try: + # No client memory path to convert + with open(results_path, 'w', encoding='utf-8') as f: + json.dump(self.results_summary, f, indent=2, default=str) + except Exception as e: + console.print(f"[bold red]Failed to save results summary: {e}[/bold red]") + + +# --- run_full_test Function --- +async def run_full_test(args): + """Runs the full API stress test process.""" + client = httpx.AsyncClient(base_url=args.api_url, timeout=REQUEST_TIMEOUT) + + if not await check_server_health(client): + console.print("[bold red]Aborting test due to server health check failure.[/]") + await client.aclose() + return + await client.aclose() + + test = ApiStressTest( + api_url=args.api_url, + url_count=args.urls, + max_concurrent_requests=args.max_concurrent_requests, + chunk_size=args.chunk_size, + report_path=args.report_path, + stream_mode=args.stream, + ) + results = {} + try: + results = await test.run() + finally: + await test.close_client() + + if not results: + console.print("[bold red]Test did not produce results.[/bold red]") + return + + console.print("\n" + "=" * 80) + console.print("[bold green]API Stress Test Completed[/bold green]") + console.print("=" * 80) + + success_rate_reqs = results["successful_requests"] / results["total_api_calls"] * 100 if results["total_api_calls"] > 0 else 0 + success_rate_urls = results["successful_urls"] / results["url_count"] * 100 if results["url_count"] > 0 else 0 + urls_per_second = results["total_urls_processed"] / results["total_time_seconds"] if results["total_time_seconds"] > 0 else 0 + reqs_per_second = results["total_api_calls"] / results["total_time_seconds"] if results["total_time_seconds"] > 0 else 0 + + + console.print(f"[bold cyan]Test ID:[/bold cyan] {results['test_id']}") + console.print(f"[bold cyan]Target API:[/bold cyan] {results['api_url']}") + console.print(f"[bold cyan]Configuration:[/bold cyan] {results['url_count']} URLs, {results['max_concurrent_requests']} concurrent client requests, URLs/Req: {results['chunk_size']}, Stream: {results['stream_mode']}") + console.print(f"[bold cyan]API Requests:[/bold cyan] {results['successful_requests']} successful, {results['failed_requests']} failed ({results['total_api_calls']} total, {success_rate_reqs:.1f}% success)") + console.print(f"[bold cyan]URL Processing:[/bold cyan] {results['successful_urls']} successful, {results['failed_urls']} failed ({results['total_urls_processed']} processed, {success_rate_urls:.1f}% success)") + console.print(f"[bold cyan]Performance:[/bold cyan] {results['total_time_seconds']:.2f}s total | Avg Reqs/sec: {reqs_per_second:.2f} | Avg URLs/sec: {urls_per_second:.2f}") + + # Report Server Memory + mem_metrics = results.get("server_memory_metrics", {}) + mem_samples = mem_metrics.get("samples", []) + if mem_samples: + num_samples = len(mem_samples) + if results['stream_mode']: + avg_mem = mem_metrics.get("stream_mode_avg_max_snapshot_mb") + max_mem = mem_metrics.get("stream_mode_max_max_snapshot_mb") + avg_str = f"{avg_mem:.1f}" if avg_mem is not None else "N/A" + max_str = f"{max_mem:.1f}" if max_mem is not None else "N/A" + console.print(f"[bold cyan]Server Memory (Stream):[/bold cyan] Avg Max Snapshot: {avg_str} MB | Max Max Snapshot: {max_str} MB (across {num_samples} requests)") + else: # Batch mode + avg_delta = mem_metrics.get("batch_mode_avg_delta_mb") + max_delta = mem_metrics.get("batch_mode_max_delta_mb") + avg_peak = mem_metrics.get("batch_mode_avg_peak_mb") + max_peak = mem_metrics.get("batch_mode_max_peak_mb") + + avg_delta_str = f"{avg_delta:.1f}" if avg_delta is not None else "N/A" + max_delta_str = f"{max_delta:.1f}" if max_delta is not None else "N/A" + avg_peak_str = f"{avg_peak:.1f}" if avg_peak is not None else "N/A" + max_peak_str = f"{max_peak:.1f}" if max_peak is not None else "N/A" + + console.print(f"[bold cyan]Server Memory (Batch):[/bold cyan] Avg Peak: {avg_peak_str} MB | Max Peak: {max_peak_str} MB | Avg Delta: {avg_delta_str} MB | Max Delta: {max_delta_str} MB (across {num_samples} requests)") + else: + console.print("[bold cyan]Server Memory:[/bold cyan] No memory data reported by server.") + + + # No client memory report + summary_path = pathlib.Path(args.report_path) / f"api_test_summary_{results['test_id']}.json" + console.print(f"[bold green]Results summary saved to {summary_path}[/bold green]") + + if results["failed_requests"] > 0: + console.print(f"\n[bold yellow]Warning: {results['failed_requests']} API requests failed ({100-success_rate_reqs:.1f}% failure rate)[/bold yellow]") + if results["failed_urls"] > 0: + console.print(f"[bold yellow]Warning: {results['failed_urls']} URLs failed to process ({100-success_rate_urls:.1f}% URL failure rate)[/bold yellow]") + if results["total_urls_processed"] < results["url_count"]: + console.print(f"\n[bold red]Error: Only {results['total_urls_processed']} out of {results['url_count']} target URLs were processed![/bold red]") + + +# --- main Function (Argument parsing mostly unchanged) --- +def main(): + """Main entry point for the script.""" + parser = argparse.ArgumentParser(description="Crawl4AI API Server Stress Test") + + parser.add_argument("--api-url", type=str, default=DEFAULT_API_URL, help=f"Base URL of the Crawl4AI API server (default: {DEFAULT_API_URL})") + parser.add_argument("--urls", type=int, default=DEFAULT_URL_COUNT, help=f"Total number of unique URLs to process via API calls (default: {DEFAULT_URL_COUNT})") + parser.add_argument("--max-concurrent-requests", type=int, default=DEFAULT_MAX_CONCURRENT_REQUESTS, help=f"Maximum concurrent API requests from this client (default: {DEFAULT_MAX_CONCURRENT_REQUESTS})") + parser.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help=f"Number of URLs per API request payload (default: {DEFAULT_CHUNK_SIZE})") + parser.add_argument("--stream", action="store_true", default=DEFAULT_STREAM_MODE, help=f"Use the /crawl/stream endpoint instead of /crawl (default: {DEFAULT_STREAM_MODE})") + parser.add_argument("--report-path", type=str, default=DEFAULT_REPORT_PATH, help=f"Path to save reports and logs (default: {DEFAULT_REPORT_PATH})") + parser.add_argument("--clean-reports", action="store_true", help="Clean up report directory before running") + + args = parser.parse_args() + + console.print("[bold underline]Crawl4AI API Stress Test Configuration[/bold underline]") + console.print(f"API URL: {args.api_url}") + console.print(f"Total URLs: {args.urls}, Concurrent Client Requests: {args.max_concurrent_requests}, URLs per Request: {args.chunk_size}") + console.print(f"Mode: {'Streaming' if args.stream else 'Batch'}") + console.print(f"Report Path: {args.report_path}") + console.print("-" * 40) + if args.clean_reports: console.print("[cyan]Option: Clean reports before test[/cyan]") + console.print("-" * 40) + + if args.clean_reports: + report_dir = pathlib.Path(args.report_path) + if report_dir.exists(): + console.print(f"[yellow]Cleaning up reports directory: {args.report_path}[/yellow]") + shutil.rmtree(args.report_path) + report_dir.mkdir(parents=True, exist_ok=True) + + try: + asyncio.run(run_full_test(args)) + except KeyboardInterrupt: + console.print("\n[bold yellow]Test interrupted by user.[/bold yellow]") + except Exception as e: + console.print(f"\n[bold red]An unexpected error occurred:[/bold red] {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + # No need to modify sys.path for SimpleMemoryTracker as it's removed + main() \ No newline at end of file diff --git a/tests/memory/test_stress_docker_api.py b/tests/memory/test_stress_docker_api.py new file mode 100644 index 00000000..05b3bea8 --- /dev/null +++ b/tests/memory/test_stress_docker_api.py @@ -0,0 +1,129 @@ +""" +Crawl4AI Docker API stress tester. + +Examples +-------- +python test_stress_docker_api.py --urls 1000 --concurrency 32 +python test_stress_docker_api.py --urls 1000 --concurrency 32 --stream +python test_stress_docker_api.py --base-url http://10.0.0.42:11235 --http2 +""" + +import argparse, asyncio, json, secrets, statistics, time +from typing import List, Tuple +import httpx +from rich.console import Console +from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.table import Table + +console = Console() + + +# ───────────────────────── helpers ───────────────────────── +def make_fake_urls(n: int) -> List[str]: + base = "https://httpbin.org/anything/" + return [f"{base}{secrets.token_hex(8)}" for _ in range(n)] + + +async def fire( + client: httpx.AsyncClient, endpoint: str, payload: dict, sem: asyncio.Semaphore +) -> Tuple[bool, float]: + async with sem: + print(f"POST {endpoint} with {len(payload['urls'])} URLs") + t0 = time.perf_counter() + try: + if endpoint.endswith("/stream"): + async with client.stream("POST", endpoint, json=payload) as r: + r.raise_for_status() + async for _ in r.aiter_lines(): + pass + else: + r = await client.post(endpoint, json=payload) + r.raise_for_status() + return True, time.perf_counter() - t0 + except Exception: + return False, time.perf_counter() - t0 + + +def pct(lat: List[float], p: float) -> str: + """Return percentile string even for tiny samples.""" + if not lat: + return "-" + if len(lat) == 1: + return f"{lat[0]:.2f}s" + lat_sorted = sorted(lat) + k = (p / 100) * (len(lat_sorted) - 1) + lo = int(k) + hi = min(lo + 1, len(lat_sorted) - 1) + frac = k - lo + val = lat_sorted[lo] * (1 - frac) + lat_sorted[hi] * frac + return f"{val:.2f}s" + + +# ───────────────────────── main ───────────────────────── +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Stress test Crawl4AI Docker API") + p.add_argument("--urls", type=int, default=100, help="number of URLs") + p.add_argument("--concurrency", type=int, default=1, help="max POSTs in flight") + p.add_argument("--chunk-size", type=int, default=50, help="URLs per request") + p.add_argument("--base-url", default="http://localhost:11235", help="API root") + # p.add_argument("--base-url", default="http://localhost:8020", help="API root") + p.add_argument("--stream", action="store_true", help="use /crawl/stream") + p.add_argument("--http2", action="store_true", help="enable HTTP/2") + p.add_argument("--headless", action="store_true", default=True) + return p.parse_args() + + +async def main() -> None: + args = parse_args() + + urls = make_fake_urls(args.urls) + batches = [urls[i : i + args.chunk_size] for i in range(0, len(urls), args.chunk_size)] + endpoint = "/crawl/stream" if args.stream else "/crawl" + sem = asyncio.Semaphore(args.concurrency) + + async with httpx.AsyncClient(base_url=args.base_url, http2=args.http2, timeout=None) as client: + with Progress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeElapsedColumn(), + TimeRemainingColumn(), + ) as progress: + task_id = progress.add_task("[cyan]bombarding…", total=len(batches)) + tasks = [] + for chunk in batches: + payload = { + "urls": chunk, + "browser_config": {"type": "BrowserConfig", "params": {"headless": args.headless}}, + "crawler_config": {"type": "CrawlerRunConfig", "params": {"cache_mode": "BYPASS", "stream": args.stream}}, + } + tasks.append(asyncio.create_task(fire(client, endpoint, payload, sem))) + progress.advance(task_id) + + results = await asyncio.gather(*tasks) + + ok_latencies = [dt for ok, dt in results if ok] + err_count = sum(1 for ok, _ in results if not ok) + + table = Table(title="Docker API Stress‑Test Summary") + table.add_column("total", justify="right") + table.add_column("errors", justify="right") + table.add_column("p50", justify="right") + table.add_column("p95", justify="right") + table.add_column("max", justify="right") + + table.add_row( + str(len(results)), + str(err_count), + pct(ok_latencies, 50), + pct(ok_latencies, 95), + f"{max(ok_latencies):.2f}s" if ok_latencies else "-", + ) + console.print(table) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + console.print("\n[yellow]aborted by user[/]") diff --git a/tests/memory/test_stress_sdk.py b/tests/memory/test_stress_sdk.py index 8000690c..14da94a4 100644 --- a/tests/memory/test_stress_sdk.py +++ b/tests/memory/test_stress_sdk.py @@ -37,8 +37,8 @@ from crawl4ai import ( DEFAULT_SITE_PATH = "test_site" DEFAULT_PORT = 8000 DEFAULT_MAX_SESSIONS = 16 -DEFAULT_URL_COUNT = 100 -DEFAULT_CHUNK_SIZE = 10 # Define chunk size for batch logging +DEFAULT_URL_COUNT = 1 +DEFAULT_CHUNK_SIZE = 1 # Define chunk size for batch logging DEFAULT_REPORT_PATH = "reports" DEFAULT_STREAM_MODE = False DEFAULT_MONITOR_MODE = "DETAILED"