Creating the API server component
This commit is contained in:
550
main.py
550
main.py
@@ -1,254 +1,346 @@
|
||||
import os
|
||||
import importlib
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Optional
|
||||
|
||||
from crawl4ai.web_crawler import WebCrawler
|
||||
from crawl4ai.database import get_total_count, clear_db
|
||||
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, HttpUrl, Field
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
import psutil
|
||||
import time
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
||||
current_requests = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize rate limiter
|
||||
def rate_limit_key_func(request: Request):
|
||||
access_token = request.headers.get("access-token")
|
||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||
return None
|
||||
return get_remote_address(request)
|
||||
|
||||
limiter = Limiter(key_func=rate_limit_key_func)
|
||||
app.state.limiter = limiter
|
||||
|
||||
# Dictionary to store last request times for each client
|
||||
last_request_times = {}
|
||||
last_rate_limit = {}
|
||||
|
||||
|
||||
def get_rate_limit():
|
||||
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
||||
return f"{limit}/minute"
|
||||
|
||||
# Custom rate limit exceeded handler
|
||||
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60:
|
||||
last_rate_limit[request.client.host] = time.time()
|
||||
retry_after = 60 - (time.time() - last_rate_limit[request.client.host])
|
||||
reset_at = time.time() + retry_after
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"limit": str(exc.limit.limit),
|
||||
"retry_after": retry_after,
|
||||
'reset_at': reset_at,
|
||||
"message": f"You have exceeded the rate limit of {exc.limit.limit}."
|
||||
}
|
||||
)
|
||||
|
||||
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
# Middleware for token-based bypass and per-request limit
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
||||
access_token = request.headers.get("access-token")
|
||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path in ["/crawl", "/old"]:
|
||||
client_ip = request.client.host
|
||||
current_time = time.time()
|
||||
|
||||
# Check time since last request
|
||||
if client_ip in last_request_times:
|
||||
time_since_last_request = current_time - last_request_times[client_ip]
|
||||
if time_since_last_request < SPAN:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Too many requests",
|
||||
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
||||
"retry_after": max(0, SPAN - time_since_last_request),
|
||||
"reset_at": current_time + max(0, SPAN - time_since_last_request),
|
||||
}
|
||||
)
|
||||
|
||||
last_request_times[client_ip] = current_time
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# CORS configuration
|
||||
origins = ["*"] # Allow all origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins, # List of origins that are allowed to make requests
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from urllib.parse import urlparse
|
||||
import math
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from crawl4ai import AsyncWebCrawler, CrawlResult
|
||||
from crawl4ai.extraction_strategy import (
|
||||
LLMExtractionStrategy,
|
||||
CosineStrategy,
|
||||
JsonCssExtractionStrategy,
|
||||
)
|
||||
|
||||
# Mount the pages directory as a static directory
|
||||
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
|
||||
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
|
||||
site_templates = Jinja2Templates(directory=__location__ + "/site")
|
||||
templates = Jinja2Templates(directory=__location__ + "/pages")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@lru_cache()
|
||||
def get_crawler():
|
||||
# Initialize and return a WebCrawler instance
|
||||
crawler = WebCrawler(verbose = True)
|
||||
crawler.warmup()
|
||||
return crawler
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
class CrawlerType(str, Enum):
|
||||
BASIC = "basic"
|
||||
LLM = "llm"
|
||||
COSINE = "cosine"
|
||||
JSON_CSS = "json_css"
|
||||
|
||||
class ExtractionConfig(BaseModel):
|
||||
type: CrawlerType
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str]
|
||||
include_raw_html: Optional[bool] = False
|
||||
bypass_cache: bool = False
|
||||
extract_blocks: bool = True
|
||||
word_count_threshold: Optional[int] = 5
|
||||
extraction_strategy: Optional[str] = "NoExtractionStrategy"
|
||||
extraction_strategy_args: Optional[dict] = {}
|
||||
chunking_strategy: Optional[str] = "RegexChunking"
|
||||
chunking_strategy_args: Optional[dict] = {}
|
||||
urls: Union[HttpUrl, List[HttpUrl]]
|
||||
extraction_config: Optional[ExtractionConfig] = None
|
||||
crawler_params: Dict[str, Any] = {}
|
||||
priority: int = Field(default=5, ge=1, le=10)
|
||||
ttl: Optional[int] = 3600
|
||||
js_code: Optional[List[str]] = None
|
||||
wait_for: Optional[str] = None
|
||||
css_selector: Optional[str] = None
|
||||
screenshot: Optional[bool] = False
|
||||
user_agent: Optional[str] = None
|
||||
verbose: Optional[bool] = True
|
||||
screenshot: bool = False
|
||||
magic: bool = False
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return RedirectResponse(url="/mkdocs")
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
id: str
|
||||
status: TaskStatus
|
||||
result: Optional[Union[CrawlResult, List[CrawlResult]]] = None
|
||||
error: Optional[str] = None
|
||||
created_at: float = time.time()
|
||||
ttl: int = 3600
|
||||
|
||||
@app.get("/old", response_class=HTMLResponse)
|
||||
@limiter.limit(get_rate_limit())
|
||||
async def read_index(request: Request):
|
||||
partials_dir = os.path.join(__location__, "pages", "partial")
|
||||
partials = {}
|
||||
class ResourceMonitor:
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
self.memory_threshold = 0.85
|
||||
self.cpu_threshold = 0.90
|
||||
self._last_check = 0
|
||||
self._check_interval = 1 # seconds
|
||||
self._last_available_slots = max_concurrent_tasks
|
||||
|
||||
for filename in os.listdir(partials_dir):
|
||||
if filename.endswith(".html"):
|
||||
with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file:
|
||||
partials[filename[:-5]] = file.read()
|
||||
async def get_available_slots(self) -> int:
|
||||
current_time = time.time()
|
||||
if current_time - self._last_check < self._check_interval:
|
||||
return self._last_available_slots
|
||||
|
||||
return templates.TemplateResponse("index.html", {"request": request, **partials})
|
||||
mem_usage = psutil.virtual_memory().percent / 100
|
||||
cpu_usage = psutil.cpu_percent() / 100
|
||||
|
||||
@app.get("/total-count")
|
||||
async def get_total_url_count():
|
||||
count = get_total_count()
|
||||
return JSONResponse(content={"count": count})
|
||||
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold)
|
||||
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
|
||||
|
||||
@app.get("/clear-db")
|
||||
async def clear_database():
|
||||
# clear_db()
|
||||
return JSONResponse(content={"message": "Database cleared."})
|
||||
self._last_available_slots = math.floor(
|
||||
self.max_concurrent_tasks * min(memory_factor, cpu_factor)
|
||||
)
|
||||
self._last_check = current_time
|
||||
|
||||
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
strategy_class = getattr(module, class_name)
|
||||
return strategy_class(*args, **kwargs)
|
||||
except ImportError:
|
||||
print("ImportError: Module not found.")
|
||||
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
|
||||
except AttributeError:
|
||||
print("AttributeError: Class not found.")
|
||||
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
||||
return self._last_available_slots
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, cleanup_interval: int = 300):
|
||||
self.tasks: Dict[str, TaskInfo] = {}
|
||||
self.high_priority = asyncio.PriorityQueue()
|
||||
self.low_priority = asyncio.PriorityQueue()
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.cleanup_task = None
|
||||
|
||||
async def start(self):
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop(self):
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def add_task(self, task_id: str, priority: int, ttl: int) -> None:
|
||||
task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl)
|
||||
self.tasks[task_id] = task_info
|
||||
queue = self.high_priority if priority > 5 else self.low_priority
|
||||
await queue.put((-priority, task_id)) # Negative for proper priority ordering
|
||||
|
||||
async def get_next_task(self) -> Optional[str]:
|
||||
try:
|
||||
# Try high priority first
|
||||
_, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1)
|
||||
return task_id
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
# Then try low priority
|
||||
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1)
|
||||
return task_id
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None):
|
||||
if task_id in self.tasks:
|
||||
task_info = self.tasks[task_id]
|
||||
task_info.status = status
|
||||
task_info.result = result
|
||||
task_info.error = error
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TaskInfo]:
|
||||
return self.tasks.get(task_id)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
current_time = time.time()
|
||||
expired_tasks = [
|
||||
task_id
|
||||
for task_id, task in self.tasks.items()
|
||||
if current_time - task.created_at > task.ttl
|
||||
and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||||
]
|
||||
for task_id in expired_tasks:
|
||||
del self.tasks[task_id]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
class CrawlerPool:
|
||||
def __init__(self, max_size: int = 10):
|
||||
self.max_size = max_size
|
||||
self.active_crawlers: Dict[AsyncWebCrawler, float] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self, **kwargs) -> AsyncWebCrawler:
|
||||
async with self._lock:
|
||||
# Clean up inactive crawlers
|
||||
current_time = time.time()
|
||||
inactive = [
|
||||
crawler
|
||||
for crawler, last_used in self.active_crawlers.items()
|
||||
if current_time - last_used > 600 # 10 minutes timeout
|
||||
]
|
||||
for crawler in inactive:
|
||||
await crawler.__aexit__(None, None, None)
|
||||
del self.active_crawlers[crawler]
|
||||
|
||||
# Create new crawler if needed
|
||||
if len(self.active_crawlers) < self.max_size:
|
||||
crawler = AsyncWebCrawler(**kwargs)
|
||||
await crawler.__aenter__()
|
||||
self.active_crawlers[crawler] = current_time
|
||||
return crawler
|
||||
|
||||
# Reuse least recently used crawler
|
||||
crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0]
|
||||
self.active_crawlers[crawler] = current_time
|
||||
return crawler
|
||||
|
||||
async def release(self, crawler: AsyncWebCrawler):
|
||||
async with self._lock:
|
||||
if crawler in self.active_crawlers:
|
||||
self.active_crawlers[crawler] = time.time()
|
||||
|
||||
async def cleanup(self):
|
||||
async with self._lock:
|
||||
for crawler in list(self.active_crawlers.keys()):
|
||||
await crawler.__aexit__(None, None, None)
|
||||
self.active_crawlers.clear()
|
||||
|
||||
class CrawlerService:
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
|
||||
self.task_manager = TaskManager()
|
||||
self.crawler_pool = CrawlerPool(max_concurrent_tasks)
|
||||
self._processing_task = None
|
||||
|
||||
async def start(self):
|
||||
await self.task_manager.start()
|
||||
self._processing_task = asyncio.create_task(self._process_queue())
|
||||
|
||||
async def stop(self):
|
||||
if self._processing_task:
|
||||
self._processing_task.cancel()
|
||||
try:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self.task_manager.stop()
|
||||
await self.crawler_pool.cleanup()
|
||||
|
||||
def _create_extraction_strategy(self, config: ExtractionConfig):
|
||||
if not config:
|
||||
return None
|
||||
|
||||
if config.type == CrawlerType.LLM:
|
||||
return LLMExtractionStrategy(**config.params)
|
||||
elif config.type == CrawlerType.COSINE:
|
||||
return CosineStrategy(**config.params)
|
||||
elif config.type == CrawlerType.JSON_CSS:
|
||||
return JsonCssExtractionStrategy(**config.params)
|
||||
return None
|
||||
|
||||
async def submit_task(self, request: CrawlRequest) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600)
|
||||
|
||||
# Store request data with task
|
||||
self.task_manager.tasks[task_id].request = request
|
||||
|
||||
return task_id
|
||||
|
||||
async def _process_queue(self):
|
||||
while True:
|
||||
try:
|
||||
available_slots = await self.resource_monitor.get_available_slots()
|
||||
if available_slots <= 0:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
task_id = await self.task_manager.get_next_task()
|
||||
if not task_id:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
task_info = self.task_manager.get_task(task_id)
|
||||
if not task_info:
|
||||
continue
|
||||
|
||||
request = task_info.request
|
||||
self.task_manager.update_task(task_id, TaskStatus.PROCESSING)
|
||||
|
||||
try:
|
||||
crawler = await self.crawler_pool.acquire(**request.crawler_params)
|
||||
|
||||
extraction_strategy = self._create_extraction_strategy(request.extraction_config)
|
||||
|
||||
if isinstance(request.urls, list):
|
||||
results = await crawler.arun_many(
|
||||
urls=[str(url) for url in request.urls],
|
||||
extraction_strategy=extraction_strategy,
|
||||
js_code=request.js_code,
|
||||
wait_for=request.wait_for,
|
||||
css_selector=request.css_selector,
|
||||
screenshot=request.screenshot,
|
||||
magic=request.magic,
|
||||
)
|
||||
else:
|
||||
results = await crawler.arun(
|
||||
url=str(request.urls),
|
||||
extraction_strategy=extraction_strategy,
|
||||
js_code=request.js_code,
|
||||
wait_for=request.wait_for,
|
||||
css_selector=request.css_selector,
|
||||
screenshot=request.screenshot,
|
||||
magic=request.magic,
|
||||
)
|
||||
|
||||
await self.crawler_pool.release(crawler)
|
||||
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in queue processing: {str(e)}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
app = FastAPI(title="Crawl4AI API")
|
||||
crawler_service = CrawlerService()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
await crawler_service.start()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await crawler_service.stop()
|
||||
|
||||
@app.post("/crawl")
|
||||
@limiter.limit(get_rate_limit())
|
||||
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
||||
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
||||
global current_requests
|
||||
async with lock:
|
||||
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
||||
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
||||
current_requests += 1
|
||||
async def crawl(request: CrawlRequest) -> Dict[str, str]:
|
||||
task_id = await crawler_service.submit_task(request)
|
||||
return {"task_id": task_id}
|
||||
|
||||
try:
|
||||
logging.debug("[LOG] Loading extraction and chunking strategies...")
|
||||
crawl_request.extraction_strategy_args['verbose'] = True
|
||||
crawl_request.chunking_strategy_args['verbose'] = True
|
||||
|
||||
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
|
||||
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
|
||||
@app.get("/task/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
task_info = crawler_service.task_manager.get_task(task_id)
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
|
||||
logging.debug("[LOG] Running the WebCrawler...")
|
||||
with ThreadPoolExecutor() as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
futures = [
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
get_crawler().run,
|
||||
str(url),
|
||||
crawl_request.word_count_threshold,
|
||||
extraction_strategy,
|
||||
chunking_strategy,
|
||||
crawl_request.bypass_cache,
|
||||
crawl_request.css_selector,
|
||||
crawl_request.screenshot,
|
||||
crawl_request.user_agent,
|
||||
crawl_request.verbose
|
||||
)
|
||||
for url in crawl_request.urls
|
||||
]
|
||||
results = await asyncio.gather(*futures)
|
||||
response = {
|
||||
"status": task_info.status,
|
||||
"created_at": task_info.created_at,
|
||||
}
|
||||
|
||||
# if include_raw_html is False, remove the raw HTML content from the results
|
||||
if not crawl_request.include_raw_html:
|
||||
for result in results:
|
||||
result.html = None
|
||||
if task_info.status == TaskStatus.COMPLETED:
|
||||
# Convert CrawlResult to dict for JSON response
|
||||
if isinstance(task_info.result, list):
|
||||
response["results"] = [result.dict() for result in task_info.result]
|
||||
else:
|
||||
response["result"] = task_info.result.dict()
|
||||
elif task_info.status == TaskStatus.FAILED:
|
||||
response["error"] = task_info.error
|
||||
|
||||
return {"results": [result.model_dump() for result in results]}
|
||||
finally:
|
||||
async with lock:
|
||||
current_requests -= 1
|
||||
|
||||
@app.get("/strategies/extraction", response_class=JSONResponse)
|
||||
async def get_extraction_strategies():
|
||||
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
|
||||
return JSONResponse(content=file.read())
|
||||
|
||||
@app.get("/strategies/chunking", response_class=JSONResponse)
|
||||
async def get_chunking_strategies():
|
||||
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
|
||||
return JSONResponse(content=file.read())
|
||||
return response
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
available_slots = await crawler_service.resource_monitor.get_available_slots()
|
||||
memory = psutil.virtual_memory()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"available_slots": available_slots,
|
||||
"memory_usage": memory.percent,
|
||||
"cpu_usage": psutil.cpu_percent(),
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8888)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
254
main_v0.py
Normal file
254
main_v0.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import os
|
||||
import importlib
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Optional
|
||||
|
||||
from crawl4ai.web_crawler import WebCrawler
|
||||
from crawl4ai.database import get_total_count, clear_db
|
||||
|
||||
import time
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
||||
current_requests = 0
|
||||
lock = asyncio.Lock()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize rate limiter
|
||||
def rate_limit_key_func(request: Request):
|
||||
access_token = request.headers.get("access-token")
|
||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||
return None
|
||||
return get_remote_address(request)
|
||||
|
||||
limiter = Limiter(key_func=rate_limit_key_func)
|
||||
app.state.limiter = limiter
|
||||
|
||||
# Dictionary to store last request times for each client
|
||||
last_request_times = {}
|
||||
last_rate_limit = {}
|
||||
|
||||
|
||||
def get_rate_limit():
|
||||
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
||||
return f"{limit}/minute"
|
||||
|
||||
# Custom rate limit exceeded handler
|
||||
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60:
|
||||
last_rate_limit[request.client.host] = time.time()
|
||||
retry_after = 60 - (time.time() - last_rate_limit[request.client.host])
|
||||
reset_at = time.time() + retry_after
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"limit": str(exc.limit.limit),
|
||||
"retry_after": retry_after,
|
||||
'reset_at': reset_at,
|
||||
"message": f"You have exceeded the rate limit of {exc.limit.limit}."
|
||||
}
|
||||
)
|
||||
|
||||
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
# Middleware for token-based bypass and per-request limit
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
||||
access_token = request.headers.get("access-token")
|
||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path in ["/crawl", "/old"]:
|
||||
client_ip = request.client.host
|
||||
current_time = time.time()
|
||||
|
||||
# Check time since last request
|
||||
if client_ip in last_request_times:
|
||||
time_since_last_request = current_time - last_request_times[client_ip]
|
||||
if time_since_last_request < SPAN:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Too many requests",
|
||||
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
||||
"retry_after": max(0, SPAN - time_since_last_request),
|
||||
"reset_at": current_time + max(0, SPAN - time_since_last_request),
|
||||
}
|
||||
)
|
||||
|
||||
last_request_times[client_ip] = current_time
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# CORS configuration
|
||||
origins = ["*"] # Allow all origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins, # List of origins that are allowed to make requests
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
# Mount the pages directory as a static directory
|
||||
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
|
||||
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
|
||||
site_templates = Jinja2Templates(directory=__location__ + "/site")
|
||||
templates = Jinja2Templates(directory=__location__ + "/pages")
|
||||
|
||||
@lru_cache()
|
||||
def get_crawler():
|
||||
# Initialize and return a WebCrawler instance
|
||||
crawler = WebCrawler(verbose = True)
|
||||
crawler.warmup()
|
||||
return crawler
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str]
|
||||
include_raw_html: Optional[bool] = False
|
||||
bypass_cache: bool = False
|
||||
extract_blocks: bool = True
|
||||
word_count_threshold: Optional[int] = 5
|
||||
extraction_strategy: Optional[str] = "NoExtractionStrategy"
|
||||
extraction_strategy_args: Optional[dict] = {}
|
||||
chunking_strategy: Optional[str] = "RegexChunking"
|
||||
chunking_strategy_args: Optional[dict] = {}
|
||||
css_selector: Optional[str] = None
|
||||
screenshot: Optional[bool] = False
|
||||
user_agent: Optional[str] = None
|
||||
verbose: Optional[bool] = True
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return RedirectResponse(url="/mkdocs")
|
||||
|
||||
@app.get("/old", response_class=HTMLResponse)
|
||||
@limiter.limit(get_rate_limit())
|
||||
async def read_index(request: Request):
|
||||
partials_dir = os.path.join(__location__, "pages", "partial")
|
||||
partials = {}
|
||||
|
||||
for filename in os.listdir(partials_dir):
|
||||
if filename.endswith(".html"):
|
||||
with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file:
|
||||
partials[filename[:-5]] = file.read()
|
||||
|
||||
return templates.TemplateResponse("index.html", {"request": request, **partials})
|
||||
|
||||
@app.get("/total-count")
|
||||
async def get_total_url_count():
|
||||
count = get_total_count()
|
||||
return JSONResponse(content={"count": count})
|
||||
|
||||
@app.get("/clear-db")
|
||||
async def clear_database():
|
||||
# clear_db()
|
||||
return JSONResponse(content={"message": "Database cleared."})
|
||||
|
||||
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
strategy_class = getattr(module, class_name)
|
||||
return strategy_class(*args, **kwargs)
|
||||
except ImportError:
|
||||
print("ImportError: Module not found.")
|
||||
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
|
||||
except AttributeError:
|
||||
print("AttributeError: Class not found.")
|
||||
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
||||
|
||||
@app.post("/crawl")
|
||||
@limiter.limit(get_rate_limit())
|
||||
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
||||
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
||||
global current_requests
|
||||
async with lock:
|
||||
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
||||
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
||||
current_requests += 1
|
||||
|
||||
try:
|
||||
logging.debug("[LOG] Loading extraction and chunking strategies...")
|
||||
crawl_request.extraction_strategy_args['verbose'] = True
|
||||
crawl_request.chunking_strategy_args['verbose'] = True
|
||||
|
||||
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
|
||||
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
|
||||
|
||||
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
|
||||
logging.debug("[LOG] Running the WebCrawler...")
|
||||
with ThreadPoolExecutor() as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
futures = [
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
get_crawler().run,
|
||||
str(url),
|
||||
crawl_request.word_count_threshold,
|
||||
extraction_strategy,
|
||||
chunking_strategy,
|
||||
crawl_request.bypass_cache,
|
||||
crawl_request.css_selector,
|
||||
crawl_request.screenshot,
|
||||
crawl_request.user_agent,
|
||||
crawl_request.verbose
|
||||
)
|
||||
for url in crawl_request.urls
|
||||
]
|
||||
results = await asyncio.gather(*futures)
|
||||
|
||||
# if include_raw_html is False, remove the raw HTML content from the results
|
||||
if not crawl_request.include_raw_html:
|
||||
for result in results:
|
||||
result.html = None
|
||||
|
||||
return {"results": [result.model_dump() for result in results]}
|
||||
finally:
|
||||
async with lock:
|
||||
current_requests -= 1
|
||||
|
||||
@app.get("/strategies/extraction", response_class=JSONResponse)
|
||||
async def get_extraction_strategies():
|
||||
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
|
||||
return JSONResponse(content=file.read())
|
||||
|
||||
@app.get("/strategies/chunking", response_class=JSONResponse)
|
||||
async def get_chunking_strategies():
|
||||
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
|
||||
return JSONResponse(content=file.read())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8888)
|
||||
281
tests/test_main.py
Normal file
281
tests/test_main.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
class NBCNewsAPITest:
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
|
||||
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response:
|
||||
result = await response.json()
|
||||
return result["task_id"]
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
|
||||
return await response.json()
|
||||
|
||||
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
||||
|
||||
status = await self.get_task_status(task_id)
|
||||
if status["status"] in ["completed", "failed"]:
|
||||
return status
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
async def check_health(self) -> Dict[str, Any]:
|
||||
async with self.session.get(f"{self.base_url}/health") as response:
|
||||
return await response.json()
|
||||
|
||||
async def test_basic_crawl():
|
||||
print("\n=== Testing Basic Crawl ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result["status"] == "completed"
|
||||
assert "result" in result
|
||||
assert result["result"]["success"]
|
||||
|
||||
async def test_js_execution():
|
||||
print("\n=== Testing JS Execution ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 8,
|
||||
"js_code": [
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
],
|
||||
"wait_for": "article.tease-card:nth-child(10)",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
|
||||
async def test_css_selector():
|
||||
print("\n=== Testing CSS Selector ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 7,
|
||||
"css_selector": ".wide-tease-item__description"
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
|
||||
async def test_structured_extraction():
|
||||
print("\n=== Testing Structured Extraction ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
schema = {
|
||||
"name": "NBC News Articles",
|
||||
"baseSelector": "article.tease-card",
|
||||
"fields": [
|
||||
{
|
||||
"name": "title",
|
||||
"selector": "h2",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"selector": ".tease-card__description",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "link",
|
||||
"selector": "a",
|
||||
"type": "attribute",
|
||||
"attribute": "href"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 9,
|
||||
"extraction_config": {
|
||||
"type": "json_css",
|
||||
"params": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
print(f"Extracted {len(extracted)} articles")
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
assert len(extracted) > 0
|
||||
|
||||
async def test_batch_crawl():
|
||||
print("\n=== Testing Batch Crawl ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": [
|
||||
"https://www.nbcnews.com/business",
|
||||
"https://www.nbcnews.com/business/consumer",
|
||||
"https://www.nbcnews.com/business/economy"
|
||||
],
|
||||
"priority": 6,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
print(f"Batch crawl completed, got {len(result['results'])} results")
|
||||
assert result["status"] == "completed"
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
async def test_llm_extraction():
|
||||
print("\n=== Testing LLM Extraction with Ollama ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"article_title": {
|
||||
"type": "string",
|
||||
"description": "The main title of the news article"
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A brief summary of the article content"
|
||||
},
|
||||
"main_topics": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Main topics or themes discussed in the article"
|
||||
}
|
||||
},
|
||||
"required": ["article_title", "summary", "main_topics"]
|
||||
}
|
||||
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 8,
|
||||
"extraction_config": {
|
||||
"type": "llm",
|
||||
"params": {
|
||||
"provider": "openai/gpt-4o-mini",
|
||||
"api_key": os.getenv("OLLAMA_API_KEY"),
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
|
||||
Focus on the primary business news article on the page."""
|
||||
}
|
||||
},
|
||||
"crawler_params": {
|
||||
"headless": True,
|
||||
"word_count_threshold": 1
|
||||
}
|
||||
}
|
||||
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
|
||||
if result["status"] == "completed":
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
print(f"Extracted article analysis:")
|
||||
print(json.dumps(extracted, indent=2))
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
|
||||
async def test_screenshot():
|
||||
print("\n=== Testing Screenshot ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 5,
|
||||
"screenshot": True,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
print("Screenshot captured:", bool(result["result"]["screenshot"]))
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
assert result["result"]["screenshot"] is not None
|
||||
|
||||
async def test_priority_handling():
|
||||
print("\n=== Testing Priority Handling ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
# Submit low priority task first
|
||||
low_priority = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 1,
|
||||
"crawler_params": {"headless": True}
|
||||
}
|
||||
low_task_id = await api.submit_crawl(low_priority)
|
||||
|
||||
# Submit high priority task
|
||||
high_priority = {
|
||||
"urls": "https://www.nbcnews.com/business/consumer",
|
||||
"priority": 10,
|
||||
"crawler_params": {"headless": True}
|
||||
}
|
||||
high_task_id = await api.submit_crawl(high_priority)
|
||||
|
||||
# Get both results
|
||||
high_result = await api.wait_for_task(high_task_id)
|
||||
low_result = await api.wait_for_task(low_task_id)
|
||||
|
||||
print("Both tasks completed")
|
||||
assert high_result["status"] == "completed"
|
||||
assert low_result["status"] == "completed"
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Start with health check
|
||||
async with NBCNewsAPITest() as api:
|
||||
health = await api.check_health()
|
||||
print("Server health:", health)
|
||||
|
||||
# Run all tests
|
||||
# await test_basic_crawl()
|
||||
# await test_js_execution()
|
||||
# await test_css_selector()
|
||||
# await test_structured_extraction()
|
||||
await test_llm_extraction()
|
||||
# await test_batch_crawl()
|
||||
# await test_screenshot()
|
||||
# await test_priority_handling()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user