Apply Ruff Corrections
This commit is contained in:
106
main.py
106
main.py
@@ -1,14 +1,9 @@
|
||||
import asyncio, os
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import Depends, Security
|
||||
@@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union
|
||||
import psutil
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from urllib.parse import urlparse
|
||||
import math
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
|
||||
from crawl4ai.config import MIN_WORD_THRESHOLD
|
||||
from crawl4ai.extraction_strategy import (
|
||||
@@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class CrawlerType(str, Enum):
|
||||
BASIC = "basic"
|
||||
LLM = "llm"
|
||||
COSINE = "cosine"
|
||||
JSON_CSS = "json_css"
|
||||
|
||||
|
||||
class ExtractionConfig(BaseModel):
|
||||
type: CrawlerType
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ChunkingStrategy(BaseModel):
|
||||
type: str
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ContentFilter(BaseModel):
|
||||
type: str = "bm25"
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: Union[HttpUrl, List[HttpUrl]]
|
||||
word_count_threshold: int = MIN_WORD_THRESHOLD
|
||||
@@ -77,9 +75,10 @@ class CrawlRequest(BaseModel):
|
||||
session_id: Optional[str] = None
|
||||
cache_mode: Optional[CacheMode] = CacheMode.ENABLED
|
||||
priority: int = Field(default=5, ge=1, le=10)
|
||||
ttl: Optional[int] = 3600
|
||||
ttl: Optional[int] = 3600
|
||||
crawler_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
id: str
|
||||
@@ -89,6 +88,7 @@ class TaskInfo:
|
||||
created_at: float = time.time()
|
||||
ttl: int = 3600
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
@@ -106,7 +106,9 @@ class ResourceMonitor:
|
||||
mem_usage = psutil.virtual_memory().percent / 100
|
||||
cpu_usage = psutil.cpu_percent() / 100
|
||||
|
||||
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold)
|
||||
memory_factor = max(
|
||||
0, (self.memory_threshold - mem_usage) / self.memory_threshold
|
||||
)
|
||||
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
|
||||
|
||||
self._last_available_slots = math.floor(
|
||||
@@ -116,6 +118,7 @@ class ResourceMonitor:
|
||||
|
||||
return self._last_available_slots
|
||||
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, cleanup_interval: int = 300):
|
||||
self.tasks: Dict[str, TaskInfo] = {}
|
||||
@@ -149,12 +152,16 @@ class TaskManager:
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
# Then try low priority
|
||||
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1)
|
||||
_, task_id = await asyncio.wait_for(
|
||||
self.low_priority.get(), timeout=0.1
|
||||
)
|
||||
return task_id
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None):
|
||||
def update_task(
|
||||
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
|
||||
):
|
||||
if task_id in self.tasks:
|
||||
task_info = self.tasks[task_id]
|
||||
task_info.status = status
|
||||
@@ -180,6 +187,7 @@ class TaskManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
|
||||
class CrawlerPool:
|
||||
def __init__(self, max_size: int = 10):
|
||||
self.max_size = max_size
|
||||
@@ -222,6 +230,7 @@ class CrawlerPool:
|
||||
await crawler.__aexit__(None, None, None)
|
||||
self.active_crawlers.clear()
|
||||
|
||||
|
||||
class CrawlerService:
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
|
||||
@@ -258,10 +267,10 @@ class CrawlerService:
|
||||
async def submit_task(self, request: CrawlRequest) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600)
|
||||
|
||||
|
||||
# Store request data with task
|
||||
self.task_manager.tasks[task_id].request = request
|
||||
|
||||
|
||||
return task_id
|
||||
|
||||
async def _process_queue(self):
|
||||
@@ -286,9 +295,11 @@ class CrawlerService:
|
||||
|
||||
try:
|
||||
crawler = await self.crawler_pool.acquire(**request.crawler_params)
|
||||
|
||||
extraction_strategy = self._create_extraction_strategy(request.extraction_config)
|
||||
|
||||
|
||||
extraction_strategy = self._create_extraction_strategy(
|
||||
request.extraction_config
|
||||
)
|
||||
|
||||
if isinstance(request.urls, list):
|
||||
results = await crawler.arun_many(
|
||||
urls=[str(url) for url in request.urls],
|
||||
@@ -318,16 +329,21 @@ class CrawlerService:
|
||||
)
|
||||
|
||||
await self.crawler_pool.release(crawler)
|
||||
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results)
|
||||
self.task_manager.update_task(
|
||||
task_id, TaskStatus.COMPLETED, results
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e))
|
||||
self.task_manager.update_task(
|
||||
task_id, TaskStatus.FAILED, error=str(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in queue processing: {str(e)}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
app = FastAPI(title="Crawl4AI API")
|
||||
|
||||
# CORS configuration
|
||||
@@ -344,6 +360,7 @@ app.add_middleware(
|
||||
security = HTTPBearer()
|
||||
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
|
||||
|
||||
|
||||
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
if not CRAWL4AI_API_TOKEN:
|
||||
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
|
||||
@@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return credentials
|
||||
|
||||
|
||||
def secure_endpoint():
|
||||
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
|
||||
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
|
||||
|
||||
|
||||
# Check if site directory exists
|
||||
if os.path.exists(__location__ + "/site"):
|
||||
# Mount the site directory as a static directory
|
||||
@@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site")
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
await crawler_service.start()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await crawler_service.stop()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
if os.path.exists(__location__ + "/site"):
|
||||
@@ -379,12 +401,16 @@ def read_root():
|
||||
# Return a json response
|
||||
return {"message": "Crawl4AI API service is running"}
|
||||
|
||||
|
||||
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
async def crawl(request: CrawlRequest) -> Dict[str, str]:
|
||||
task_id = await crawler_service.submit_task(request)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
|
||||
@app.get(
|
||||
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
|
||||
)
|
||||
async def get_task_status(task_id: str):
|
||||
task_info = crawler_service.task_manager.get_task(task_id)
|
||||
if not task_info:
|
||||
@@ -406,36 +432,45 @@ async def get_task_status(task_id: str):
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
|
||||
task_id = await crawler_service.submit_task(request)
|
||||
|
||||
|
||||
# Wait up to 60 seconds for task completion
|
||||
for _ in range(60):
|
||||
task_info = crawler_service.task_manager.get_task(task_id)
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
if task_info.status == TaskStatus.COMPLETED:
|
||||
# Return same format as /task/{task_id} endpoint
|
||||
if isinstance(task_info.result, list):
|
||||
return {"status": task_info.status, "results": [result.dict() for result in task_info.result]}
|
||||
return {
|
||||
"status": task_info.status,
|
||||
"results": [result.dict() for result in task_info.result],
|
||||
}
|
||||
return {"status": task_info.status, "result": task_info.result.dict()}
|
||||
|
||||
|
||||
if task_info.status == TaskStatus.FAILED:
|
||||
raise HTTPException(status_code=500, detail=task_info.error)
|
||||
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# If we get here, task didn't complete within timeout
|
||||
raise HTTPException(status_code=408, detail="Task timed out")
|
||||
|
||||
@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
|
||||
@app.post(
|
||||
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
|
||||
)
|
||||
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
|
||||
extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config)
|
||||
|
||||
extraction_strategy = crawler_service._create_extraction_strategy(
|
||||
request.extraction_config
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(request.urls, list):
|
||||
results = await crawler.arun_many(
|
||||
@@ -470,7 +505,8 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in direct crawl: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
available_slots = await crawler_service.resource_monitor.get_available_slots()
|
||||
@@ -482,6 +518,8 @@ async def health_check():
|
||||
"cpu_usage": psutil.cpu_percent(),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=11235)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=11235)
|
||||
|
||||
Reference in New Issue
Block a user