refactor(docker): improve server architecture and configuration

Complete overhaul of Docker deployment setup with improved architecture:
- Add Redis integration for task management
- Implement rate limiting and security middleware
- Add Prometheus metrics and health checks
- Improve error handling and logging
- Add support for streaming responses
- Implement proper configuration management
- Add platform-specific optimizations for ARM64/AMD64

BREAKING CHANGE: Docker deployment now requires Redis and new config.yml structure
This commit is contained in:
UncleCode
2025-02-02 20:19:51 +08:00
parent 7b1ef07c41
commit 33a21d6a7a
16 changed files with 1918 additions and 344 deletions

View File

@@ -1,120 +1,237 @@
import os
import sys
import time
from typing import List, Optional
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
import json
import asyncio
from typing import AsyncGenerator
from crawl4ai import (
BrowserConfig,
CrawlerRunConfig,
AsyncWebCrawler,
MemoryAdaptiveDispatcher,
RateLimiter,
from redis import asyncio as aioredis
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from fastapi.responses import PlainTextResponse
from fastapi.responses import JSONResponse
from fastapi.background import BackgroundTasks
from typing import Dict
import os
from utils import (
FilterType,
load_config,
setup_logging
)
from api import (
handle_markdown_request,
handle_llm_request
)
from typing import List, Optional
from pydantic import BaseModel
# Load configuration and setup
config = load_config()
setup_logging(config)
# Initialize Redis
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
# Initialize rate limiter
limiter = Limiter(
key_func=get_remote_address,
default_limits=[config["rate_limiting"]["default_limit"]],
storage_uri=config["rate_limiting"]["storage_uri"]
)
app = FastAPI(
title=config["app"]["title"],
version=config["app"]["version"]
)
# Configure middleware
if config["security"]["enabled"]:
if config["security"]["https_redirect"]:
app.add_middleware(HTTPSRedirectMiddleware)
if config["security"]["trusted_hosts"] and config["security"]["trusted_hosts"] != ["*"]:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=config["security"]["trusted_hosts"]
)
# Prometheus instrumentation
if config["observability"]["prometheus"]["enabled"]:
Instrumentator().instrument(app).expose(app)
class CrawlRequest(BaseModel):
urls: List[str]
browser_config: Optional[dict] = None
crawler_config: Optional[dict] = None
class CrawlResponse(BaseModel):
success: bool
results: List[dict]
class Config:
arbitrary_types_allowed = True
app = FastAPI(title="Crawl4AI API")
async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) -> AsyncGenerator[bytes, None]:
"""Stream results and manage crawler lifecycle"""
def datetime_handler(obj):
"""Custom handler for datetime objects during JSON serialization"""
if hasattr(obj, 'isoformat'):
return obj.isoformat()
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
try:
async for result in results_gen:
try:
# Use dump method for serialization
result_dict = result.model_dump()
print(f"Streaming result for URL: {result_dict['url']}, Success: {result_dict['success']}")
# Use custom JSON encoder with datetime handler
yield (json.dumps(result_dict, default=datetime_handler) + "\n").encode('utf-8')
except Exception as e:
print(f"Error serializing result: {e}")
error_response = {
"error": str(e),
"url": getattr(result, 'url', 'unknown')
}
yield (json.dumps(error_response, default=datetime_handler) + "\n").encode('utf-8')
except asyncio.CancelledError:
print("Client disconnected, cleaning up...")
finally:
try:
await crawler.close()
except Exception as e:
print(f"Error closing crawler: {e}")
@app.post("/crawl")
async def crawl(request: CrawlRequest):
# Load configs using our new utilities
browser_config = BrowserConfig.load(request.browser_config)
crawler_config = CrawlerRunConfig.load(request.crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=95.0,
rate_limiter=RateLimiter(base_delay=(1.0, 2.0)),
urls: List[str] = Field(
min_length=1,
max_length=100,
json_schema_extra={
"items": {"type": "string", "maxLength": 2000, "pattern": "\\S"}
}
)
browser_config: Optional[Dict] = Field(
default_factory=dict,
example={"headless": True, "viewport": {"width": 1200}}
)
crawler_config: Optional[Dict] = Field(
default_factory=dict,
example={"stream": True, "cache_mode": "aggressive"}
)
try:
if crawler_config.stream:
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
if config["security"]["enabled"]:
response.headers.update(config["security"]["headers"])
return response
results_gen = await crawler.arun_many(
urls=request.urls,
config=crawler_config,
dispatcher=dispatcher
)
@app.get("/md/{url:path}")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def get_markdown(
request: Request,
url: str,
f: FilterType = FilterType.FIT,
q: Optional[str] = None,
c: Optional[str] = "0"
):
"""Get markdown from URL with optional filtering."""
result = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(result)
return StreamingResponse(
stream_results(crawler, results_gen),
media_type='application/x-ndjson'
)
else:
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await crawler.arun_many(
urls=request.urls,
config=crawler_config,
dispatcher=dispatcher
)
# Use dump method for each result
results_dict = [result.model_dump() for result in results]
return CrawlResponse(success=True, results=results_dict)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/llm/{input:path}")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def llm_endpoint(
request: Request,
background_tasks: BackgroundTasks,
input: str,
q: Optional[str] = None,
s: Optional[str] = None,
c: Optional[str] = "0"
):
"""Handle LLM extraction requests."""
return await handle_llm_request(
redis, background_tasks, request, input, q, s, c, config
)
@app.get("/schema")
async def get_schema():
"""Return config schemas for client validation"""
"""Endpoint for client-side validation schema."""
from crawl4ai import BrowserConfig, CrawlerRunConfig
return {
"browser": BrowserConfig.model_json_schema(),
"crawler": CrawlerRunConfig.model_json_schema()
}
@app.get("/health")
@app.get(config["observability"]["health_check"]["endpoint"])
async def health():
return {"status": "ok"}
"""Health check endpoint."""
return {"status": "ok", "timestamp": time.time()}
@app.get(config["observability"]["prometheus"]["endpoint"])
async def metrics():
"""Prometheus metrics endpoint."""
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
@app.post("/crawl")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def crawl(request: Request, crawl_request: CrawlRequest):
"""Handle crawl requests."""
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
RateLimiter
)
import asyncio
import logging
logger = logging.getLogger(__name__)
crawler = None
try:
if not crawl_request.urls:
logger.error("Empty URL list received")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one URL required"
)
browser_config = BrowserConfig.load(crawl_request.browser_config)
crawler_config = CrawlerRunConfig.load(crawl_request.crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
)
if crawler_config.stream:
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
results_gen = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["stream_init"]
)
from api import stream_results
return StreamingResponse(
stream_results(crawler, results_gen),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Stream-Status': 'active'
}
)
else:
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["batch_process"]
)
return JSONResponse({
"success": True,
"results": [result.model_dump() for result in results]
})
except asyncio.TimeoutError as e:
logger.error(f"Operation timed out: {str(e)}")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Processing timeout"
)
except Exception as e:
logger.error(f"Server error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
finally:
if crawler:
try:
await crawler.close()
except Exception as e:
logger.error(f"Final crawler cleanup error: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
uvicorn.run(
"server:app",
host=config["app"]["host"],
port=config["app"]["port"],
reload=config["app"]["reload"],
timeout_keep_alive=config["app"]["timeout_keep_alive"]
)