Files
crawl4ai/deploy/docker/server-secure.py
UncleCode 2864015469 feat(docker): implement supervisor and secure API endpoints
Add supervisor configuration for managing Redis and Gunicorn processes
Replace direct process management with supervisord
Add secure and token-free API server variants
Implement JWT authentication for protected endpoints
Update datetime handling in async dispatcher
Add email domain verification

BREAKING CHANGE: Server startup now uses supervisord instead of direct process management
2025-02-17 20:31:20 +08:00

325 lines
11 KiB
Python

import os
import sys
import time
import base64
from typing import List, Optional, Dict
from datetime import datetime, timedelta, timezone
from jwt import JWT, jwk_from_dict
from jwt.utils import get_int_from_datetime
from fastapi import FastAPI, HTTPException, Request, status, Depends, Query, Path
from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from redis import asyncio as aioredis
from pydantic import EmailStr
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from utils import FilterType, load_config, setup_logging, verify_email_domain
from api import (
handle_markdown_request,
handle_llm_qa
)
__version__ = "0.1.2"
class CrawlRequest(BaseModel):
urls: List[str] = Field(
min_length=1,
max_length=100,
json_schema_extra={
"items": {"type": "string", "maxLength": 2000, "pattern": "\\S"}
}
)
browser_config: Optional[Dict] = Field(
default_factory=dict,
example={"headless": True, "viewport": {"width": 1200}}
)
crawler_config: Optional[Dict] = Field(
default_factory=dict,
example={"stream": True, "cache_mode": "aggressive"}
)
# Load configuration and setup
config = load_config()
setup_logging(config)
# Initialize Redis
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
# Initialize rate limiter
limiter = Limiter(
key_func=get_remote_address,
default_limits=[config["rate_limiting"]["default_limit"]],
storage_uri=config["rate_limiting"]["storage_uri"]
)
app = FastAPI(
title=config["app"]["title"],
version=config["app"]["version"]
)
# Configure middleware
if config["security"]["enabled"]:
if config["security"]["https_redirect"]:
app.add_middleware(HTTPSRedirectMiddleware)
if config["security"]["trusted_hosts"] and config["security"]["trusted_hosts"] != ["*"]:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=config["security"]["trusted_hosts"]
)
# Prometheus instrumentation
if config["observability"]["prometheus"]["enabled"]:
Instrumentator().instrument(app).expose(app)
# -------------------------------
# JWT Token Authentication Setup
# -------------------------------
instance = JWT()
# Use a secret key for symmetric signing (HS256)
SECRET_KEY = os.environ.get("SECRET_KEY", "mysecret")
ACCESS_TOKEN_EXPIRE_MINUTES = 60
# FastAPI security scheme for extracting the Authorization header
security = HTTPBearer()
def get_jwk_from_secret(secret: str):
"""
Convert a simple secret string into a JWK object.
The secret is base64 URL-safe encoded (without padding) as required.
"""
secret_bytes = secret.encode('utf-8')
b64_secret = base64.urlsafe_b64encode(secret_bytes).rstrip(b'=').decode('utf-8')
return jwk_from_dict({"kty": "oct", "k": b64_secret})
def create_access_token(data: dict, expires_delta: timedelta = None):
"""
Create a JWT access token with an expiration.
"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta if expires_delta else timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": get_int_from_datetime(expire)})
# Convert the secret into a JWK object
signing_key = get_jwk_from_secret(SECRET_KEY)
encoded_jwt = instance.encode(to_encode, signing_key, alg='HS256')
return encoded_jwt
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""
Verify the JWT token extracted from the Authorization header.
"""
token = credentials.credentials
# Convert the secret into a JWK object for verification
verifying_key = get_jwk_from_secret(SECRET_KEY)
try:
payload = instance.decode(token, verifying_key, do_time_check=True, algorithms='HS256')
return payload
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid or expired token")
# -------------------------------
# Endpoints
# -------------------------------
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
if config["security"]["enabled"]:
response.headers.update(config["security"]["headers"])
return response
class TokenRequest(BaseModel):
email: EmailStr
@app.post("/token")
async def get_token(request_data: TokenRequest):
"""
Minimal endpoint to generate a JWT token.
In a real-world scenario, you'd validate credentials here.
"""
# token = create_access_token({"sub": "user1"})
# return {"access_token": token, "token_type": "bearer"}
# Verify that the email domain likely exists (has MX records)
if not verify_email_domain(request_data.email):
raise HTTPException(
status_code=400,
detail="Email domain verification failed. Please use a valid email address."
)
token = create_access_token({"sub": request_data.email})
return {"email": request_data.email, "access_token": token, "token_type": "bearer"}
@app.get("/md/{url:path}")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def get_markdown(
request: Request,
url: str,
f: FilterType = FilterType.FIT,
q: Optional[str] = None,
c: Optional[str] = "0",
token_data: dict = Depends(verify_token)
):
"""Get markdown from URL with optional filtering."""
result = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(result)
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
async def llm_endpoint(
request: Request,
url: str = Path(..., description="Domain and path without protocol"),
q: Optional[str] = Query(None, description="Question to ask about the page content"),
token_data: dict = Depends(verify_token)
):
"""QA endpoint that uses LLM with crawled content as context."""
if not q:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Query parameter 'q' is required"
)
# Ensure URL starts with http/https
if not url.startswith(('http://', 'https://')):
url = 'https://' + url
try:
answer = await handle_llm_qa(url, q, config)
return JSONResponse({"answer": answer})
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
@app.get("/schema")
async def get_schema():
"""Endpoint for client-side validation schema."""
from crawl4ai import BrowserConfig, CrawlerRunConfig
return {
"browser": BrowserConfig.model_json_schema(),
"crawler": CrawlerRunConfig.model_json_schema()
}
@app.get(config["observability"]["health_check"]["endpoint"])
async def health():
"""Health check endpoint."""
return {"status": "ok", "timestamp": time.time(), "version": __version__}
@app.get(config["observability"]["prometheus"]["endpoint"])
async def metrics():
"""Prometheus metrics endpoint."""
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
# -------------------------------
# Protected Endpoint Example: /crawl
# -------------------------------
@app.post("/crawl")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def crawl(request: Request, crawl_request: CrawlRequest, token_data: dict = Depends(verify_token)):
"""Handle crawl requests. Protected by JWT authentication."""
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
RateLimiter
)
import asyncio
import logging
logger = logging.getLogger(__name__)
crawler = None
try:
if not crawl_request.urls:
logger.error("Empty URL list received")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one URL required"
)
browser_config = BrowserConfig.load(crawl_request.browser_config)
crawler_config = CrawlerRunConfig.load(crawl_request.crawler_config)
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
rate_limiter=RateLimiter(
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
)
)
if crawler_config.stream:
crawler = AsyncWebCrawler(config=browser_config)
await crawler.start()
results_gen = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["stream_init"]
)
from api import stream_results
return StreamingResponse(
stream_results(crawler, results_gen),
media_type='application/x-ndjson',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Stream-Status': 'active'
}
)
else:
async with AsyncWebCrawler(config=browser_config) as crawler:
results = await asyncio.wait_for(
crawler.arun_many(
urls=crawl_request.urls,
config=crawler_config,
dispatcher=dispatcher
),
timeout=config["crawler"]["timeouts"]["batch_process"]
)
return JSONResponse({
"success": True,
"results": [result.model_dump() for result in results]
})
except asyncio.TimeoutError as e:
logger.error(f"Operation timed out: {str(e)}")
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Processing timeout"
)
except Exception as e:
logger.error(f"Server error: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
finally:
if crawler:
try:
await crawler.close()
except Exception as e:
logger.error(f"Final crawler cleanup error: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server-secure:app",
host=config["app"]["host"],
port=config["app"]["port"],
reload=config["app"]["reload"],
timeout_keep_alive=config["app"]["timeout_keep_alive"]
)