feat(api): implement crawler pool manager for improved resource handling
Adds a new CrawlerManager class to handle browser instance pooling and failover: - Implements auto-scaling based on system resources - Adds primary/backup crawler management - Integrates memory monitoring and throttling - Adds streaming support with memory tracking - Updates API endpoints to use pooled crawlers BREAKING CHANGE: API endpoints now require CrawlerManager initialization
This commit is contained in:
@@ -542,9 +542,9 @@ class AsyncWebCrawler:
|
||||
markdown_input_html = source_lambda()
|
||||
|
||||
# Log which source is being used (optional, but helpful for debugging)
|
||||
if self.logger and verbose:
|
||||
actual_source_used = selected_html_source if selected_html_source in html_source_selector else 'cleaned_html (default)'
|
||||
self.logger.debug(f"Using '{actual_source_used}' as source for Markdown generation for {url}", tag="MARKDOWN_SRC")
|
||||
# if self.logger and verbose:
|
||||
# actual_source_used = selected_html_source if selected_html_source in html_source_selector else 'cleaned_html (default)'
|
||||
# self.logger.debug(f"Using '{actual_source_used}' as source for Markdown generation for {url}", tag="MARKDOWN_SRC")
|
||||
|
||||
except Exception as e:
|
||||
# Handle potential errors, especially from preprocess_html_for_schema
|
||||
|
||||
503
deploy/docker/api copy.py
Normal file
503
deploy/docker/api copy.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple
|
||||
from functools import partial
|
||||
|
||||
import logging
|
||||
from typing import Optional, AsyncGenerator
|
||||
from urllib.parse import unquote
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.background import BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
CrawlerRunConfig,
|
||||
LLMExtractionStrategy,
|
||||
CacheMode,
|
||||
BrowserConfig,
|
||||
MemoryAdaptiveDispatcher,
|
||||
RateLimiter,
|
||||
LLMConfig
|
||||
)
|
||||
from crawl4ai.utils import perform_completion_with_backoff
|
||||
from crawl4ai.content_filter_strategy import (
|
||||
PruningContentFilter,
|
||||
BM25ContentFilter,
|
||||
LLMContentFilter
|
||||
)
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy
|
||||
|
||||
from utils import (
|
||||
TaskStatus,
|
||||
FilterType,
|
||||
get_base_url,
|
||||
is_task_id,
|
||||
should_cleanup_task,
|
||||
decode_redis_hash
|
||||
)
|
||||
|
||||
import psutil, time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Helper to get memory ---
|
||||
def _get_memory_mb():
|
||||
try:
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get memory info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def handle_llm_qa(
|
||||
url: str,
|
||||
query: str,
|
||||
config: dict
|
||||
) -> str:
|
||||
"""Process QA using LLM with crawled content as context."""
|
||||
try:
|
||||
# Extract base URL by finding last '?q=' occurrence
|
||||
last_q_index = url.rfind('?q=')
|
||||
if last_q_index != -1:
|
||||
url = url[:last_q_index]
|
||||
|
||||
# Get markdown content
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(url)
|
||||
if not result.success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=result.error_message
|
||||
)
|
||||
content = result.markdown.fit_markdown
|
||||
|
||||
# Create prompt and get LLM response
|
||||
prompt = f"""Use the following content as context to answer the question.
|
||||
Content:
|
||||
{content}
|
||||
|
||||
Question: {query}
|
||||
|
||||
Answer:"""
|
||||
|
||||
response = perform_completion_with_backoff(
|
||||
provider=config["llm"]["provider"],
|
||||
prompt_with_variables=prompt,
|
||||
api_token=os.environ.get(config["llm"].get("api_key_env", ""))
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
logger.error(f"QA processing error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
async def process_llm_extraction(
|
||||
redis: aioredis.Redis,
|
||||
config: dict,
|
||||
task_id: str,
|
||||
url: str,
|
||||
instruction: str,
|
||||
schema: Optional[str] = None,
|
||||
cache: str = "0"
|
||||
) -> None:
|
||||
"""Process LLM extraction in background."""
|
||||
try:
|
||||
# If config['llm'] has api_key then ignore the api_key_env
|
||||
api_key = ""
|
||||
if "api_key" in config["llm"]:
|
||||
api_key = config["llm"]["api_key"]
|
||||
else:
|
||||
api_key = os.environ.get(config["llm"].get("api_key_env", None), "")
|
||||
llm_strategy = LLMExtractionStrategy(
|
||||
llm_config=LLMConfig(
|
||||
provider=config["llm"]["provider"],
|
||||
api_token=api_key
|
||||
),
|
||||
instruction=instruction,
|
||||
schema=json.loads(schema) if schema else None,
|
||||
)
|
||||
|
||||
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
config=CrawlerRunConfig(
|
||||
extraction_strategy=llm_strategy,
|
||||
scraping_strategy=LXMLWebScrapingStrategy(),
|
||||
cache_mode=cache_mode
|
||||
)
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": result.error_message
|
||||
})
|
||||
return
|
||||
|
||||
try:
|
||||
content = json.loads(result.extracted_content)
|
||||
except json.JSONDecodeError:
|
||||
content = result.extracted_content
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"result": json.dumps(content)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM extraction error: {str(e)}", exc_info=True)
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
async def handle_markdown_request(
|
||||
url: str,
|
||||
filter_type: FilterType,
|
||||
query: Optional[str] = None,
|
||||
cache: str = "0",
|
||||
config: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Handle markdown generation requests."""
|
||||
try:
|
||||
decoded_url = unquote(url)
|
||||
if not decoded_url.startswith(('http://', 'https://')):
|
||||
decoded_url = 'https://' + decoded_url
|
||||
|
||||
if filter_type == FilterType.RAW:
|
||||
md_generator = DefaultMarkdownGenerator()
|
||||
else:
|
||||
content_filter = {
|
||||
FilterType.FIT: PruningContentFilter(),
|
||||
FilterType.BM25: BM25ContentFilter(user_query=query or ""),
|
||||
FilterType.LLM: LLMContentFilter(
|
||||
llm_config=LLMConfig(
|
||||
provider=config["llm"]["provider"],
|
||||
api_token=os.environ.get(config["llm"].get("api_key_env", None), ""),
|
||||
),
|
||||
instruction=query or "Extract main content"
|
||||
)
|
||||
}[filter_type]
|
||||
md_generator = DefaultMarkdownGenerator(content_filter=content_filter)
|
||||
|
||||
cache_mode = CacheMode.ENABLED if cache == "1" else CacheMode.WRITE_ONLY
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url=decoded_url,
|
||||
config=CrawlerRunConfig(
|
||||
markdown_generator=md_generator,
|
||||
scraping_strategy=LXMLWebScrapingStrategy(),
|
||||
cache_mode=cache_mode
|
||||
)
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=result.error_message
|
||||
)
|
||||
|
||||
return (result.markdown.raw_markdown
|
||||
if filter_type == FilterType.RAW
|
||||
else result.markdown.fit_markdown)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Markdown error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
async def handle_llm_request(
|
||||
redis: aioredis.Redis,
|
||||
background_tasks: BackgroundTasks,
|
||||
request: Request,
|
||||
input_path: str,
|
||||
query: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
cache: str = "0",
|
||||
config: Optional[dict] = None
|
||||
) -> JSONResponse:
|
||||
"""Handle LLM extraction requests."""
|
||||
base_url = get_base_url(request)
|
||||
|
||||
try:
|
||||
if is_task_id(input_path):
|
||||
return await handle_task_status(
|
||||
redis, input_path, base_url
|
||||
)
|
||||
|
||||
if not query:
|
||||
return JSONResponse({
|
||||
"message": "Please provide an instruction",
|
||||
"_links": {
|
||||
"example": {
|
||||
"href": f"{base_url}/llm/{input_path}?q=Extract+main+content",
|
||||
"title": "Try this example"
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return await create_new_task(
|
||||
redis,
|
||||
background_tasks,
|
||||
input_path,
|
||||
query,
|
||||
schema,
|
||||
cache,
|
||||
base_url,
|
||||
config
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM endpoint error: {str(e)}", exc_info=True)
|
||||
return JSONResponse({
|
||||
"error": str(e),
|
||||
"_links": {
|
||||
"retry": {"href": str(request.url)}
|
||||
}
|
||||
}, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
async def handle_task_status(
|
||||
redis: aioredis.Redis,
|
||||
task_id: str,
|
||||
base_url: str
|
||||
) -> JSONResponse:
|
||||
"""Handle task status check requests."""
|
||||
task = await redis.hgetall(f"task:{task_id}")
|
||||
if not task:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Task not found"
|
||||
)
|
||||
|
||||
task = decode_redis_hash(task)
|
||||
response = create_task_response(task, task_id, base_url)
|
||||
|
||||
if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||||
if should_cleanup_task(task["created_at"]):
|
||||
await redis.delete(f"task:{task_id}")
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
async def create_new_task(
|
||||
redis: aioredis.Redis,
|
||||
background_tasks: BackgroundTasks,
|
||||
input_path: str,
|
||||
query: str,
|
||||
schema: Optional[str],
|
||||
cache: str,
|
||||
base_url: str,
|
||||
config: dict
|
||||
) -> JSONResponse:
|
||||
"""Create and initialize a new task."""
|
||||
decoded_url = unquote(input_path)
|
||||
if not decoded_url.startswith(('http://', 'https://')):
|
||||
decoded_url = 'https://' + decoded_url
|
||||
|
||||
from datetime import datetime
|
||||
task_id = f"llm_{int(datetime.now().timestamp())}_{id(background_tasks)}"
|
||||
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"url": decoded_url
|
||||
})
|
||||
|
||||
background_tasks.add_task(
|
||||
process_llm_extraction,
|
||||
redis,
|
||||
config,
|
||||
task_id,
|
||||
decoded_url,
|
||||
query,
|
||||
schema,
|
||||
cache
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"url": decoded_url,
|
||||
"_links": {
|
||||
"self": {"href": f"{base_url}/llm/{task_id}"},
|
||||
"status": {"href": f"{base_url}/llm/{task_id}"}
|
||||
}
|
||||
})
|
||||
|
||||
def create_task_response(task: dict, task_id: str, base_url: str) -> dict:
|
||||
"""Create response for task status check."""
|
||||
response = {
|
||||
"task_id": task_id,
|
||||
"status": task["status"],
|
||||
"created_at": task["created_at"],
|
||||
"url": task["url"],
|
||||
"_links": {
|
||||
"self": {"href": f"{base_url}/llm/{task_id}"},
|
||||
"refresh": {"href": f"{base_url}/llm/{task_id}"}
|
||||
}
|
||||
}
|
||||
|
||||
if task["status"] == TaskStatus.COMPLETED:
|
||||
response["result"] = json.loads(task["result"])
|
||||
elif task["status"] == TaskStatus.FAILED:
|
||||
response["error"] = task["error"]
|
||||
|
||||
return response
|
||||
|
||||
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream results with heartbeats and completion markers."""
|
||||
import json
|
||||
from utils import datetime_handler
|
||||
|
||||
try:
|
||||
async for result in results_gen:
|
||||
try:
|
||||
server_memory_mb = _get_memory_mb()
|
||||
result_dict = result.model_dump()
|
||||
result_dict['server_memory_mb'] = server_memory_mb
|
||||
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
||||
data = json.dumps(result_dict, default=datetime_handler) + "\n"
|
||||
yield data.encode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"Serialization error: {e}")
|
||||
error_response = {"error": str(e), "url": getattr(result, 'url', 'unknown')}
|
||||
yield (json.dumps(error_response) + "\n").encode('utf-8')
|
||||
|
||||
yield json.dumps({"status": "completed"}).encode('utf-8')
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Client disconnected during streaming")
|
||||
finally:
|
||||
try:
|
||||
await crawler.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Crawler cleanup error: {e}")
|
||||
|
||||
async def handle_crawl_request(
|
||||
urls: List[str],
|
||||
browser_config: dict,
|
||||
crawler_config: dict,
|
||||
config: dict
|
||||
) -> dict:
|
||||
"""Handle non-streaming crawl requests."""
|
||||
start_mem_mb = _get_memory_mb() # <--- Get memory before
|
||||
start_time = time.time()
|
||||
mem_delta_mb = None
|
||||
peak_mem_mb = start_mem_mb
|
||||
|
||||
try:
|
||||
browser_config = BrowserConfig.load(browser_config)
|
||||
crawler_config = CrawlerRunConfig.load(crawler_config)
|
||||
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
|
||||
)
|
||||
)
|
||||
|
||||
crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config)
|
||||
await crawler.start()
|
||||
results = []
|
||||
func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many")
|
||||
partial_func = partial(func,
|
||||
urls[0] if len(urls) == 1 else urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher)
|
||||
results = await partial_func()
|
||||
await crawler.close()
|
||||
|
||||
end_mem_mb = _get_memory_mb() # <--- Get memory after
|
||||
end_time = time.time()
|
||||
|
||||
if start_mem_mb is not None and end_mem_mb is not None:
|
||||
mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta
|
||||
peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory
|
||||
logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": [result.model_dump() for result in results],
|
||||
"server_processing_time_s": end_time - start_time,
|
||||
"server_memory_delta_mb": mem_delta_mb,
|
||||
"server_peak_memory_mb": peak_mem_mb
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Crawl error: {str(e)}", exc_info=True)
|
||||
if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started
|
||||
try:
|
||||
await crawler.close()
|
||||
except Exception as close_e:
|
||||
logger.error(f"Error closing crawler during exception handling: {close_e}")
|
||||
|
||||
# Measure memory even on error if possible
|
||||
end_mem_mb_error = _get_memory_mb()
|
||||
if start_mem_mb is not None and end_mem_mb_error is not None:
|
||||
mem_delta_mb = end_mem_mb_error - start_mem_mb
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=json.dumps({ # Send structured error
|
||||
"error": str(e),
|
||||
"server_memory_delta_mb": mem_delta_mb,
|
||||
"server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0)
|
||||
})
|
||||
)
|
||||
|
||||
async def handle_stream_crawl_request(
|
||||
urls: List[str],
|
||||
browser_config: dict,
|
||||
crawler_config: dict,
|
||||
config: dict
|
||||
) -> Tuple[AsyncWebCrawler, AsyncGenerator]:
|
||||
"""Handle streaming crawl requests."""
|
||||
try:
|
||||
browser_config = BrowserConfig.load(browser_config)
|
||||
# browser_config.verbose = True # Set to False or remove for production stress testing
|
||||
browser_config.verbose = False
|
||||
crawler_config = CrawlerRunConfig.load(crawler_config)
|
||||
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
|
||||
crawler_config.stream = True
|
||||
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
|
||||
)
|
||||
)
|
||||
|
||||
crawler = AsyncWebCrawler(config=browser_config)
|
||||
await crawler.start()
|
||||
|
||||
results_gen = await crawler.arun_many(
|
||||
urls=urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
)
|
||||
|
||||
return crawler, results_gen
|
||||
|
||||
except Exception as e:
|
||||
# Make sure to close crawler if started during an error here
|
||||
if 'crawler' in locals() and crawler.ready:
|
||||
try:
|
||||
await crawler.close()
|
||||
except Exception as close_e:
|
||||
logger.error(f"Error closing crawler during stream setup exception: {close_e}")
|
||||
logger.error(f"Stream crawl error: {str(e)}", exc_info=True)
|
||||
# Raising HTTPException here will prevent streaming response
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
@@ -40,8 +40,19 @@ from utils import (
|
||||
decode_redis_hash
|
||||
)
|
||||
|
||||
import psutil, time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Helper to get memory ---
|
||||
def _get_memory_mb():
|
||||
try:
|
||||
return psutil.Process().memory_info().rss / (1024 * 1024)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get memory info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def handle_llm_qa(
|
||||
url: str,
|
||||
query: str,
|
||||
@@ -351,7 +362,9 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
|
||||
try:
|
||||
async for result in results_gen:
|
||||
try:
|
||||
server_memory_mb = _get_memory_mb()
|
||||
result_dict = result.model_dump()
|
||||
result_dict['server_memory_mb'] = server_memory_mb
|
||||
logger.info(f"Streaming result for {result_dict.get('url', 'unknown')}")
|
||||
data = json.dumps(result_dict, default=datetime_handler) + "\n"
|
||||
yield data.encode('utf-8')
|
||||
@@ -364,19 +377,25 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Client disconnected during streaming")
|
||||
finally:
|
||||
try:
|
||||
await crawler.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Crawler cleanup error: {e}")
|
||||
# finally:
|
||||
# try:
|
||||
# await crawler.close()
|
||||
# except Exception as e:
|
||||
# logger.error(f"Crawler cleanup error: {e}")
|
||||
|
||||
async def handle_crawl_request(
|
||||
crawler: AsyncWebCrawler,
|
||||
urls: List[str],
|
||||
browser_config: dict,
|
||||
crawler_config: dict,
|
||||
config: dict
|
||||
) -> dict:
|
||||
"""Handle non-streaming crawl requests."""
|
||||
start_mem_mb = _get_memory_mb() # <--- Get memory before
|
||||
start_time = time.time()
|
||||
mem_delta_mb = None
|
||||
peak_mem_mb = start_mem_mb
|
||||
|
||||
try:
|
||||
browser_config = BrowserConfig.load(browser_config)
|
||||
crawler_config = CrawlerRunConfig.load(crawler_config)
|
||||
@@ -388,31 +407,63 @@ async def handle_crawl_request(
|
||||
)
|
||||
)
|
||||
|
||||
crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config)
|
||||
await crawler.start()
|
||||
# crawler: AsyncWebCrawler = AsyncWebCrawler(config=browser_config)
|
||||
# await crawler.start()
|
||||
results = []
|
||||
func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many")
|
||||
partial_func = partial(func,
|
||||
urls[0] if len(urls) == 1 else urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher)
|
||||
|
||||
# Simulate work being done by the crawler
|
||||
# logger.debug(f"Request (URLs: {len(urls)}) starting simulated work...") # Add log
|
||||
# await asyncio.sleep(2) # <--- ADD ARTIFICIAL DELAY (e.g., 0.5 seconds)
|
||||
# logger.debug(f"Request (URLs: {len(urls)}) finished simulated work.")
|
||||
|
||||
results = await partial_func()
|
||||
await crawler.close()
|
||||
# await crawler.close()
|
||||
|
||||
end_mem_mb = _get_memory_mb() # <--- Get memory after
|
||||
end_time = time.time()
|
||||
|
||||
if start_mem_mb is not None and end_mem_mb is not None:
|
||||
mem_delta_mb = end_mem_mb - start_mem_mb # <--- Calculate delta
|
||||
peak_mem_mb = max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb) # <--- Get peak memory
|
||||
logger.info(f"Memory usage: Start: {start_mem_mb} MB, End: {end_mem_mb} MB, Delta: {mem_delta_mb} MB, Peak: {peak_mem_mb} MB")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": [result.model_dump() for result in results]
|
||||
"results": [result.model_dump() for result in results],
|
||||
"server_processing_time_s": end_time - start_time,
|
||||
"server_memory_delta_mb": mem_delta_mb,
|
||||
"server_peak_memory_mb": peak_mem_mb
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Crawl error: {str(e)}", exc_info=True)
|
||||
if 'crawler' in locals():
|
||||
await crawler.close()
|
||||
# if 'crawler' in locals() and crawler.ready: # Check if crawler was initialized and started
|
||||
# try:
|
||||
# await crawler.close()
|
||||
# except Exception as close_e:
|
||||
# logger.error(f"Error closing crawler during exception handling: {close_e}")
|
||||
|
||||
# Measure memory even on error if possible
|
||||
end_mem_mb_error = _get_memory_mb()
|
||||
if start_mem_mb is not None and end_mem_mb_error is not None:
|
||||
mem_delta_mb = end_mem_mb_error - start_mem_mb
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
detail=json.dumps({ # Send structured error
|
||||
"error": str(e),
|
||||
"server_memory_delta_mb": mem_delta_mb,
|
||||
"server_peak_memory_mb": max(peak_mem_mb if peak_mem_mb else 0, end_mem_mb_error or 0)
|
||||
})
|
||||
)
|
||||
|
||||
async def handle_stream_crawl_request(
|
||||
crawler: AsyncWebCrawler,
|
||||
urls: List[str],
|
||||
browser_config: dict,
|
||||
crawler_config: dict,
|
||||
@@ -421,9 +472,11 @@ async def handle_stream_crawl_request(
|
||||
"""Handle streaming crawl requests."""
|
||||
try:
|
||||
browser_config = BrowserConfig.load(browser_config)
|
||||
browser_config.verbose = True
|
||||
# browser_config.verbose = True # Set to False or remove for production stress testing
|
||||
browser_config.verbose = False
|
||||
crawler_config = CrawlerRunConfig.load(crawler_config)
|
||||
crawler_config.scraping_strategy = LXMLWebScrapingStrategy()
|
||||
crawler_config.stream = True
|
||||
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
||||
@@ -432,8 +485,8 @@ async def handle_stream_crawl_request(
|
||||
)
|
||||
)
|
||||
|
||||
crawler = AsyncWebCrawler(config=browser_config)
|
||||
await crawler.start()
|
||||
# crawler = AsyncWebCrawler(config=browser_config)
|
||||
# await crawler.start()
|
||||
|
||||
results_gen = await crawler.arun_many(
|
||||
urls=urls,
|
||||
@@ -441,12 +494,19 @@ async def handle_stream_crawl_request(
|
||||
dispatcher=dispatcher
|
||||
)
|
||||
|
||||
# Return the *same* crawler instance and the generator
|
||||
# The caller (server.py) manages the crawler lifecycle via the pool context
|
||||
return crawler, results_gen
|
||||
|
||||
except Exception as e:
|
||||
if 'crawler' in locals():
|
||||
await crawler.close()
|
||||
# Make sure to close crawler if started during an error here
|
||||
# if 'crawler' in locals() and crawler.ready:
|
||||
# try:
|
||||
# await crawler.close()
|
||||
# except Exception as close_e:
|
||||
# logger.error(f"Error closing crawler during stream setup exception: {close_e}")
|
||||
logger.error(f"Stream crawl error: {str(e)}", exc_info=True)
|
||||
# Raising HTTPException here will prevent streaming response
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
|
||||
@@ -48,6 +48,38 @@ security:
|
||||
content_security_policy: "default-src 'self'"
|
||||
strict_transport_security: "max-age=63072000; includeSubDomains"
|
||||
|
||||
# Crawler Pool Configuration
|
||||
crawler_pool:
|
||||
enabled: true # Set to false to disable the pool
|
||||
|
||||
# --- Option 1: Auto-calculate size ---
|
||||
auto_calculate_size: true
|
||||
calculation_params:
|
||||
mem_headroom_mb: 512 # Memory reserved for OS/other apps
|
||||
avg_page_mem_mb: 150 # Estimated MB per concurrent "tab"/page in browsers
|
||||
fd_per_page: 20 # Estimated file descriptors per page
|
||||
core_multiplier: 4 # Max crawlers per CPU core
|
||||
min_pool_size: 2 # Minimum number of primary crawlers
|
||||
max_pool_size: 16 # Maximum number of primary crawlers
|
||||
|
||||
# --- Option 2: Manual size (ignored if auto_calculate_size is true) ---
|
||||
# pool_size: 8
|
||||
|
||||
# --- Other Pool Settings ---
|
||||
backup_pool_size: 1 # Number of backup crawlers
|
||||
max_wait_time_s: 30.0 # Max seconds a request waits for a free crawler
|
||||
throttle_threshold_percent: 70.0 # Start throttling delay above this % usage
|
||||
throttle_delay_min_s: 0.1 # Min throttle delay
|
||||
throttle_delay_max_s: 0.5 # Max throttle delay
|
||||
|
||||
# --- Browser Config for Pooled Crawlers ---
|
||||
browser_config:
|
||||
# No need for "type": "BrowserConfig" here, just params
|
||||
headless: true
|
||||
verbose: false # Keep pool crawlers less verbose in production
|
||||
# user_agent: "MyPooledCrawler/1.0" # Example
|
||||
# Add other BrowserConfig params as needed (e.g., proxy, viewport)
|
||||
|
||||
# Crawler Configuration
|
||||
crawler:
|
||||
memory_threshold_percent: 95.0
|
||||
@@ -61,6 +93,8 @@ crawler:
|
||||
logging:
|
||||
level: "INFO"
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file: "logs/app.log"
|
||||
verbose: true
|
||||
|
||||
# Observability Configuration
|
||||
observability:
|
||||
|
||||
556
deploy/docker/crawler_manager.py
Normal file
556
deploy/docker/crawler_manager.py
Normal file
@@ -0,0 +1,556 @@
|
||||
# crawler_manager.py
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
import psutil
|
||||
import os
|
||||
import resource # For FD limit
|
||||
import random
|
||||
import math
|
||||
from typing import Optional, Tuple, Any, List, Dict, AsyncGenerator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, AsyncLogger
|
||||
# Assuming api.py handlers are accessible or refactored slightly if needed
|
||||
# We might need to import the specific handler functions if we call them directly
|
||||
# from api import handle_crawl_request, handle_stream_crawl_request, _get_memory_mb, stream_results
|
||||
|
||||
# --- Custom Exceptions ---
|
||||
class PoolTimeoutError(Exception):
|
||||
"""Raised when waiting for a crawler resource times out."""
|
||||
pass
|
||||
|
||||
class PoolConfigurationError(Exception):
|
||||
"""Raised for configuration issues."""
|
||||
pass
|
||||
|
||||
class NoHealthyCrawlerError(Exception):
|
||||
"""Raised when no healthy crawler is available."""
|
||||
pass
|
||||
|
||||
|
||||
# --- Configuration Models ---
|
||||
class CalculationParams(BaseModel):
|
||||
mem_headroom_mb: int = 512
|
||||
avg_page_mem_mb: int = 150
|
||||
fd_per_page: int = 20
|
||||
core_multiplier: int = 4
|
||||
min_pool_size: int = 1 # Min safe pages should be at least 1
|
||||
max_pool_size: int = 16
|
||||
|
||||
# V2 validation for avg_page_mem_mb
|
||||
@field_validator('avg_page_mem_mb')
|
||||
@classmethod
|
||||
def check_avg_page_mem(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("avg_page_mem_mb must be positive")
|
||||
return v
|
||||
|
||||
# V2 validation for fd_per_page
|
||||
@field_validator('fd_per_page')
|
||||
@classmethod
|
||||
def check_fd_per_page(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("fd_per_page must be positive")
|
||||
return v
|
||||
|
||||
# crawler_manager.py
|
||||
# ... (imports including BaseModel, Field from pydantic) ...
|
||||
from pydantic import BaseModel, Field, field_validator # <-- Import field_validator
|
||||
|
||||
# --- Configuration Models (Pydantic V2 Syntax) ---
|
||||
class CalculationParams(BaseModel):
|
||||
mem_headroom_mb: int = 512
|
||||
avg_page_mem_mb: int = 150
|
||||
fd_per_page: int = 20
|
||||
core_multiplier: int = 4
|
||||
min_pool_size: int = 1 # Min safe pages should be at least 1
|
||||
max_pool_size: int = 16
|
||||
|
||||
# V2 validation for avg_page_mem_mb
|
||||
@field_validator('avg_page_mem_mb')
|
||||
@classmethod
|
||||
def check_avg_page_mem(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("avg_page_mem_mb must be positive")
|
||||
return v
|
||||
|
||||
# V2 validation for fd_per_page
|
||||
@field_validator('fd_per_page')
|
||||
@classmethod
|
||||
def check_fd_per_page(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("fd_per_page must be positive")
|
||||
return v
|
||||
|
||||
class CrawlerManagerConfig(BaseModel):
|
||||
enabled: bool = True
|
||||
auto_calculate_size: bool = True
|
||||
calculation_params: CalculationParams = Field(default_factory=CalculationParams) # Use Field for default_factory
|
||||
backup_pool_size: int = Field(1, ge=0) # Allow 0 backups
|
||||
max_wait_time_s: float = 30.0
|
||||
throttle_threshold_percent: float = Field(70.0, ge=0, le=100)
|
||||
throttle_delay_min_s: float = 0.1
|
||||
throttle_delay_max_s: float = 0.5
|
||||
browser_config: Dict[str, Any] = Field(default_factory=lambda: {"headless": True, "verbose": False}) # Use Field for default_factory
|
||||
primary_reload_delay_s: float = 60.0
|
||||
|
||||
# --- Crawler Manager ---
|
||||
class CrawlerManager:
|
||||
"""Manages shared AsyncWebCrawler instances, concurrency, and failover."""
|
||||
|
||||
def __init__(self, config: CrawlerManagerConfig, logger = None):
|
||||
if not config.enabled:
|
||||
self.logger.warning("CrawlerManager is disabled by configuration.")
|
||||
# Set defaults to allow server to run, but manager won't function
|
||||
self.config = config
|
||||
self._initialized = False,
|
||||
return
|
||||
|
||||
self.config = config
|
||||
self._primary_crawler: Optional[AsyncWebCrawler] = None
|
||||
self._secondary_crawlers: List[AsyncWebCrawler] = []
|
||||
self._active_crawler_index: int = 0 # 0 for primary, 1+ for secondary index
|
||||
self._primary_healthy: bool = False
|
||||
self._secondary_healthy_flags: List[bool] = []
|
||||
|
||||
self._safe_pages: int = 1 # Default, calculated in initialize
|
||||
self._semaphore: Optional[asyncio.Semaphore] = None
|
||||
self._state_lock = asyncio.Lock() # Protects active_crawler, health flags
|
||||
self._reload_tasks: List[Optional[asyncio.Task]] = [] # Track reload background tasks
|
||||
|
||||
self._initialized = False
|
||||
self._shutting_down = False
|
||||
|
||||
# Initialize logger if provided
|
||||
if logger is None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
else:
|
||||
self.logger = logger
|
||||
|
||||
self.logger.info("CrawlerManager initialized with config.")
|
||||
self.logger.debug(f"Config: {self.config.model_dump_json(indent=2)}")
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self.config.enabled and self._initialized
|
||||
|
||||
def _get_system_resources(self) -> Tuple[int, int, int]:
|
||||
"""Gets RAM, CPU cores, and FD limit."""
|
||||
total_ram_mb = 0
|
||||
cpu_cores = 0
|
||||
try:
|
||||
mem_info = psutil.virtual_memory()
|
||||
total_ram_mb = mem_info.total // (1024 * 1024)
|
||||
cpu_cores = psutil.cpu_count(logical=False) or psutil.cpu_count(logical=True) # Prefer physical cores
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Could not get RAM/CPU info via psutil: {e}")
|
||||
total_ram_mb = 2048 # Default fallback
|
||||
cpu_cores = 2 # Default fallback
|
||||
|
||||
fd_limit = 1024 # Default fallback
|
||||
try:
|
||||
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
fd_limit = soft_limit # Use the soft limit
|
||||
except (ImportError, ValueError, OSError, AttributeError) as e:
|
||||
self.logger.warning(f"Could not get file descriptor limit (common on Windows): {e}. Using default: {fd_limit}")
|
||||
|
||||
self.logger.info(f"System Resources: RAM={total_ram_mb}MB, Cores={cpu_cores}, FD Limit={fd_limit}")
|
||||
return total_ram_mb, cpu_cores, fd_limit
|
||||
|
||||
def _calculate_safe_pages(self) -> int:
|
||||
"""Calculates the safe number of concurrent pages based on resources."""
|
||||
if not self.config.auto_calculate_size:
|
||||
# If auto-calc is off, use max_pool_size as the hard limit
|
||||
# This isn't ideal based on the prompt, but provides *some* manual override
|
||||
# A dedicated `manual_safe_pages` might be better. Let's use max_pool_size for now.
|
||||
self.logger.warning("Auto-calculation disabled. Using max_pool_size as safe_pages limit.")
|
||||
return self.config.calculation_params.max_pool_size
|
||||
|
||||
params = self.config.calculation_params
|
||||
total_ram_mb, cpu_cores, fd_limit = self._get_system_resources()
|
||||
|
||||
available_ram_mb = total_ram_mb - params.mem_headroom_mb
|
||||
if available_ram_mb <= 0:
|
||||
self.logger.error(f"Not enough RAM ({total_ram_mb}MB) after headroom ({params.mem_headroom_mb}MB). Cannot calculate safe pages.")
|
||||
return params.min_pool_size # Fallback to minimum
|
||||
|
||||
try:
|
||||
# Calculate limits from each resource
|
||||
mem_limit = available_ram_mb // params.avg_page_mem_mb if params.avg_page_mem_mb > 0 else float('inf')
|
||||
fd_limit_pages = fd_limit // params.fd_per_page if params.fd_per_page > 0 else float('inf')
|
||||
cpu_limit = cpu_cores * params.core_multiplier if cpu_cores > 0 else float('inf')
|
||||
|
||||
# Determine the most constraining limit
|
||||
calculated_limit = math.floor(min(mem_limit, fd_limit_pages, cpu_limit))
|
||||
|
||||
except ZeroDivisionError:
|
||||
self.logger.error("Division by zero in safe_pages calculation (avg_page_mem_mb or fd_per_page is zero).")
|
||||
calculated_limit = params.min_pool_size # Fallback
|
||||
|
||||
# Clamp the result within min/max bounds
|
||||
safe_pages = max(params.min_pool_size, min(calculated_limit, params.max_pool_size))
|
||||
|
||||
self.logger.info(f"Calculated safe pages: MemoryLimit={mem_limit}, FDLimit={fd_limit_pages}, CPULimit={cpu_limit} -> RawCalc={calculated_limit} -> Clamped={safe_pages}")
|
||||
return safe_pages
|
||||
|
||||
async def _create_and_start_crawler(self, crawler_id: str) -> Optional[AsyncWebCrawler]:
|
||||
"""Creates, starts, and returns a crawler instance."""
|
||||
try:
|
||||
# Create BrowserConfig from the dictionary in manager config
|
||||
browser_conf = BrowserConfig(**self.config.browser_config)
|
||||
crawler = AsyncWebCrawler(config=browser_conf)
|
||||
await crawler.start()
|
||||
self.logger.info(f"Successfully started crawler instance: {crawler_id}")
|
||||
return crawler
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to start crawler instance {crawler_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initializes crawlers and semaphore. Called at server startup."""
|
||||
if not self.config.enabled or self._initialized:
|
||||
return
|
||||
|
||||
self.logger.info("Initializing CrawlerManager...")
|
||||
self._safe_pages = self._calculate_safe_pages()
|
||||
self._semaphore = asyncio.Semaphore(self._safe_pages)
|
||||
|
||||
self._primary_crawler = await self._create_and_start_crawler("Primary")
|
||||
if self._primary_crawler:
|
||||
self._primary_healthy = True
|
||||
else:
|
||||
self._primary_healthy = False
|
||||
self.logger.critical("Primary crawler failed to initialize!")
|
||||
|
||||
self._secondary_crawlers = []
|
||||
self._secondary_healthy_flags = []
|
||||
self._reload_tasks = [None] * (1 + self.config.backup_pool_size) # For primary + backups
|
||||
|
||||
for i in range(self.config.backup_pool_size):
|
||||
sec_id = f"Secondary-{i+1}"
|
||||
crawler = await self._create_and_start_crawler(sec_id)
|
||||
self._secondary_crawlers.append(crawler) # Add even if None
|
||||
self._secondary_healthy_flags.append(crawler is not None)
|
||||
if crawler is None:
|
||||
self.logger.error(f"{sec_id} crawler failed to initialize!")
|
||||
|
||||
# Set initial active crawler (prefer primary)
|
||||
if self._primary_healthy:
|
||||
self._active_crawler_index = 0
|
||||
self.logger.info("Primary crawler is active.")
|
||||
else:
|
||||
# Find the first healthy secondary
|
||||
found_healthy_backup = False
|
||||
for i, healthy in enumerate(self._secondary_healthy_flags):
|
||||
if healthy:
|
||||
self._active_crawler_index = i + 1 # 1-based index for secondaries
|
||||
self.logger.warning(f"Primary failed, Secondary-{i+1} is active.")
|
||||
found_healthy_backup = True
|
||||
break
|
||||
if not found_healthy_backup:
|
||||
self.logger.critical("FATAL: No healthy crawlers available after initialization!")
|
||||
# Server should probably refuse connections in this state
|
||||
|
||||
self._initialized = True
|
||||
self.logger.info(f"CrawlerManager initialized. Safe Pages: {self._safe_pages}. Active Crawler Index: {self._active_crawler_index}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shuts down all crawler instances. Called at server shutdown."""
|
||||
if not self._initialized or self._shutting_down:
|
||||
return
|
||||
|
||||
self._shutting_down = True
|
||||
self.logger.info("Shutting down CrawlerManager...")
|
||||
|
||||
# Cancel any ongoing reload tasks
|
||||
for i, task in enumerate(self._reload_tasks):
|
||||
if task and not task.done():
|
||||
try:
|
||||
task.cancel()
|
||||
await task # Wait for cancellation
|
||||
self.logger.info(f"Cancelled reload task for crawler index {i}.")
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info(f"Reload task for crawler index {i} was already cancelled.")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error cancelling reload task for crawler index {i}: {e}")
|
||||
self._reload_tasks = []
|
||||
|
||||
|
||||
# Close primary
|
||||
if self._primary_crawler:
|
||||
try:
|
||||
self.logger.info("Closing primary crawler...")
|
||||
await self._primary_crawler.close()
|
||||
self._primary_crawler = None
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing primary crawler: {e}", exc_info=True)
|
||||
|
||||
# Close secondaries
|
||||
for i, crawler in enumerate(self._secondary_crawlers):
|
||||
if crawler:
|
||||
try:
|
||||
self.logger.info(f"Closing secondary crawler {i+1}...")
|
||||
await crawler.close()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing secondary crawler {i+1}: {e}", exc_info=True)
|
||||
self._secondary_crawlers = []
|
||||
|
||||
self._initialized = False
|
||||
self.logger.info("CrawlerManager shut down complete.")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_crawler(self) -> AsyncGenerator[AsyncWebCrawler, None]:
|
||||
"""Acquires semaphore, yields active crawler, handles throttling & failover."""
|
||||
if not self.is_enabled():
|
||||
raise NoHealthyCrawlerError("CrawlerManager is disabled or not initialized.")
|
||||
|
||||
if self._shutting_down:
|
||||
raise NoHealthyCrawlerError("CrawlerManager is shutting down.")
|
||||
|
||||
active_crawler: Optional[AsyncWebCrawler] = None
|
||||
acquired = False
|
||||
request_id = uuid.uuid4()
|
||||
start_wait = time.time()
|
||||
|
||||
# --- Throttling ---
|
||||
try:
|
||||
# Check semaphore value without acquiring
|
||||
current_usage = self._safe_pages - self._semaphore._value
|
||||
usage_percent = (current_usage / self._safe_pages) * 100 if self._safe_pages > 0 else 0
|
||||
|
||||
if usage_percent >= self.config.throttle_threshold_percent:
|
||||
delay = random.uniform(self.config.throttle_delay_min_s, self.config.throttle_delay_max_s)
|
||||
self.logger.debug(f"Throttling: Usage {usage_percent:.1f}% >= {self.config.throttle_threshold_percent}%. Delaying {delay:.3f}s")
|
||||
await asyncio.sleep(delay)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error during throttling check: {e}") # Continue attempt even if throttle check fails
|
||||
|
||||
# --- Acquire Semaphore ---
|
||||
try:
|
||||
# self.logger.debug(f"Attempting to acquire semaphore (Available: {self._semaphore._value}/{self._safe_pages}). Wait Timeout: {self.config.max_wait_time_s}s")
|
||||
|
||||
# --- Logging Before Acquire ---
|
||||
sem_value = self._semaphore._value if self._semaphore else 'N/A'
|
||||
sem_waiters = len(self._semaphore._waiters) if self._semaphore and self._semaphore._waiters else 0
|
||||
self.logger.debug(f"Req {request_id}: Attempting acquire. Available={sem_value}/{self._safe_pages}, Waiters={sem_waiters}, Timeout={self.config.max_wait_time_s}s")
|
||||
|
||||
await asyncio.wait_for(
|
||||
self._semaphore.acquire(), timeout=self.config.max_wait_time_s
|
||||
)
|
||||
acquired = True
|
||||
wait_duration = time.time() - start_wait
|
||||
if wait_duration > 1:
|
||||
self.logger.warning(f"Semaphore acquired after {wait_duration:.3f}s. (Available: {self._semaphore._value}/{self._safe_pages})")
|
||||
|
||||
self.logger.debug(f"Semaphore acquired successfully after {wait_duration:.3f}s. (Available: {self._semaphore._value}/{self._safe_pages})")
|
||||
|
||||
# --- Select Active Crawler (Critical Section) ---
|
||||
async with self._state_lock:
|
||||
current_active_index = self._active_crawler_index
|
||||
is_primary_active = (current_active_index == 0)
|
||||
|
||||
if is_primary_active:
|
||||
if self._primary_healthy and self._primary_crawler:
|
||||
active_crawler = self._primary_crawler
|
||||
else:
|
||||
# Primary is supposed to be active but isn't healthy
|
||||
self.logger.warning("Primary crawler unhealthy, attempting immediate failover...")
|
||||
if not await self._try_failover_sync(): # Try to switch active crawler NOW
|
||||
raise NoHealthyCrawlerError("Primary unhealthy and no healthy backup available.")
|
||||
# If failover succeeded, active_crawler_index is updated
|
||||
current_active_index = self._active_crawler_index
|
||||
# Fall through to select the new active secondary
|
||||
|
||||
# Check if we need to use a secondary (either initially or after failover)
|
||||
if current_active_index > 0:
|
||||
secondary_idx = current_active_index - 1
|
||||
if secondary_idx < len(self._secondary_crawlers) and \
|
||||
self._secondary_healthy_flags[secondary_idx] and \
|
||||
self._secondary_crawlers[secondary_idx]:
|
||||
active_crawler = self._secondary_crawlers[secondary_idx]
|
||||
else:
|
||||
self.logger.error(f"Selected Secondary-{current_active_index} is unhealthy or missing.")
|
||||
# Attempt failover to *another* secondary if possible? (Adds complexity)
|
||||
# For now, raise error if the selected one isn't good.
|
||||
raise NoHealthyCrawlerError(f"Selected Secondary-{current_active_index} is unavailable.")
|
||||
|
||||
if active_crawler is None:
|
||||
# This shouldn't happen if logic above is correct, but safeguard
|
||||
raise NoHealthyCrawlerError("Failed to select a healthy active crawler.")
|
||||
|
||||
# --- Yield Crawler ---
|
||||
try:
|
||||
yield active_crawler
|
||||
except Exception as crawl_error:
|
||||
self.logger.error(f"Error during crawl execution using {active_crawler}: {crawl_error}", exc_info=True)
|
||||
# Determine if this error warrants failover
|
||||
# For now, let's assume any exception triggers a health check/failover attempt
|
||||
await self._handle_crawler_failure(active_crawler)
|
||||
raise # Re-raise the original error for the API handler
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f"Timeout waiting for semaphore after {self.config.max_wait_time_s}s.")
|
||||
raise PoolTimeoutError(f"Timed out waiting for available crawler resource after {self.config.max_wait_time_s}s")
|
||||
except NoHealthyCrawlerError:
|
||||
# Logged within the selection logic
|
||||
raise # Re-raise for API handler
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error in get_crawler context manager: {e}", exc_info=True)
|
||||
raise # Re-raise potentially unknown errors
|
||||
finally:
|
||||
if acquired:
|
||||
self._semaphore.release()
|
||||
self.logger.debug(f"Semaphore released. (Available: {self._semaphore._value}/{self._safe_pages})")
|
||||
|
||||
|
||||
async def _try_failover_sync(self) -> bool:
|
||||
"""Synchronous part of failover logic (must be called under state_lock). Finds next healthy secondary."""
|
||||
if not self._primary_healthy: # Only failover if primary is already marked down
|
||||
found_healthy_backup = False
|
||||
start_idx = (self._active_crawler_index % (self.config.backup_pool_size +1)) # Start check after current
|
||||
for i in range(self.config.backup_pool_size):
|
||||
check_idx = (start_idx + i) % self.config.backup_pool_size # Circular check
|
||||
if self._secondary_healthy_flags[check_idx] and self._secondary_crawlers[check_idx]:
|
||||
self._active_crawler_index = check_idx + 1
|
||||
self.logger.warning(f"Failover successful: Switched active crawler to Secondary-{self._active_crawler_index}")
|
||||
found_healthy_backup = True
|
||||
break # Found one
|
||||
if not found_healthy_backup:
|
||||
# If primary is down AND no backups are healthy, mark primary as active index (0) but it's still unhealthy
|
||||
self._active_crawler_index = 0
|
||||
self.logger.error("Failover failed: No healthy secondary crawlers available.")
|
||||
return False
|
||||
return True
|
||||
return True # Primary is healthy, no failover needed
|
||||
|
||||
async def _handle_crawler_failure(self, failed_crawler: AsyncWebCrawler):
|
||||
"""Handles marking a crawler as unhealthy and initiating recovery."""
|
||||
if self._shutting_down: return # Don't handle failures during shutdown
|
||||
|
||||
async with self._state_lock:
|
||||
crawler_index = -1
|
||||
is_primary = False
|
||||
|
||||
if failed_crawler is self._primary_crawler and self._primary_healthy:
|
||||
self.logger.warning("Primary crawler reported failure.")
|
||||
self._primary_healthy = False
|
||||
is_primary = True
|
||||
crawler_index = 0
|
||||
# Try immediate failover within the lock
|
||||
await self._try_failover_sync()
|
||||
# Start reload task if not already running for primary
|
||||
if self._reload_tasks[0] is None or self._reload_tasks[0].done():
|
||||
self.logger.info("Initiating primary crawler reload task.")
|
||||
self._reload_tasks[0] = asyncio.create_task(self._reload_crawler(0))
|
||||
|
||||
else:
|
||||
# Check if it was one of the secondaries
|
||||
for i, crawler in enumerate(self._secondary_crawlers):
|
||||
if failed_crawler is crawler and self._secondary_healthy_flags[i]:
|
||||
self.logger.warning(f"Secondary-{i+1} crawler reported failure.")
|
||||
self._secondary_healthy_flags[i] = False
|
||||
is_primary = False
|
||||
crawler_index = i + 1
|
||||
# If this *was* the active crawler, trigger failover check
|
||||
if self._active_crawler_index == crawler_index:
|
||||
self.logger.warning(f"Active secondary {crawler_index} failed, attempting failover...")
|
||||
await self._try_failover_sync()
|
||||
# Start reload task for this secondary
|
||||
if self._reload_tasks[crawler_index] is None or self._reload_tasks[crawler_index].done():
|
||||
self.logger.info(f"Initiating Secondary-{i+1} crawler reload task.")
|
||||
self._reload_tasks[crawler_index] = asyncio.create_task(self._reload_crawler(crawler_index))
|
||||
break # Found the failed secondary
|
||||
|
||||
if crawler_index == -1:
|
||||
self.logger.debug("Failure reported by an unknown or already unhealthy crawler instance. Ignoring.")
|
||||
|
||||
|
||||
async def _reload_crawler(self, crawler_index_to_reload: int):
|
||||
"""Background task to close, recreate, and start a specific crawler."""
|
||||
is_primary = (crawler_index_to_reload == 0)
|
||||
crawler_id = "Primary" if is_primary else f"Secondary-{crawler_index_to_reload}"
|
||||
original_crawler = self._primary_crawler if is_primary else self._secondary_crawlers[crawler_index_to_reload - 1]
|
||||
|
||||
self.logger.info(f"Starting reload process for {crawler_id}...")
|
||||
|
||||
# 1. Delay before attempting reload (e.g., allow transient issues to clear)
|
||||
if not is_primary: # Maybe shorter delay for backups?
|
||||
await asyncio.sleep(self.config.primary_reload_delay_s / 2)
|
||||
else:
|
||||
await asyncio.sleep(self.config.primary_reload_delay_s)
|
||||
|
||||
|
||||
# 2. Attempt to close the old instance cleanly
|
||||
if original_crawler:
|
||||
try:
|
||||
self.logger.info(f"Attempting to close existing {crawler_id} instance...")
|
||||
await original_crawler.close()
|
||||
self.logger.info(f"Successfully closed old {crawler_id} instance.")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing old {crawler_id} instance during reload: {e}")
|
||||
|
||||
# 3. Create and start a new instance
|
||||
self.logger.info(f"Attempting to start new {crawler_id} instance...")
|
||||
new_crawler = await self._create_and_start_crawler(crawler_id)
|
||||
|
||||
# 4. Update state if successful
|
||||
async with self._state_lock:
|
||||
if new_crawler:
|
||||
self.logger.info(f"Successfully reloaded {crawler_id}. Marking as healthy.")
|
||||
if is_primary:
|
||||
self._primary_crawler = new_crawler
|
||||
self._primary_healthy = True
|
||||
# Switch back to primary if no other failures occurred
|
||||
# Check if ANY secondary is currently active
|
||||
secondary_is_active = self._active_crawler_index > 0
|
||||
if not secondary_is_active or not self._secondary_healthy_flags[self._active_crawler_index - 1]:
|
||||
self.logger.info("Switching active crawler back to primary.")
|
||||
self._active_crawler_index = 0
|
||||
else: # Is secondary
|
||||
secondary_idx = crawler_index_to_reload - 1
|
||||
self._secondary_crawlers[secondary_idx] = new_crawler
|
||||
self._secondary_healthy_flags[secondary_idx] = True
|
||||
# Potentially switch back if primary is still down and this was needed?
|
||||
if not self._primary_healthy and self._active_crawler_index == 0:
|
||||
self.logger.info(f"Primary still down, activating reloaded Secondary-{crawler_index_to_reload}.")
|
||||
self._active_crawler_index = crawler_index_to_reload
|
||||
|
||||
else:
|
||||
self.logger.error(f"Failed to reload {crawler_id}. It remains unhealthy.")
|
||||
# Keep the crawler marked as unhealthy
|
||||
if is_primary:
|
||||
self._primary_healthy = False # Ensure it stays false
|
||||
else:
|
||||
self._secondary_healthy_flags[crawler_index_to_reload - 1] = False
|
||||
|
||||
|
||||
# Clear the reload task reference for this index
|
||||
self._reload_tasks[crawler_index_to_reload] = None
|
||||
|
||||
|
||||
async def get_status(self) -> Dict:
|
||||
"""Returns the current status of the manager."""
|
||||
if not self.is_enabled():
|
||||
return {"status": "disabled"}
|
||||
|
||||
async with self._state_lock:
|
||||
active_id = "Primary" if self._active_crawler_index == 0 else f"Secondary-{self._active_crawler_index}"
|
||||
primary_status = "Healthy" if self._primary_healthy else "Unhealthy"
|
||||
secondary_statuses = [f"Secondary-{i+1}: {'Healthy' if healthy else 'Unhealthy'}"
|
||||
for i, healthy in enumerate(self._secondary_healthy_flags)]
|
||||
semaphore_available = self._semaphore._value if self._semaphore else 'N/A'
|
||||
semaphore_locked = len(self._semaphore._waiters) if self._semaphore and self._semaphore._waiters else 0
|
||||
|
||||
return {
|
||||
"status": "enabled",
|
||||
"safe_pages": self._safe_pages,
|
||||
"semaphore_available": semaphore_available,
|
||||
"semaphore_waiters": semaphore_locked,
|
||||
"active_crawler": active_id,
|
||||
"primary_status": primary_status,
|
||||
"secondary_statuses": secondary_statuses,
|
||||
"reloading_tasks": [i for i, t in enumerate(self._reload_tasks) if t and not t.done()]
|
||||
}
|
||||
@@ -1,8 +1,20 @@
|
||||
# Import from auth.py
|
||||
from auth import create_access_token, get_token_dependency, TokenRequest
|
||||
from api import (
|
||||
handle_markdown_request,
|
||||
handle_llm_qa,
|
||||
handle_stream_crawl_request,
|
||||
handle_crawl_request,
|
||||
stream_results,
|
||||
_get_memory_mb
|
||||
)
|
||||
from utils import FilterType, load_config, setup_logging, verify_email_domain
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional, Dict
|
||||
from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends
|
||||
from typing import List, Optional, Dict, AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends, status
|
||||
from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
|
||||
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
@@ -11,28 +23,39 @@ from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from redis import asyncio as aioredis
|
||||
from crawl4ai import (
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
AsyncLogger
|
||||
)
|
||||
|
||||
from crawler_manager import (
|
||||
CrawlerManager,
|
||||
CrawlerManagerConfig,
|
||||
PoolTimeoutError,
|
||||
NoHealthyCrawlerError
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
from utils import FilterType, load_config, setup_logging, verify_email_domain
|
||||
from api import (
|
||||
handle_markdown_request,
|
||||
handle_llm_qa,
|
||||
handle_stream_crawl_request,
|
||||
handle_crawl_request,
|
||||
stream_results
|
||||
)
|
||||
from auth import create_access_token, get_token_dependency, TokenRequest # Import from auth.py
|
||||
|
||||
__version__ = "0.2.6"
|
||||
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str] = Field(min_length=1, max_length=100)
|
||||
browser_config: Optional[Dict] = Field(default_factory=dict)
|
||||
crawler_config: Optional[Dict] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Load configuration and setup
|
||||
config = load_config()
|
||||
setup_logging(config)
|
||||
logger = AsyncLogger(
|
||||
log_file=config["logging"].get("log_file", "app.log"),
|
||||
verbose=config["logging"].get("verbose", False),
|
||||
tag_width=10,
|
||||
)
|
||||
|
||||
# Initialize Redis
|
||||
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
|
||||
@@ -44,9 +67,43 @@ limiter = Limiter(
|
||||
storage_uri=config["rate_limiting"]["storage_uri"]
|
||||
)
|
||||
|
||||
# --- Initialize Manager (will be done in lifespan) ---
|
||||
# Load manager config from the main config
|
||||
manager_config_dict = config.get("crawler_pool", {})
|
||||
# Use Pydantic to parse and validate
|
||||
manager_config = CrawlerManagerConfig(**manager_config_dict)
|
||||
crawler_manager = CrawlerManager(config=manager_config, logger=logger)
|
||||
|
||||
# --- FastAPI App and Lifespan ---
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("Starting up the server...")
|
||||
if manager_config.enabled:
|
||||
logger.info("Initializing Crawler Manager...")
|
||||
await crawler_manager.initialize()
|
||||
app.state.crawler_manager = crawler_manager # Store manager in app state
|
||||
logger.info("Crawler Manager is enabled.")
|
||||
else:
|
||||
logger.warning("Crawler Manager is disabled.")
|
||||
app.state.crawler_manager = None # Indicate disabled state
|
||||
|
||||
yield # Server runs here
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down server...")
|
||||
if app.state.crawler_manager:
|
||||
logger.info("Shutting down Crawler Manager...")
|
||||
await app.state.crawler_manager.shutdown()
|
||||
logger.info("Crawler Manager shut down.")
|
||||
logger.info("Server shut down.")
|
||||
|
||||
app = FastAPI(
|
||||
title=config["app"]["title"],
|
||||
version=config["app"]["version"]
|
||||
version=config["app"]["version"],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure middleware
|
||||
@@ -56,7 +113,9 @@ def setup_security_middleware(app, config):
|
||||
if sec_config.get("https_redirect", False):
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
if sec_config.get("trusted_hosts", []) != ["*"]:
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=sec_config["trusted_hosts"])
|
||||
app.add_middleware(TrustedHostMiddleware,
|
||||
allowed_hosts=sec_config["trusted_hosts"])
|
||||
|
||||
|
||||
setup_security_middleware(app, config)
|
||||
|
||||
@@ -68,6 +127,8 @@ if config["observability"]["prometheus"]["enabled"]:
|
||||
token_dependency = get_token_dependency(config)
|
||||
|
||||
# Middleware for security headers
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
@@ -75,7 +136,24 @@ async def add_security_headers(request: Request, call_next):
|
||||
response.headers.update(config["security"]["headers"])
|
||||
return response
|
||||
|
||||
|
||||
async def get_manager() -> CrawlerManager:
|
||||
# Ensure manager exists and is enabled before yielding
|
||||
if not hasattr(app.state, 'crawler_manager') or app.state.crawler_manager is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Crawler service is disabled or not initialized"
|
||||
)
|
||||
if not app.state.crawler_manager.is_enabled():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Crawler service is currently disabled"
|
||||
)
|
||||
return app.state.crawler_manager
|
||||
|
||||
# Token endpoint (always available, but usage depends on config)
|
||||
|
||||
|
||||
@app.post("/token")
|
||||
async def get_token(request_data: TokenRequest):
|
||||
if not verify_email_domain(request_data.email):
|
||||
@@ -84,6 +162,8 @@ async def get_token(request_data: TokenRequest):
|
||||
return {"email": request_data.email, "access_token": token, "token_type": "bearer"}
|
||||
|
||||
# Endpoints with conditional auth
|
||||
|
||||
|
||||
@app.get("/md/{url:path}")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
async def get_markdown(
|
||||
@@ -97,6 +177,7 @@ async def get_markdown(
|
||||
result = await handle_markdown_request(url, f, q, c, config)
|
||||
return PlainTextResponse(result)
|
||||
|
||||
|
||||
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
|
||||
async def llm_endpoint(
|
||||
request: Request,
|
||||
@@ -105,7 +186,8 @@ async def llm_endpoint(
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
):
|
||||
if not q:
|
||||
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Query parameter 'q' is required")
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = 'https://' + url
|
||||
try:
|
||||
@@ -114,37 +196,89 @@ async def llm_endpoint(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/schema")
|
||||
async def get_schema():
|
||||
from crawl4ai import BrowserConfig, CrawlerRunConfig
|
||||
return {"browser": BrowserConfig().dump(), "crawler": CrawlerRunConfig().dump()}
|
||||
|
||||
|
||||
@app.get(config["observability"]["health_check"]["endpoint"])
|
||||
async def health():
|
||||
return {"status": "ok", "timestamp": time.time(), "version": __version__}
|
||||
|
||||
|
||||
@app.get(config["observability"]["prometheus"]["endpoint"])
|
||||
async def metrics():
|
||||
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
|
||||
|
||||
|
||||
@app.get("/browswers")
|
||||
# Optional dependency
|
||||
async def health(manager: Optional[CrawlerManager] = Depends(get_manager, use_cache=False)):
|
||||
base_status = {"status": "ok", "timestamp": time.time(),
|
||||
"version": __version__}
|
||||
if manager:
|
||||
try:
|
||||
manager_status = await manager.get_status()
|
||||
base_status["crawler_manager"] = manager_status
|
||||
except Exception as e:
|
||||
base_status["crawler_manager"] = {
|
||||
"status": "error", "detail": str(e)}
|
||||
else:
|
||||
base_status["crawler_manager"] = {"status": "disabled"}
|
||||
return base_status
|
||||
|
||||
|
||||
@app.post("/crawl")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
async def crawl(
|
||||
request: Request,
|
||||
crawl_request: CrawlRequest,
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
manager: CrawlerManager = Depends(get_manager), # Use dependency
|
||||
token_data: Optional[Dict] = Depends(token_dependency) # Keep auth
|
||||
):
|
||||
if not crawl_request.urls:
|
||||
raise HTTPException(status_code=400, detail="At least one URL required")
|
||||
|
||||
results = await handle_crawl_request(
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config,
|
||||
crawler_config=crawl_request.crawler_config,
|
||||
config=config
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail="At least one URL required")
|
||||
|
||||
return JSONResponse(results)
|
||||
try:
|
||||
# Use the manager's context to get a crawler instance
|
||||
async with manager.get_crawler() as active_crawler:
|
||||
# Call the actual handler from api.py, passing the acquired crawler
|
||||
results_dict = await handle_crawl_request(
|
||||
crawler=active_crawler, # Pass the live crawler instance
|
||||
urls=crawl_request.urls,
|
||||
# Pass user-provided configs, these might override pool defaults if needed
|
||||
# Or the manager/handler could decide how to merge them
|
||||
browser_config=crawl_request.browser_config or {}, # Ensure dict
|
||||
crawler_config=crawl_request.crawler_config or {}, # Ensure dict
|
||||
config=config # Pass the global server config
|
||||
)
|
||||
return JSONResponse(results_dict)
|
||||
|
||||
except PoolTimeoutError as e:
|
||||
logger.warning(f"Request rejected due to pool timeout: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, # Or 429
|
||||
detail=f"Crawler resources busy. Please try again later. Timeout: {e}"
|
||||
)
|
||||
except NoHealthyCrawlerError as e:
|
||||
logger.error(f"Request failed as no healthy crawler available: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Crawler service temporarily unavailable: {e}"
|
||||
)
|
||||
except HTTPException: # Re-raise HTTP exceptions from handler
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error during batch crawl processing: {e}", exc_info=True)
|
||||
# Return generic error, details might be logged by handle_crawl_request
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An unexpected error occurred: {e}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/crawl/stream")
|
||||
@@ -152,23 +286,114 @@ async def crawl(
|
||||
async def crawl_stream(
|
||||
request: Request,
|
||||
crawl_request: CrawlRequest,
|
||||
manager: CrawlerManager = Depends(get_manager),
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
):
|
||||
if not crawl_request.urls:
|
||||
raise HTTPException(status_code=400, detail="At least one URL required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="At least one URL required")
|
||||
|
||||
crawler, results_gen = await handle_stream_crawl_request(
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config,
|
||||
crawler_config=crawl_request.crawler_config,
|
||||
config=config
|
||||
)
|
||||
try:
|
||||
# THIS IS A BIT WORK OF ART RATHER THAN ENGINEERING
|
||||
# Acquire the crawler context from the manager
|
||||
# IMPORTANT: The context needs to be active for the *duration* of the stream
|
||||
# This structure might be tricky with FastAPI's StreamingResponse which consumes
|
||||
# the generator *after* the endpoint function returns.
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(crawler, results_gen),
|
||||
media_type='application/x-ndjson',
|
||||
headers={'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Stream-Status': 'active'}
|
||||
)
|
||||
# --- Option A: Acquire crawler, pass to handler, handler yields ---
|
||||
# (Requires handler NOT to be async generator itself, but return one)
|
||||
# async with manager.get_crawler() as active_crawler:
|
||||
# # Handler returns the generator
|
||||
# _, results_gen = await handle_stream_crawl_request(
|
||||
# crawler=active_crawler,
|
||||
# urls=crawl_request.urls,
|
||||
# browser_config=crawl_request.browser_config or {},
|
||||
# crawler_config=crawl_request.crawler_config or {},
|
||||
# config=config
|
||||
# )
|
||||
# # PROBLEM: `active_crawler` context exits before StreamingResponse uses results_gen
|
||||
# # This releases the semaphore too early.
|
||||
|
||||
# --- Option B: Pass manager to handler, handler uses context internally ---
|
||||
# (Requires modifying handle_stream_crawl_request signature/logic)
|
||||
# This seems cleaner. Let's assume api.py is adapted for this.
|
||||
# We need a way for the generator yielded by stream_results to know when
|
||||
# to release the semaphore.
|
||||
|
||||
# --- Option C: Create a wrapper generator that handles context ---
|
||||
async def stream_wrapper(manager: CrawlerManager, crawl_request: CrawlRequest, config: dict) -> AsyncGenerator[bytes, None]:
|
||||
active_crawler = None
|
||||
try:
|
||||
async with manager.get_crawler() as acquired_crawler:
|
||||
active_crawler = acquired_crawler # Keep reference for cleanup
|
||||
# Call the handler which returns the raw result generator
|
||||
_crawler_ref, results_gen = await handle_stream_crawl_request(
|
||||
crawler=acquired_crawler,
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config or {},
|
||||
crawler_config=crawl_request.crawler_config or {},
|
||||
config=config
|
||||
)
|
||||
# Use the stream_results utility to format and yield
|
||||
async for data_bytes in stream_results(_crawler_ref, results_gen):
|
||||
yield data_bytes
|
||||
except (PoolTimeoutError, NoHealthyCrawlerError) as e:
|
||||
# Yield a final error message in the stream
|
||||
error_payload = {"status": "error", "detail": str(e)}
|
||||
yield (json.dumps(error_payload) + "\n").encode('utf-8')
|
||||
logger.warning(f"Stream request failed: {e}")
|
||||
# Re-raise might be better if StreamingResponse handles it? Test needed.
|
||||
except HTTPException as e: # Catch HTTP exceptions from handler setup
|
||||
error_payload = {"status": "error",
|
||||
"detail": e.detail, "status_code": e.status_code}
|
||||
yield (json.dumps(error_payload) + "\n").encode('utf-8')
|
||||
logger.warning(
|
||||
f"Stream request failed with HTTPException: {e.detail}")
|
||||
except Exception as e:
|
||||
error_payload = {"status": "error",
|
||||
"detail": f"Unexpected stream error: {e}"}
|
||||
yield (json.dumps(error_payload) + "\n").encode('utf-8')
|
||||
logger.error(
|
||||
f"Unexpected error during stream processing: {e}", exc_info=True)
|
||||
# finally:
|
||||
# Ensure crawler cleanup if stream_results doesn't handle it?
|
||||
# stream_results *should* call crawler.close(), but only on the
|
||||
# instance it received. If we pass the *manager* instead, this gets complex.
|
||||
# Let's stick to passing the acquired_crawler and rely on stream_results.
|
||||
|
||||
# Create the generator using the wrapper
|
||||
streaming_generator = stream_wrapper(manager, crawl_request, config)
|
||||
|
||||
return StreamingResponse(
|
||||
streaming_generator, # Use the wrapper
|
||||
media_type='application/x-ndjson',
|
||||
headers={'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive', 'X-Stream-Status': 'active'}
|
||||
)
|
||||
|
||||
except (PoolTimeoutError, NoHealthyCrawlerError) as e:
|
||||
# These might occur if get_crawler fails *before* stream starts
|
||||
# Or if the wrapper re-raises them.
|
||||
logger.warning(f"Stream request rejected before starting: {e}")
|
||||
status_code = status.HTTP_503_SERVICE_UNAVAILABLE # Or 429 for timeout
|
||||
# Don't raise HTTPException here, let the wrapper yield the error message.
|
||||
# If we want to return a non-200 initial status, need more complex handling.
|
||||
# Return an *empty* stream with error headers? Or just let wrapper yield error.
|
||||
|
||||
async def _error_stream():
|
||||
error_payload = {"status": "error", "detail": str(e)}
|
||||
yield (json.dumps(error_payload) + "\n").encode('utf-8')
|
||||
return StreamingResponse(_error_stream(), status_code=status_code, media_type='application/x-ndjson')
|
||||
|
||||
except HTTPException: # Re-raise HTTP exceptions from setup
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error setting up stream crawl: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"An unexpected error occurred setting up the stream: {e}"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
@@ -178,4 +403,4 @@ if __name__ == "__main__":
|
||||
port=config["app"]["port"],
|
||||
reload=config["app"]["reload"],
|
||||
timeout_keep_alive=config["app"]["timeout_keep_alive"]
|
||||
)
|
||||
)
|
||||
|
||||
516
tests/memory/test_stress_api.py
Normal file
516
tests/memory/test_stress_api.py
Normal file
@@ -0,0 +1,516 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Stress test for Crawl4AI's Docker API server (/crawl and /crawl/stream endpoints).
|
||||
|
||||
This version targets a running Crawl4AI API server, sending concurrent requests
|
||||
to test its ability to handle multiple crawl jobs simultaneously.
|
||||
It uses httpx for async HTTP requests and logs results per batch of requests,
|
||||
including server-side memory usage reported by the API.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Union, AsyncGenerator, Tuple
|
||||
import httpx
|
||||
import pathlib # Import pathlib explicitly
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
|
||||
# --- Constants ---
|
||||
# DEFAULT_API_URL = "http://localhost:11235" # Default port
|
||||
DEFAULT_API_URL = "http://localhost:8020" # Default port
|
||||
DEFAULT_URL_COUNT = 1000
|
||||
DEFAULT_MAX_CONCURRENT_REQUESTS = 5
|
||||
DEFAULT_CHUNK_SIZE = 10
|
||||
DEFAULT_REPORT_PATH = "reports_api"
|
||||
DEFAULT_STREAM_MODE = False
|
||||
REQUEST_TIMEOUT = 180.0
|
||||
|
||||
# Initialize Rich console
|
||||
console = Console()
|
||||
|
||||
# --- API Health Check (Unchanged) ---
|
||||
async def check_server_health(client: httpx.AsyncClient, health_endpoint: str = "/health"):
|
||||
"""Check if the API server is healthy."""
|
||||
console.print(f"[bold cyan]Checking API server health at {client.base_url}{health_endpoint}...[/]", end="")
|
||||
try:
|
||||
response = await client.get(health_endpoint, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
health_data = response.json()
|
||||
version = health_data.get('version', 'N/A')
|
||||
console.print(f"[bold green] Server OK! Version: {version}[/]")
|
||||
return True
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
console.print(f"\n[bold red]Server health check FAILED:[/]")
|
||||
console.print(f"Error: {e}")
|
||||
console.print(f"Is the server running and accessible at {client.base_url}?")
|
||||
return False
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]An unexpected error occurred during health check:[/]")
|
||||
console.print(e)
|
||||
return False
|
||||
|
||||
# --- API Stress Test Class ---
|
||||
class ApiStressTest:
|
||||
"""Orchestrates the stress test by sending concurrent requests to the API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str,
|
||||
url_count: int,
|
||||
max_concurrent_requests: int,
|
||||
chunk_size: int,
|
||||
report_path: str,
|
||||
stream_mode: bool,
|
||||
):
|
||||
self.api_base_url = api_url.rstrip('/')
|
||||
self.url_count = url_count
|
||||
self.max_concurrent_requests = max_concurrent_requests
|
||||
self.chunk_size = chunk_size
|
||||
self.report_path = pathlib.Path(report_path)
|
||||
self.report_path.mkdir(parents=True, exist_ok=True)
|
||||
self.stream_mode = stream_mode
|
||||
|
||||
self.test_id = time.strftime("%Y%m%d_%H%M%S")
|
||||
self.results_summary = {
|
||||
"test_id": self.test_id, "api_url": api_url, "url_count": url_count,
|
||||
"max_concurrent_requests": max_concurrent_requests, "chunk_size": chunk_size,
|
||||
"stream_mode": stream_mode, "start_time": "", "end_time": "",
|
||||
"total_time_seconds": 0, "successful_requests": 0, "failed_requests": 0,
|
||||
"successful_urls": 0, "failed_urls": 0, "total_urls_processed": 0,
|
||||
"total_api_calls": 0,
|
||||
"server_memory_metrics": { # To store aggregated server memory info
|
||||
"batch_mode_avg_delta_mb": None,
|
||||
"batch_mode_max_delta_mb": None,
|
||||
"stream_mode_avg_max_snapshot_mb": None,
|
||||
"stream_mode_max_max_snapshot_mb": None,
|
||||
"samples": [] # Store individual request memory results
|
||||
}
|
||||
}
|
||||
self.http_client = httpx.AsyncClient(base_url=self.api_base_url, timeout=REQUEST_TIMEOUT, limits=httpx.Limits(max_connections=max_concurrent_requests + 5, max_keepalive_connections=max_concurrent_requests))
|
||||
|
||||
async def close_client(self):
|
||||
"""Close the httpx client."""
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def run(self) -> Dict:
|
||||
"""Run the API stress test."""
|
||||
# No client memory tracker needed
|
||||
urls_to_process = [f"https://httpbin.org/anything/{uuid.uuid4()}" for _ in range(self.url_count)]
|
||||
url_chunks = [urls_to_process[i:i+self.chunk_size] for i in range(0, len(urls_to_process), self.chunk_size)]
|
||||
|
||||
self.results_summary["start_time"] = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
start_time = time.time()
|
||||
|
||||
console.print(f"\n[bold cyan]Crawl4AI API Stress Test - {self.url_count} URLs, {self.max_concurrent_requests} concurrent requests[/bold cyan]")
|
||||
console.print(f"[bold cyan]Target API:[/bold cyan] {self.api_base_url}, [bold cyan]Mode:[/bold cyan] {'Streaming' if self.stream_mode else 'Batch'}, [bold cyan]URLs per Request:[/bold cyan] {self.chunk_size}")
|
||||
# Removed client memory log
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
||||
|
||||
# Updated Batch logging header
|
||||
console.print("\n[bold]API Request Batch Progress:[/bold]")
|
||||
# Adjusted spacing and added Peak
|
||||
console.print("[bold] Batch | Progress | SrvMem Peak / Δ|Max (MB) | Reqs/sec | S/F URLs | Time (s) | Status [/bold]")
|
||||
# Adjust separator length if needed, looks okay for now
|
||||
console.print("─" * 95)
|
||||
|
||||
# No client memory monitor task needed
|
||||
|
||||
tasks = []
|
||||
total_api_calls = len(url_chunks)
|
||||
self.results_summary["total_api_calls"] = total_api_calls
|
||||
|
||||
try:
|
||||
for i, chunk in enumerate(url_chunks):
|
||||
task = asyncio.create_task(self._make_api_request(
|
||||
chunk=chunk,
|
||||
batch_idx=i + 1,
|
||||
total_batches=total_api_calls,
|
||||
semaphore=semaphore
|
||||
# No memory tracker passed
|
||||
))
|
||||
tasks.append(task)
|
||||
|
||||
api_results = await asyncio.gather(*tasks)
|
||||
|
||||
# Process aggregated results including server memory
|
||||
total_successful_requests = sum(1 for r in api_results if r['request_success'])
|
||||
total_failed_requests = total_api_calls - total_successful_requests
|
||||
total_successful_urls = sum(r['success_urls'] for r in api_results)
|
||||
total_failed_urls = sum(r['failed_urls'] for r in api_results)
|
||||
total_urls_processed = total_successful_urls + total_failed_urls
|
||||
|
||||
# Aggregate server memory metrics
|
||||
valid_samples = [r for r in api_results if r.get('server_delta_or_max_mb') is not None] # Filter results with valid mem data
|
||||
self.results_summary["server_memory_metrics"]["samples"] = valid_samples # Store raw samples with both peak and delta/max
|
||||
|
||||
if valid_samples:
|
||||
delta_or_max_values = [r['server_delta_or_max_mb'] for r in valid_samples]
|
||||
if self.stream_mode:
|
||||
# Stream mode: delta_or_max holds max snapshot
|
||||
self.results_summary["server_memory_metrics"]["stream_mode_avg_max_snapshot_mb"] = sum(delta_or_max_values) / len(delta_or_max_values)
|
||||
self.results_summary["server_memory_metrics"]["stream_mode_max_max_snapshot_mb"] = max(delta_or_max_values)
|
||||
else: # Batch mode
|
||||
# delta_or_max holds delta
|
||||
self.results_summary["server_memory_metrics"]["batch_mode_avg_delta_mb"] = sum(delta_or_max_values) / len(delta_or_max_values)
|
||||
self.results_summary["server_memory_metrics"]["batch_mode_max_delta_mb"] = max(delta_or_max_values)
|
||||
|
||||
# Aggregate peak values for batch mode
|
||||
peak_values = [r['server_peak_memory_mb'] for r in valid_samples if r.get('server_peak_memory_mb') is not None]
|
||||
if peak_values:
|
||||
self.results_summary["server_memory_metrics"]["batch_mode_avg_peak_mb"] = sum(peak_values) / len(peak_values)
|
||||
self.results_summary["server_memory_metrics"]["batch_mode_max_peak_mb"] = max(peak_values)
|
||||
|
||||
|
||||
self.results_summary.update({
|
||||
"successful_requests": total_successful_requests,
|
||||
"failed_requests": total_failed_requests,
|
||||
"successful_urls": total_successful_urls,
|
||||
"failed_urls": total_failed_urls,
|
||||
"total_urls_processed": total_urls_processed,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred during task execution: {e}[/bold red]")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# No finally block needed for monitor task
|
||||
|
||||
end_time = time.time()
|
||||
self.results_summary.update({
|
||||
"end_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"total_time_seconds": end_time - start_time,
|
||||
# No client memory report
|
||||
})
|
||||
self._save_results()
|
||||
return self.results_summary
|
||||
|
||||
async def _make_api_request(
|
||||
self,
|
||||
chunk: List[str],
|
||||
batch_idx: int,
|
||||
total_batches: int,
|
||||
semaphore: asyncio.Semaphore
|
||||
# No memory tracker
|
||||
) -> Dict:
|
||||
"""Makes a single API request for a chunk of URLs, handling concurrency and logging server memory."""
|
||||
request_success = False
|
||||
success_urls = 0
|
||||
failed_urls = 0
|
||||
status = "Pending"
|
||||
status_color = "grey"
|
||||
server_memory_metric = None # Store delta (batch) or max snapshot (stream)
|
||||
api_call_start_time = time.time()
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
# No client memory sampling
|
||||
|
||||
endpoint = "/crawl/stream" if self.stream_mode else "/crawl"
|
||||
payload = {
|
||||
"urls": chunk,
|
||||
"browser_config": {"type": "BrowserConfig", "params": {"headless": True}},
|
||||
"crawler_config": {
|
||||
"type": "CrawlerRunConfig",
|
||||
"params": {"cache_mode": "BYPASS", "stream": self.stream_mode}
|
||||
}
|
||||
}
|
||||
|
||||
if self.stream_mode:
|
||||
max_server_mem_snapshot = 0.0 # Track max memory seen in this stream
|
||||
async with self.http_client.stream("POST", endpoint, json=payload) as response:
|
||||
initial_status_code = response.status_code
|
||||
response.raise_for_status()
|
||||
|
||||
completed_marker_received = False
|
||||
async for line in response.aiter_lines():
|
||||
if line:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if data.get("status") == "completed":
|
||||
completed_marker_received = True
|
||||
break
|
||||
elif data.get("url"):
|
||||
if data.get("success"): success_urls += 1
|
||||
else: failed_urls += 1
|
||||
# Extract server memory snapshot per result
|
||||
mem_snapshot = data.get('server_memory_mb')
|
||||
if mem_snapshot is not None:
|
||||
max_server_mem_snapshot = max(max_server_mem_snapshot, float(mem_snapshot))
|
||||
except json.JSONDecodeError:
|
||||
console.print(f"[Batch {batch_idx}] [red]Stream decode error for line:[/red] {line}")
|
||||
failed_urls = len(chunk)
|
||||
break
|
||||
request_success = completed_marker_received
|
||||
if not request_success:
|
||||
failed_urls = len(chunk) - success_urls
|
||||
server_memory_metric = max_server_mem_snapshot # Use max snapshot for stream logging
|
||||
|
||||
else: # Batch mode
|
||||
response = await self.http_client.post(endpoint, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract server memory delta from the response
|
||||
server_memory_metric = data.get('server_memory_delta_mb')
|
||||
server_peak_mem_mb = data.get('server_peak_memory_mb')
|
||||
|
||||
if data.get("success") and "results" in data:
|
||||
request_success = True
|
||||
results_list = data.get("results", [])
|
||||
for result_item in results_list:
|
||||
if result_item.get("success"): success_urls += 1
|
||||
else: failed_urls += 1
|
||||
if len(results_list) != len(chunk):
|
||||
console.print(f"[Batch {batch_idx}] [yellow]Warning: Result count ({len(results_list)}) doesn't match URL count ({len(chunk)})[/yellow]")
|
||||
failed_urls = len(chunk) - success_urls
|
||||
else:
|
||||
request_success = False
|
||||
failed_urls = len(chunk)
|
||||
# Try to get memory from error detail if available
|
||||
detail = data.get('detail')
|
||||
if isinstance(detail, str):
|
||||
try: detail_json = json.loads(detail)
|
||||
except: detail_json = {}
|
||||
elif isinstance(detail, dict):
|
||||
detail_json = detail
|
||||
else: detail_json = {}
|
||||
server_peak_mem_mb = detail_json.get('server_peak_memory_mb', None)
|
||||
server_memory_metric = detail_json.get('server_memory_delta_mb', None)
|
||||
console.print(f"[Batch {batch_idx}] [red]API request failed:[/red] {detail_json.get('error', 'No details')}")
|
||||
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
request_success = False
|
||||
failed_urls = len(chunk)
|
||||
console.print(f"[Batch {batch_idx}] [bold red]HTTP Error {e.response.status_code}:[/] {e.request.url}")
|
||||
try:
|
||||
error_detail = e.response.json()
|
||||
# Attempt to extract memory info even from error responses
|
||||
detail_content = error_detail.get('detail', {})
|
||||
if isinstance(detail_content, str): # Handle if detail is stringified JSON
|
||||
try: detail_content = json.loads(detail_content)
|
||||
except: detail_content = {}
|
||||
server_memory_metric = detail_content.get('server_memory_delta_mb', None)
|
||||
server_peak_mem_mb = detail_content.get('server_peak_memory_mb', None)
|
||||
console.print(f"Response: {error_detail}")
|
||||
except Exception:
|
||||
console.print(f"Response Text: {e.response.text[:200]}...")
|
||||
except httpx.RequestError as e:
|
||||
request_success = False
|
||||
failed_urls = len(chunk)
|
||||
console.print(f"[Batch {batch_idx}] [bold red]Request Error:[/bold] {e.request.url} - {e}")
|
||||
except Exception as e:
|
||||
request_success = False
|
||||
failed_urls = len(chunk)
|
||||
console.print(f"[Batch {batch_idx}] [bold red]Unexpected Error:[/bold] {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
api_call_time = time.time() - api_call_start_time
|
||||
total_processed_urls = success_urls + failed_urls
|
||||
|
||||
if request_success and failed_urls == 0: status_color, status = "green", "Success"
|
||||
elif request_success and success_urls > 0: status_color, status = "yellow", "Partial"
|
||||
else: status_color, status = "red", "Failed"
|
||||
|
||||
current_total_urls = batch_idx * self.chunk_size
|
||||
progress_pct = min(100.0, (current_total_urls / self.url_count) * 100)
|
||||
reqs_per_sec = 1.0 / api_call_time if api_call_time > 0 else float('inf')
|
||||
|
||||
# --- New Memory Formatting ---
|
||||
mem_display = " N/A " # Default
|
||||
peak_mem_value = None
|
||||
delta_or_max_value = None
|
||||
|
||||
if self.stream_mode:
|
||||
# server_memory_metric holds max snapshot for stream
|
||||
if server_memory_metric is not None:
|
||||
mem_display = f"{server_memory_metric:.1f} (Max)"
|
||||
delta_or_max_value = server_memory_metric # Store for aggregation
|
||||
else: # Batch mode - expect peak and delta
|
||||
# We need to get peak and delta from the API response
|
||||
peak_mem_value = locals().get('server_peak_mem_mb', None) # Get from response data if available
|
||||
delta_value = server_memory_metric # server_memory_metric holds delta for batch
|
||||
|
||||
if peak_mem_value is not None and delta_value is not None:
|
||||
mem_display = f"{peak_mem_value:.1f} / {delta_value:+.1f}"
|
||||
delta_or_max_value = delta_value # Store delta for aggregation
|
||||
elif peak_mem_value is not None:
|
||||
mem_display = f"{peak_mem_value:.1f} / N/A"
|
||||
elif delta_value is not None:
|
||||
mem_display = f"N/A / {delta_value:+.1f}"
|
||||
delta_or_max_value = delta_value # Store delta for aggregation
|
||||
|
||||
# --- Updated Print Statement with Adjusted Padding ---
|
||||
console.print(
|
||||
f" {batch_idx:<5} | {progress_pct:6.1f}% | {mem_display:>24} | {reqs_per_sec:8.1f} | " # Increased width for memory column
|
||||
f"{success_urls:^7}/{failed_urls:<6} | {api_call_time:8.2f} | [{status_color}]{status:<7}[/{status_color}] " # Added trailing space
|
||||
)
|
||||
|
||||
# --- Updated Return Dictionary ---
|
||||
return_data = {
|
||||
"batch_idx": batch_idx,
|
||||
"request_success": request_success,
|
||||
"success_urls": success_urls,
|
||||
"failed_urls": failed_urls,
|
||||
"time": api_call_time,
|
||||
# Return both peak (if available) and delta/max
|
||||
"server_peak_memory_mb": peak_mem_value, # Will be None for stream mode
|
||||
"server_delta_or_max_mb": delta_or_max_value # Delta for batch, Max for stream
|
||||
}
|
||||
# Add back the specific batch mode delta if needed elsewhere, but delta_or_max covers it
|
||||
# if not self.stream_mode:
|
||||
# return_data["server_memory_delta_mb"] = delta_value
|
||||
return return_data
|
||||
|
||||
# No _periodic_memory_sample needed
|
||||
|
||||
def _save_results(self) -> None:
|
||||
"""Saves the results summary to a JSON file."""
|
||||
results_path = self.report_path / f"api_test_summary_{self.test_id}.json"
|
||||
try:
|
||||
# No client memory path to convert
|
||||
with open(results_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.results_summary, f, indent=2, default=str)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Failed to save results summary: {e}[/bold red]")
|
||||
|
||||
|
||||
# --- run_full_test Function ---
|
||||
async def run_full_test(args):
|
||||
"""Runs the full API stress test process."""
|
||||
client = httpx.AsyncClient(base_url=args.api_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
if not await check_server_health(client):
|
||||
console.print("[bold red]Aborting test due to server health check failure.[/]")
|
||||
await client.aclose()
|
||||
return
|
||||
await client.aclose()
|
||||
|
||||
test = ApiStressTest(
|
||||
api_url=args.api_url,
|
||||
url_count=args.urls,
|
||||
max_concurrent_requests=args.max_concurrent_requests,
|
||||
chunk_size=args.chunk_size,
|
||||
report_path=args.report_path,
|
||||
stream_mode=args.stream,
|
||||
)
|
||||
results = {}
|
||||
try:
|
||||
results = await test.run()
|
||||
finally:
|
||||
await test.close_client()
|
||||
|
||||
if not results:
|
||||
console.print("[bold red]Test did not produce results.[/bold red]")
|
||||
return
|
||||
|
||||
console.print("\n" + "=" * 80)
|
||||
console.print("[bold green]API Stress Test Completed[/bold green]")
|
||||
console.print("=" * 80)
|
||||
|
||||
success_rate_reqs = results["successful_requests"] / results["total_api_calls"] * 100 if results["total_api_calls"] > 0 else 0
|
||||
success_rate_urls = results["successful_urls"] / results["url_count"] * 100 if results["url_count"] > 0 else 0
|
||||
urls_per_second = results["total_urls_processed"] / results["total_time_seconds"] if results["total_time_seconds"] > 0 else 0
|
||||
reqs_per_second = results["total_api_calls"] / results["total_time_seconds"] if results["total_time_seconds"] > 0 else 0
|
||||
|
||||
|
||||
console.print(f"[bold cyan]Test ID:[/bold cyan] {results['test_id']}")
|
||||
console.print(f"[bold cyan]Target API:[/bold cyan] {results['api_url']}")
|
||||
console.print(f"[bold cyan]Configuration:[/bold cyan] {results['url_count']} URLs, {results['max_concurrent_requests']} concurrent client requests, URLs/Req: {results['chunk_size']}, Stream: {results['stream_mode']}")
|
||||
console.print(f"[bold cyan]API Requests:[/bold cyan] {results['successful_requests']} successful, {results['failed_requests']} failed ({results['total_api_calls']} total, {success_rate_reqs:.1f}% success)")
|
||||
console.print(f"[bold cyan]URL Processing:[/bold cyan] {results['successful_urls']} successful, {results['failed_urls']} failed ({results['total_urls_processed']} processed, {success_rate_urls:.1f}% success)")
|
||||
console.print(f"[bold cyan]Performance:[/bold cyan] {results['total_time_seconds']:.2f}s total | Avg Reqs/sec: {reqs_per_second:.2f} | Avg URLs/sec: {urls_per_second:.2f}")
|
||||
|
||||
# Report Server Memory
|
||||
mem_metrics = results.get("server_memory_metrics", {})
|
||||
mem_samples = mem_metrics.get("samples", [])
|
||||
if mem_samples:
|
||||
num_samples = len(mem_samples)
|
||||
if results['stream_mode']:
|
||||
avg_mem = mem_metrics.get("stream_mode_avg_max_snapshot_mb")
|
||||
max_mem = mem_metrics.get("stream_mode_max_max_snapshot_mb")
|
||||
avg_str = f"{avg_mem:.1f}" if avg_mem is not None else "N/A"
|
||||
max_str = f"{max_mem:.1f}" if max_mem is not None else "N/A"
|
||||
console.print(f"[bold cyan]Server Memory (Stream):[/bold cyan] Avg Max Snapshot: {avg_str} MB | Max Max Snapshot: {max_str} MB (across {num_samples} requests)")
|
||||
else: # Batch mode
|
||||
avg_delta = mem_metrics.get("batch_mode_avg_delta_mb")
|
||||
max_delta = mem_metrics.get("batch_mode_max_delta_mb")
|
||||
avg_peak = mem_metrics.get("batch_mode_avg_peak_mb")
|
||||
max_peak = mem_metrics.get("batch_mode_max_peak_mb")
|
||||
|
||||
avg_delta_str = f"{avg_delta:.1f}" if avg_delta is not None else "N/A"
|
||||
max_delta_str = f"{max_delta:.1f}" if max_delta is not None else "N/A"
|
||||
avg_peak_str = f"{avg_peak:.1f}" if avg_peak is not None else "N/A"
|
||||
max_peak_str = f"{max_peak:.1f}" if max_peak is not None else "N/A"
|
||||
|
||||
console.print(f"[bold cyan]Server Memory (Batch):[/bold cyan] Avg Peak: {avg_peak_str} MB | Max Peak: {max_peak_str} MB | Avg Delta: {avg_delta_str} MB | Max Delta: {max_delta_str} MB (across {num_samples} requests)")
|
||||
else:
|
||||
console.print("[bold cyan]Server Memory:[/bold cyan] No memory data reported by server.")
|
||||
|
||||
|
||||
# No client memory report
|
||||
summary_path = pathlib.Path(args.report_path) / f"api_test_summary_{results['test_id']}.json"
|
||||
console.print(f"[bold green]Results summary saved to {summary_path}[/bold green]")
|
||||
|
||||
if results["failed_requests"] > 0:
|
||||
console.print(f"\n[bold yellow]Warning: {results['failed_requests']} API requests failed ({100-success_rate_reqs:.1f}% failure rate)[/bold yellow]")
|
||||
if results["failed_urls"] > 0:
|
||||
console.print(f"[bold yellow]Warning: {results['failed_urls']} URLs failed to process ({100-success_rate_urls:.1f}% URL failure rate)[/bold yellow]")
|
||||
if results["total_urls_processed"] < results["url_count"]:
|
||||
console.print(f"\n[bold red]Error: Only {results['total_urls_processed']} out of {results['url_count']} target URLs were processed![/bold red]")
|
||||
|
||||
|
||||
# --- main Function (Argument parsing mostly unchanged) ---
|
||||
def main():
|
||||
"""Main entry point for the script."""
|
||||
parser = argparse.ArgumentParser(description="Crawl4AI API Server Stress Test")
|
||||
|
||||
parser.add_argument("--api-url", type=str, default=DEFAULT_API_URL, help=f"Base URL of the Crawl4AI API server (default: {DEFAULT_API_URL})")
|
||||
parser.add_argument("--urls", type=int, default=DEFAULT_URL_COUNT, help=f"Total number of unique URLs to process via API calls (default: {DEFAULT_URL_COUNT})")
|
||||
parser.add_argument("--max-concurrent-requests", type=int, default=DEFAULT_MAX_CONCURRENT_REQUESTS, help=f"Maximum concurrent API requests from this client (default: {DEFAULT_MAX_CONCURRENT_REQUESTS})")
|
||||
parser.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help=f"Number of URLs per API request payload (default: {DEFAULT_CHUNK_SIZE})")
|
||||
parser.add_argument("--stream", action="store_true", default=DEFAULT_STREAM_MODE, help=f"Use the /crawl/stream endpoint instead of /crawl (default: {DEFAULT_STREAM_MODE})")
|
||||
parser.add_argument("--report-path", type=str, default=DEFAULT_REPORT_PATH, help=f"Path to save reports and logs (default: {DEFAULT_REPORT_PATH})")
|
||||
parser.add_argument("--clean-reports", action="store_true", help="Clean up report directory before running")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
console.print("[bold underline]Crawl4AI API Stress Test Configuration[/bold underline]")
|
||||
console.print(f"API URL: {args.api_url}")
|
||||
console.print(f"Total URLs: {args.urls}, Concurrent Client Requests: {args.max_concurrent_requests}, URLs per Request: {args.chunk_size}")
|
||||
console.print(f"Mode: {'Streaming' if args.stream else 'Batch'}")
|
||||
console.print(f"Report Path: {args.report_path}")
|
||||
console.print("-" * 40)
|
||||
if args.clean_reports: console.print("[cyan]Option: Clean reports before test[/cyan]")
|
||||
console.print("-" * 40)
|
||||
|
||||
if args.clean_reports:
|
||||
report_dir = pathlib.Path(args.report_path)
|
||||
if report_dir.exists():
|
||||
console.print(f"[yellow]Cleaning up reports directory: {args.report_path}[/yellow]")
|
||||
shutil.rmtree(args.report_path)
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
asyncio.run(run_full_test(args))
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[bold yellow]Test interrupted by user.[/bold yellow]")
|
||||
except Exception as e:
|
||||
console.print(f"\n[bold red]An unexpected error occurred:[/bold red] {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# No need to modify sys.path for SimpleMemoryTracker as it's removed
|
||||
main()
|
||||
129
tests/memory/test_stress_docker_api.py
Normal file
129
tests/memory/test_stress_docker_api.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Crawl4AI Docker API stress tester.
|
||||
|
||||
Examples
|
||||
--------
|
||||
python test_stress_docker_api.py --urls 1000 --concurrency 32
|
||||
python test_stress_docker_api.py --urls 1000 --concurrency 32 --stream
|
||||
python test_stress_docker_api.py --base-url http://10.0.0.42:11235 --http2
|
||||
"""
|
||||
|
||||
import argparse, asyncio, json, secrets, statistics, time
|
||||
from typing import List, Tuple
|
||||
import httpx
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
# ───────────────────────── helpers ─────────────────────────
|
||||
def make_fake_urls(n: int) -> List[str]:
|
||||
base = "https://httpbin.org/anything/"
|
||||
return [f"{base}{secrets.token_hex(8)}" for _ in range(n)]
|
||||
|
||||
|
||||
async def fire(
|
||||
client: httpx.AsyncClient, endpoint: str, payload: dict, sem: asyncio.Semaphore
|
||||
) -> Tuple[bool, float]:
|
||||
async with sem:
|
||||
print(f"POST {endpoint} with {len(payload['urls'])} URLs")
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
if endpoint.endswith("/stream"):
|
||||
async with client.stream("POST", endpoint, json=payload) as r:
|
||||
r.raise_for_status()
|
||||
async for _ in r.aiter_lines():
|
||||
pass
|
||||
else:
|
||||
r = await client.post(endpoint, json=payload)
|
||||
r.raise_for_status()
|
||||
return True, time.perf_counter() - t0
|
||||
except Exception:
|
||||
return False, time.perf_counter() - t0
|
||||
|
||||
|
||||
def pct(lat: List[float], p: float) -> str:
|
||||
"""Return percentile string even for tiny samples."""
|
||||
if not lat:
|
||||
return "-"
|
||||
if len(lat) == 1:
|
||||
return f"{lat[0]:.2f}s"
|
||||
lat_sorted = sorted(lat)
|
||||
k = (p / 100) * (len(lat_sorted) - 1)
|
||||
lo = int(k)
|
||||
hi = min(lo + 1, len(lat_sorted) - 1)
|
||||
frac = k - lo
|
||||
val = lat_sorted[lo] * (1 - frac) + lat_sorted[hi] * frac
|
||||
return f"{val:.2f}s"
|
||||
|
||||
|
||||
# ───────────────────────── main ─────────────────────────
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Stress test Crawl4AI Docker API")
|
||||
p.add_argument("--urls", type=int, default=100, help="number of URLs")
|
||||
p.add_argument("--concurrency", type=int, default=1, help="max POSTs in flight")
|
||||
p.add_argument("--chunk-size", type=int, default=50, help="URLs per request")
|
||||
p.add_argument("--base-url", default="http://localhost:11235", help="API root")
|
||||
# p.add_argument("--base-url", default="http://localhost:8020", help="API root")
|
||||
p.add_argument("--stream", action="store_true", help="use /crawl/stream")
|
||||
p.add_argument("--http2", action="store_true", help="enable HTTP/2")
|
||||
p.add_argument("--headless", action="store_true", default=True)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
urls = make_fake_urls(args.urls)
|
||||
batches = [urls[i : i + args.chunk_size] for i in range(0, len(urls), args.chunk_size)]
|
||||
endpoint = "/crawl/stream" if args.stream else "/crawl"
|
||||
sem = asyncio.Semaphore(args.concurrency)
|
||||
|
||||
async with httpx.AsyncClient(base_url=args.base_url, http2=args.http2, timeout=None) as client:
|
||||
with Progress(
|
||||
"[progress.description]{task.description}",
|
||||
BarColumn(),
|
||||
"[progress.percentage]{task.percentage:>3.0f}%",
|
||||
TimeElapsedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task_id = progress.add_task("[cyan]bombarding…", total=len(batches))
|
||||
tasks = []
|
||||
for chunk in batches:
|
||||
payload = {
|
||||
"urls": chunk,
|
||||
"browser_config": {"type": "BrowserConfig", "params": {"headless": args.headless}},
|
||||
"crawler_config": {"type": "CrawlerRunConfig", "params": {"cache_mode": "BYPASS", "stream": args.stream}},
|
||||
}
|
||||
tasks.append(asyncio.create_task(fire(client, endpoint, payload, sem)))
|
||||
progress.advance(task_id)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
ok_latencies = [dt for ok, dt in results if ok]
|
||||
err_count = sum(1 for ok, _ in results if not ok)
|
||||
|
||||
table = Table(title="Docker API Stress‑Test Summary")
|
||||
table.add_column("total", justify="right")
|
||||
table.add_column("errors", justify="right")
|
||||
table.add_column("p50", justify="right")
|
||||
table.add_column("p95", justify="right")
|
||||
table.add_column("max", justify="right")
|
||||
|
||||
table.add_row(
|
||||
str(len(results)),
|
||||
str(err_count),
|
||||
pct(ok_latencies, 50),
|
||||
pct(ok_latencies, 95),
|
||||
f"{max(ok_latencies):.2f}s" if ok_latencies else "-",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]aborted by user[/]")
|
||||
@@ -37,8 +37,8 @@ from crawl4ai import (
|
||||
DEFAULT_SITE_PATH = "test_site"
|
||||
DEFAULT_PORT = 8000
|
||||
DEFAULT_MAX_SESSIONS = 16
|
||||
DEFAULT_URL_COUNT = 100
|
||||
DEFAULT_CHUNK_SIZE = 10 # Define chunk size for batch logging
|
||||
DEFAULT_URL_COUNT = 1
|
||||
DEFAULT_CHUNK_SIZE = 1 # Define chunk size for batch logging
|
||||
DEFAULT_REPORT_PATH = "reports"
|
||||
DEFAULT_STREAM_MODE = False
|
||||
DEFAULT_MONITOR_MODE = "DETAILED"
|
||||
|
||||
Reference in New Issue
Block a user