diff --git a/main.py b/main.py index 71d5eeee..7c9976a3 100644 --- a/main.py +++ b/main.py @@ -1,254 +1,346 @@ -import os -import importlib import asyncio -from functools import lru_cache -import logging -logging.basicConfig(level=logging.DEBUG) - -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, JSONResponse -from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware -from fastapi.templating import Jinja2Templates -from fastapi.exceptions import RequestValidationError -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import FileResponse -from fastapi.responses import RedirectResponse - -from pydantic import BaseModel, HttpUrl -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List, Optional - -from crawl4ai.web_crawler import WebCrawler -from crawl4ai.database import get_total_count, clear_db - +from fastapi import FastAPI, HTTPException, BackgroundTasks, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel, HttpUrl, Field +from typing import Optional, List, Dict, Any, Union +import psutil import time -from slowapi import Limiter, _rate_limit_exceeded_handler -from slowapi.util import get_remote_address -from slowapi.errors import RateLimitExceeded - -# load .env file -from dotenv import load_dotenv -load_dotenv() - -# Configuration -__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) -MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests -current_requests = 0 -lock = asyncio.Lock() - -app = FastAPI() - -# Initialize rate limiter -def rate_limit_key_func(request: Request): - access_token = request.headers.get("access-token") - if access_token == os.environ.get('ACCESS_TOKEN'): - return None - return get_remote_address(request) - -limiter = Limiter(key_func=rate_limit_key_func) -app.state.limiter = limiter - -# Dictionary to store last request times for each client -last_request_times = {} -last_rate_limit = {} - - -def get_rate_limit(): - limit = os.environ.get('ACCESS_PER_MIN', "5") - return f"{limit}/minute" - -# Custom rate limit exceeded handler -async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: - if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60: - last_rate_limit[request.client.host] = time.time() - retry_after = 60 - (time.time() - last_rate_limit[request.client.host]) - reset_at = time.time() + retry_after - return JSONResponse( - status_code=429, - content={ - "detail": "Rate limit exceeded", - "limit": str(exc.limit.limit), - "retry_after": retry_after, - 'reset_at': reset_at, - "message": f"You have exceeded the rate limit of {exc.limit.limit}." - } - ) - -app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) - - -# Middleware for token-based bypass and per-request limit -class RateLimitMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10)) - access_token = request.headers.get("access-token") - if access_token == os.environ.get('ACCESS_TOKEN'): - return await call_next(request) - - path = request.url.path - if path in ["/crawl", "/old"]: - client_ip = request.client.host - current_time = time.time() - - # Check time since last request - if client_ip in last_request_times: - time_since_last_request = current_time - last_request_times[client_ip] - if time_since_last_request < SPAN: - return JSONResponse( - status_code=429, - content={ - "detail": "Too many requests", - "message": "Rate limit exceeded. Please wait 10 seconds between requests.", - "retry_after": max(0, SPAN - time_since_last_request), - "reset_at": current_time + max(0, SPAN - time_since_last_request), - } - ) - - last_request_times[client_ip] = current_time - - return await call_next(request) - -app.add_middleware(RateLimitMiddleware) - -# CORS configuration -origins = ["*"] # Allow all origins -app.add_middleware( - CORSMiddleware, - allow_origins=origins, # List of origins that are allowed to make requests - allow_credentials=True, - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers +import uuid +from collections import defaultdict +from urllib.parse import urlparse +import math +import logging +from enum import Enum +from dataclasses import dataclass +import json +from crawl4ai import AsyncWebCrawler, CrawlResult +from crawl4ai.extraction_strategy import ( + LLMExtractionStrategy, + CosineStrategy, + JsonCssExtractionStrategy, ) -# Mount the pages directory as a static directory -app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages") -app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs") -site_templates = Jinja2Templates(directory=__location__ + "/site") -templates = Jinja2Templates(directory=__location__ + "/pages") +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -@lru_cache() -def get_crawler(): - # Initialize and return a WebCrawler instance - crawler = WebCrawler(verbose = True) - crawler.warmup() - return crawler +class TaskStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + +class CrawlerType(str, Enum): + BASIC = "basic" + LLM = "llm" + COSINE = "cosine" + JSON_CSS = "json_css" + +class ExtractionConfig(BaseModel): + type: CrawlerType + params: Dict[str, Any] = {} class CrawlRequest(BaseModel): - urls: List[str] - include_raw_html: Optional[bool] = False - bypass_cache: bool = False - extract_blocks: bool = True - word_count_threshold: Optional[int] = 5 - extraction_strategy: Optional[str] = "NoExtractionStrategy" - extraction_strategy_args: Optional[dict] = {} - chunking_strategy: Optional[str] = "RegexChunking" - chunking_strategy_args: Optional[dict] = {} + urls: Union[HttpUrl, List[HttpUrl]] + extraction_config: Optional[ExtractionConfig] = None + crawler_params: Dict[str, Any] = {} + priority: int = Field(default=5, ge=1, le=10) + ttl: Optional[int] = 3600 + js_code: Optional[List[str]] = None + wait_for: Optional[str] = None css_selector: Optional[str] = None - screenshot: Optional[bool] = False - user_agent: Optional[str] = None - verbose: Optional[bool] = True + screenshot: bool = False + magic: bool = False -@app.get("/") -def read_root(): - return RedirectResponse(url="/mkdocs") +@dataclass +class TaskInfo: + id: str + status: TaskStatus + result: Optional[Union[CrawlResult, List[CrawlResult]]] = None + error: Optional[str] = None + created_at: float = time.time() + ttl: int = 3600 -@app.get("/old", response_class=HTMLResponse) -@limiter.limit(get_rate_limit()) -async def read_index(request: Request): - partials_dir = os.path.join(__location__, "pages", "partial") - partials = {} +class ResourceMonitor: + def __init__(self, max_concurrent_tasks: int = 10): + self.max_concurrent_tasks = max_concurrent_tasks + self.memory_threshold = 0.85 + self.cpu_threshold = 0.90 + self._last_check = 0 + self._check_interval = 1 # seconds + self._last_available_slots = max_concurrent_tasks - for filename in os.listdir(partials_dir): - if filename.endswith(".html"): - with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file: - partials[filename[:-5]] = file.read() + async def get_available_slots(self) -> int: + current_time = time.time() + if current_time - self._last_check < self._check_interval: + return self._last_available_slots - return templates.TemplateResponse("index.html", {"request": request, **partials}) + mem_usage = psutil.virtual_memory().percent / 100 + cpu_usage = psutil.cpu_percent() / 100 -@app.get("/total-count") -async def get_total_url_count(): - count = get_total_count() - return JSONResponse(content={"count": count}) + memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold) + cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) -@app.get("/clear-db") -async def clear_database(): - # clear_db() - return JSONResponse(content={"message": "Database cleared."}) + self._last_available_slots = math.floor( + self.max_concurrent_tasks * min(memory_factor, cpu_factor) + ) + self._last_check = current_time -def import_strategy(module_name: str, class_name: str, *args, **kwargs): - try: - module = importlib.import_module(module_name) - strategy_class = getattr(module, class_name) - return strategy_class(*args, **kwargs) - except ImportError: - print("ImportError: Module not found.") - raise HTTPException(status_code=400, detail=f"Module {module_name} not found.") - except AttributeError: - print("AttributeError: Class not found.") - raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.") + return self._last_available_slots + +class TaskManager: + def __init__(self, cleanup_interval: int = 300): + self.tasks: Dict[str, TaskInfo] = {} + self.high_priority = asyncio.PriorityQueue() + self.low_priority = asyncio.PriorityQueue() + self.cleanup_interval = cleanup_interval + self.cleanup_task = None + + async def start(self): + self.cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop(self): + if self.cleanup_task: + self.cleanup_task.cancel() + try: + await self.cleanup_task + except asyncio.CancelledError: + pass + + async def add_task(self, task_id: str, priority: int, ttl: int) -> None: + task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl) + self.tasks[task_id] = task_info + queue = self.high_priority if priority > 5 else self.low_priority + await queue.put((-priority, task_id)) # Negative for proper priority ordering + + async def get_next_task(self) -> Optional[str]: + try: + # Try high priority first + _, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1) + return task_id + except asyncio.TimeoutError: + try: + # Then try low priority + _, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1) + return task_id + except asyncio.TimeoutError: + return None + + def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None): + if task_id in self.tasks: + task_info = self.tasks[task_id] + task_info.status = status + task_info.result = result + task_info.error = error + + def get_task(self, task_id: str) -> Optional[TaskInfo]: + return self.tasks.get(task_id) + + async def _cleanup_loop(self): + while True: + try: + await asyncio.sleep(self.cleanup_interval) + current_time = time.time() + expired_tasks = [ + task_id + for task_id, task in self.tasks.items() + if current_time - task.created_at > task.ttl + and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] + ] + for task_id in expired_tasks: + del self.tasks[task_id] + except Exception as e: + logger.error(f"Error in cleanup loop: {e}") + +class CrawlerPool: + def __init__(self, max_size: int = 10): + self.max_size = max_size + self.active_crawlers: Dict[AsyncWebCrawler, float] = {} + self._lock = asyncio.Lock() + + async def acquire(self, **kwargs) -> AsyncWebCrawler: + async with self._lock: + # Clean up inactive crawlers + current_time = time.time() + inactive = [ + crawler + for crawler, last_used in self.active_crawlers.items() + if current_time - last_used > 600 # 10 minutes timeout + ] + for crawler in inactive: + await crawler.__aexit__(None, None, None) + del self.active_crawlers[crawler] + + # Create new crawler if needed + if len(self.active_crawlers) < self.max_size: + crawler = AsyncWebCrawler(**kwargs) + await crawler.__aenter__() + self.active_crawlers[crawler] = current_time + return crawler + + # Reuse least recently used crawler + crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0] + self.active_crawlers[crawler] = current_time + return crawler + + async def release(self, crawler: AsyncWebCrawler): + async with self._lock: + if crawler in self.active_crawlers: + self.active_crawlers[crawler] = time.time() + + async def cleanup(self): + async with self._lock: + for crawler in list(self.active_crawlers.keys()): + await crawler.__aexit__(None, None, None) + self.active_crawlers.clear() + +class CrawlerService: + def __init__(self, max_concurrent_tasks: int = 10): + self.resource_monitor = ResourceMonitor(max_concurrent_tasks) + self.task_manager = TaskManager() + self.crawler_pool = CrawlerPool(max_concurrent_tasks) + self._processing_task = None + + async def start(self): + await self.task_manager.start() + self._processing_task = asyncio.create_task(self._process_queue()) + + async def stop(self): + if self._processing_task: + self._processing_task.cancel() + try: + await self._processing_task + except asyncio.CancelledError: + pass + await self.task_manager.stop() + await self.crawler_pool.cleanup() + + def _create_extraction_strategy(self, config: ExtractionConfig): + if not config: + return None + + if config.type == CrawlerType.LLM: + return LLMExtractionStrategy(**config.params) + elif config.type == CrawlerType.COSINE: + return CosineStrategy(**config.params) + elif config.type == CrawlerType.JSON_CSS: + return JsonCssExtractionStrategy(**config.params) + return None + + async def submit_task(self, request: CrawlRequest) -> str: + task_id = str(uuid.uuid4()) + await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) + + # Store request data with task + self.task_manager.tasks[task_id].request = request + + return task_id + + async def _process_queue(self): + while True: + try: + available_slots = await self.resource_monitor.get_available_slots() + if available_slots <= 0: + await asyncio.sleep(1) + continue + + task_id = await self.task_manager.get_next_task() + if not task_id: + await asyncio.sleep(1) + continue + + task_info = self.task_manager.get_task(task_id) + if not task_info: + continue + + request = task_info.request + self.task_manager.update_task(task_id, TaskStatus.PROCESSING) + + try: + crawler = await self.crawler_pool.acquire(**request.crawler_params) + + extraction_strategy = self._create_extraction_strategy(request.extraction_config) + + if isinstance(request.urls, list): + results = await crawler.arun_many( + urls=[str(url) for url in request.urls], + extraction_strategy=extraction_strategy, + js_code=request.js_code, + wait_for=request.wait_for, + css_selector=request.css_selector, + screenshot=request.screenshot, + magic=request.magic, + ) + else: + results = await crawler.arun( + url=str(request.urls), + extraction_strategy=extraction_strategy, + js_code=request.js_code, + wait_for=request.wait_for, + css_selector=request.css_selector, + screenshot=request.screenshot, + magic=request.magic, + ) + + await self.crawler_pool.release(crawler) + self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results) + + except Exception as e: + logger.error(f"Error processing task {task_id}: {str(e)}") + self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e)) + + except Exception as e: + logger.error(f"Error in queue processing: {str(e)}") + await asyncio.sleep(1) + +app = FastAPI(title="Crawl4AI API") +crawler_service = CrawlerService() + +@app.on_event("startup") +async def startup_event(): + await crawler_service.start() + +@app.on_event("shutdown") +async def shutdown_event(): + await crawler_service.stop() @app.post("/crawl") -@limiter.limit(get_rate_limit()) -async def crawl_urls(crawl_request: CrawlRequest, request: Request): - logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}") - global current_requests - async with lock: - if current_requests >= MAX_CONCURRENT_REQUESTS: - raise HTTPException(status_code=429, detail="Too many requests - please try again later.") - current_requests += 1 +async def crawl(request: CrawlRequest) -> Dict[str, str]: + task_id = await crawler_service.submit_task(request) + return {"task_id": task_id} - try: - logging.debug("[LOG] Loading extraction and chunking strategies...") - crawl_request.extraction_strategy_args['verbose'] = True - crawl_request.chunking_strategy_args['verbose'] = True - - extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args) - chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args) +@app.get("/task/{task_id}") +async def get_task_status(task_id: str): + task_info = crawler_service.task_manager.get_task(task_id) + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") - # Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner - logging.debug("[LOG] Running the WebCrawler...") - with ThreadPoolExecutor() as executor: - loop = asyncio.get_event_loop() - futures = [ - loop.run_in_executor( - executor, - get_crawler().run, - str(url), - crawl_request.word_count_threshold, - extraction_strategy, - chunking_strategy, - crawl_request.bypass_cache, - crawl_request.css_selector, - crawl_request.screenshot, - crawl_request.user_agent, - crawl_request.verbose - ) - for url in crawl_request.urls - ] - results = await asyncio.gather(*futures) + response = { + "status": task_info.status, + "created_at": task_info.created_at, + } - # if include_raw_html is False, remove the raw HTML content from the results - if not crawl_request.include_raw_html: - for result in results: - result.html = None + if task_info.status == TaskStatus.COMPLETED: + # Convert CrawlResult to dict for JSON response + if isinstance(task_info.result, list): + response["results"] = [result.dict() for result in task_info.result] + else: + response["result"] = task_info.result.dict() + elif task_info.status == TaskStatus.FAILED: + response["error"] = task_info.error - return {"results": [result.model_dump() for result in results]} - finally: - async with lock: - current_requests -= 1 - -@app.get("/strategies/extraction", response_class=JSONResponse) -async def get_extraction_strategies(): - with open(f"{__location__}/docs/extraction_strategies.json", "r") as file: - return JSONResponse(content=file.read()) - -@app.get("/strategies/chunking", response_class=JSONResponse) -async def get_chunking_strategies(): - with open(f"{__location__}/docs/chunking_strategies.json", "r") as file: - return JSONResponse(content=file.read()) + return response +@app.get("/health") +async def health_check(): + available_slots = await crawler_service.resource_monitor.get_available_slots() + memory = psutil.virtual_memory() + return { + "status": "healthy", + "available_slots": available_slots, + "memory_usage": memory.percent, + "cpu_usage": psutil.cpu_percent(), + } if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8888) + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/main_v0.py b/main_v0.py new file mode 100644 index 00000000..71d5eeee --- /dev/null +++ b/main_v0.py @@ -0,0 +1,254 @@ +import os +import importlib +import asyncio +from functools import lru_cache +import logging +logging.basicConfig(level=logging.DEBUG) + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.templating import Jinja2Templates +from fastapi.exceptions import RequestValidationError +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import FileResponse +from fastapi.responses import RedirectResponse + +from pydantic import BaseModel, HttpUrl +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Optional + +from crawl4ai.web_crawler import WebCrawler +from crawl4ai.database import get_total_count, clear_db + +import time +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded + +# load .env file +from dotenv import load_dotenv +load_dotenv() + +# Configuration +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) +MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests +current_requests = 0 +lock = asyncio.Lock() + +app = FastAPI() + +# Initialize rate limiter +def rate_limit_key_func(request: Request): + access_token = request.headers.get("access-token") + if access_token == os.environ.get('ACCESS_TOKEN'): + return None + return get_remote_address(request) + +limiter = Limiter(key_func=rate_limit_key_func) +app.state.limiter = limiter + +# Dictionary to store last request times for each client +last_request_times = {} +last_rate_limit = {} + + +def get_rate_limit(): + limit = os.environ.get('ACCESS_PER_MIN', "5") + return f"{limit}/minute" + +# Custom rate limit exceeded handler +async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: + if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60: + last_rate_limit[request.client.host] = time.time() + retry_after = 60 - (time.time() - last_rate_limit[request.client.host]) + reset_at = time.time() + retry_after + return JSONResponse( + status_code=429, + content={ + "detail": "Rate limit exceeded", + "limit": str(exc.limit.limit), + "retry_after": retry_after, + 'reset_at': reset_at, + "message": f"You have exceeded the rate limit of {exc.limit.limit}." + } + ) + +app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) + + +# Middleware for token-based bypass and per-request limit +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10)) + access_token = request.headers.get("access-token") + if access_token == os.environ.get('ACCESS_TOKEN'): + return await call_next(request) + + path = request.url.path + if path in ["/crawl", "/old"]: + client_ip = request.client.host + current_time = time.time() + + # Check time since last request + if client_ip in last_request_times: + time_since_last_request = current_time - last_request_times[client_ip] + if time_since_last_request < SPAN: + return JSONResponse( + status_code=429, + content={ + "detail": "Too many requests", + "message": "Rate limit exceeded. Please wait 10 seconds between requests.", + "retry_after": max(0, SPAN - time_since_last_request), + "reset_at": current_time + max(0, SPAN - time_since_last_request), + } + ) + + last_request_times[client_ip] = current_time + + return await call_next(request) + +app.add_middleware(RateLimitMiddleware) + +# CORS configuration +origins = ["*"] # Allow all origins +app.add_middleware( + CORSMiddleware, + allow_origins=origins, # List of origins that are allowed to make requests + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +# Mount the pages directory as a static directory +app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages") +app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs") +site_templates = Jinja2Templates(directory=__location__ + "/site") +templates = Jinja2Templates(directory=__location__ + "/pages") + +@lru_cache() +def get_crawler(): + # Initialize and return a WebCrawler instance + crawler = WebCrawler(verbose = True) + crawler.warmup() + return crawler + +class CrawlRequest(BaseModel): + urls: List[str] + include_raw_html: Optional[bool] = False + bypass_cache: bool = False + extract_blocks: bool = True + word_count_threshold: Optional[int] = 5 + extraction_strategy: Optional[str] = "NoExtractionStrategy" + extraction_strategy_args: Optional[dict] = {} + chunking_strategy: Optional[str] = "RegexChunking" + chunking_strategy_args: Optional[dict] = {} + css_selector: Optional[str] = None + screenshot: Optional[bool] = False + user_agent: Optional[str] = None + verbose: Optional[bool] = True + +@app.get("/") +def read_root(): + return RedirectResponse(url="/mkdocs") + +@app.get("/old", response_class=HTMLResponse) +@limiter.limit(get_rate_limit()) +async def read_index(request: Request): + partials_dir = os.path.join(__location__, "pages", "partial") + partials = {} + + for filename in os.listdir(partials_dir): + if filename.endswith(".html"): + with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file: + partials[filename[:-5]] = file.read() + + return templates.TemplateResponse("index.html", {"request": request, **partials}) + +@app.get("/total-count") +async def get_total_url_count(): + count = get_total_count() + return JSONResponse(content={"count": count}) + +@app.get("/clear-db") +async def clear_database(): + # clear_db() + return JSONResponse(content={"message": "Database cleared."}) + +def import_strategy(module_name: str, class_name: str, *args, **kwargs): + try: + module = importlib.import_module(module_name) + strategy_class = getattr(module, class_name) + return strategy_class(*args, **kwargs) + except ImportError: + print("ImportError: Module not found.") + raise HTTPException(status_code=400, detail=f"Module {module_name} not found.") + except AttributeError: + print("AttributeError: Class not found.") + raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.") + +@app.post("/crawl") +@limiter.limit(get_rate_limit()) +async def crawl_urls(crawl_request: CrawlRequest, request: Request): + logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}") + global current_requests + async with lock: + if current_requests >= MAX_CONCURRENT_REQUESTS: + raise HTTPException(status_code=429, detail="Too many requests - please try again later.") + current_requests += 1 + + try: + logging.debug("[LOG] Loading extraction and chunking strategies...") + crawl_request.extraction_strategy_args['verbose'] = True + crawl_request.chunking_strategy_args['verbose'] = True + + extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args) + chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args) + + # Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner + logging.debug("[LOG] Running the WebCrawler...") + with ThreadPoolExecutor() as executor: + loop = asyncio.get_event_loop() + futures = [ + loop.run_in_executor( + executor, + get_crawler().run, + str(url), + crawl_request.word_count_threshold, + extraction_strategy, + chunking_strategy, + crawl_request.bypass_cache, + crawl_request.css_selector, + crawl_request.screenshot, + crawl_request.user_agent, + crawl_request.verbose + ) + for url in crawl_request.urls + ] + results = await asyncio.gather(*futures) + + # if include_raw_html is False, remove the raw HTML content from the results + if not crawl_request.include_raw_html: + for result in results: + result.html = None + + return {"results": [result.model_dump() for result in results]} + finally: + async with lock: + current_requests -= 1 + +@app.get("/strategies/extraction", response_class=JSONResponse) +async def get_extraction_strategies(): + with open(f"{__location__}/docs/extraction_strategies.json", "r") as file: + return JSONResponse(content=file.read()) + +@app.get("/strategies/chunking", response_class=JSONResponse) +async def get_chunking_strategies(): + with open(f"{__location__}/docs/chunking_strategies.json", "r") as file: + return JSONResponse(content=file.read()) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8888) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..19f938c8 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,281 @@ +import asyncio +import aiohttp +import json +import time +import os +from typing import Optional, Dict, Any +from pydantic import BaseModel, HttpUrl + +class NBCNewsAPITest: + def __init__(self, base_url: str = "http://localhost:8000"): + self.base_url = base_url + self.session = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.session: + await self.session.close() + + async def submit_crawl(self, request_data: Dict[str, Any]) -> str: + async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response: + result = await response.json() + return result["task_id"] + + async def get_task_status(self, task_id: str) -> Dict[str, Any]: + async with self.session.get(f"{self.base_url}/task/{task_id}") as response: + return await response.json() + + async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]: + start_time = time.time() + while True: + if time.time() - start_time > timeout: + raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") + + status = await self.get_task_status(task_id) + if status["status"] in ["completed", "failed"]: + return status + + await asyncio.sleep(poll_interval) + + async def check_health(self) -> Dict[str, Any]: + async with self.session.get(f"{self.base_url}/health") as response: + return await response.json() + +async def test_basic_crawl(): + print("\n=== Testing Basic Crawl ===") + async with NBCNewsAPITest() as api: + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 10 + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"Basic crawl result length: {len(result['result']['markdown'])}") + assert result["status"] == "completed" + assert "result" in result + assert result["result"]["success"] + +async def test_js_execution(): + print("\n=== Testing JS Execution ===") + async with NBCNewsAPITest() as api: + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 8, + "js_code": [ + "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" + ], + "wait_for": "article.tease-card:nth-child(10)", + "crawler_params": { + "headless": True + } + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"JS execution result length: {len(result['result']['markdown'])}") + assert result["status"] == "completed" + assert result["result"]["success"] + +async def test_css_selector(): + print("\n=== Testing CSS Selector ===") + async with NBCNewsAPITest() as api: + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 7, + "css_selector": ".wide-tease-item__description" + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"CSS selector result length: {len(result['result']['markdown'])}") + assert result["status"] == "completed" + assert result["result"]["success"] + +async def test_structured_extraction(): + print("\n=== Testing Structured Extraction ===") + async with NBCNewsAPITest() as api: + schema = { + "name": "NBC News Articles", + "baseSelector": "article.tease-card", + "fields": [ + { + "name": "title", + "selector": "h2", + "type": "text" + }, + { + "name": "description", + "selector": ".tease-card__description", + "type": "text" + }, + { + "name": "link", + "selector": "a", + "type": "attribute", + "attribute": "href" + } + ] + } + + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 9, + "extraction_config": { + "type": "json_css", + "params": { + "schema": schema + } + } + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + extracted = json.loads(result["result"]["extracted_content"]) + print(f"Extracted {len(extracted)} articles") + assert result["status"] == "completed" + assert result["result"]["success"] + assert len(extracted) > 0 + +async def test_batch_crawl(): + print("\n=== Testing Batch Crawl ===") + async with NBCNewsAPITest() as api: + request = { + "urls": [ + "https://www.nbcnews.com/business", + "https://www.nbcnews.com/business/consumer", + "https://www.nbcnews.com/business/economy" + ], + "priority": 6, + "crawler_params": { + "headless": True + } + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"Batch crawl completed, got {len(result['results'])} results") + assert result["status"] == "completed" + assert "results" in result + assert len(result["results"]) == 3 + +async def test_llm_extraction(): + print("\n=== Testing LLM Extraction with Ollama ===") + async with NBCNewsAPITest() as api: + schema = { + "type": "object", + "properties": { + "article_title": { + "type": "string", + "description": "The main title of the news article" + }, + "summary": { + "type": "string", + "description": "A brief summary of the article content" + }, + "main_topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Main topics or themes discussed in the article" + } + }, + "required": ["article_title", "summary", "main_topics"] + } + + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 8, + "extraction_config": { + "type": "llm", + "params": { + "provider": "openai/gpt-4o-mini", + "api_key": os.getenv("OLLAMA_API_KEY"), + "schema": schema, + "extraction_type": "schema", + "instruction": """Extract the main article information including title, a brief summary, and main topics discussed. + Focus on the primary business news article on the page.""" + } + }, + "crawler_params": { + "headless": True, + "word_count_threshold": 1 + } + } + + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + + if result["status"] == "completed": + extracted = json.loads(result["result"]["extracted_content"]) + print(f"Extracted article analysis:") + print(json.dumps(extracted, indent=2)) + + assert result["status"] == "completed" + assert result["result"]["success"] + +async def test_screenshot(): + print("\n=== Testing Screenshot ===") + async with NBCNewsAPITest() as api: + request = { + "urls": "https://www.nbcnews.com/business", + "priority": 5, + "screenshot": True, + "crawler_params": { + "headless": True + } + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print("Screenshot captured:", bool(result["result"]["screenshot"])) + assert result["status"] == "completed" + assert result["result"]["success"] + assert result["result"]["screenshot"] is not None + +async def test_priority_handling(): + print("\n=== Testing Priority Handling ===") + async with NBCNewsAPITest() as api: + # Submit low priority task first + low_priority = { + "urls": "https://www.nbcnews.com/business", + "priority": 1, + "crawler_params": {"headless": True} + } + low_task_id = await api.submit_crawl(low_priority) + + # Submit high priority task + high_priority = { + "urls": "https://www.nbcnews.com/business/consumer", + "priority": 10, + "crawler_params": {"headless": True} + } + high_task_id = await api.submit_crawl(high_priority) + + # Get both results + high_result = await api.wait_for_task(high_task_id) + low_result = await api.wait_for_task(low_task_id) + + print("Both tasks completed") + assert high_result["status"] == "completed" + assert low_result["status"] == "completed" + +async def main(): + try: + # Start with health check + async with NBCNewsAPITest() as api: + health = await api.check_health() + print("Server health:", health) + + # Run all tests + # await test_basic_crawl() + # await test_js_execution() + # await test_css_selector() + # await test_structured_extraction() + await test_llm_extraction() + # await test_batch_crawl() + # await test_screenshot() + # await test_priority_handling() + + except Exception as e: + print(f"Test failed: {str(e)}") + raise + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file