feat(docker): add JWT authentication and improve server architecture
Add JWT token-based authentication to Docker server and client. Refactor server architecture for better code organization and error handling. Move Dockerfile to root deploy directory and update configuration. Add comprehensive documentation and examples. BREAKING CHANGE: Docker server now requires authentication by default. Endpoints require JWT tokens when security.jwt_enabled is true in config.
This commit is contained in:
@@ -1,36 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
from redis import asyncio as aioredis
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.responses import StreamingResponse, RedirectResponse
|
||||
from typing import List, Optional, Dict
|
||||
from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends
|
||||
from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
|
||||
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.background import BackgroundTasks
|
||||
from typing import Dict
|
||||
from fastapi import Query, Path
|
||||
import os
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
from utils import (
|
||||
FilterType,
|
||||
load_config,
|
||||
setup_logging
|
||||
)
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
from utils import FilterType, load_config, setup_logging, verify_email_domain
|
||||
from api import (
|
||||
handle_markdown_request,
|
||||
handle_llm_request,
|
||||
handle_llm_qa
|
||||
handle_llm_qa,
|
||||
handle_stream_crawl_request,
|
||||
handle_crawl_request,
|
||||
stream_results
|
||||
)
|
||||
from auth import create_access_token, get_token_dependency, TokenRequest # Import from auth.py
|
||||
|
||||
__version__ = "0.2.6"
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str] = Field(min_length=1, max_length=100)
|
||||
browser_config: Optional[Dict] = Field(default_factory=dict)
|
||||
crawler_config: Optional[Dict] = Field(default_factory=dict)
|
||||
|
||||
# Load configuration and setup
|
||||
config = load_config()
|
||||
@@ -52,36 +50,24 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
# Configure middleware
|
||||
if config["security"]["enabled"]:
|
||||
if config["security"]["https_redirect"]:
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
if config["security"]["trusted_hosts"] and config["security"]["trusted_hosts"] != ["*"]:
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=config["security"]["trusted_hosts"]
|
||||
)
|
||||
def setup_security_middleware(app, config):
|
||||
sec_config = config.get("security", {})
|
||||
if sec_config.get("enabled", False):
|
||||
if sec_config.get("https_redirect", False):
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
if sec_config.get("trusted_hosts", []) != ["*"]:
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=sec_config["trusted_hosts"])
|
||||
|
||||
setup_security_middleware(app, config)
|
||||
|
||||
# Prometheus instrumentation
|
||||
if config["observability"]["prometheus"]["enabled"]:
|
||||
Instrumentator().instrument(app).expose(app)
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str] = Field(
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
json_schema_extra={
|
||||
"items": {"type": "string", "maxLength": 2000, "pattern": "\\S"}
|
||||
}
|
||||
)
|
||||
browser_config: Optional[Dict] = Field(
|
||||
default_factory=dict,
|
||||
example={"headless": True, "viewport": {"width": 1200}}
|
||||
)
|
||||
crawler_config: Optional[Dict] = Field(
|
||||
default_factory=dict,
|
||||
example={"stream": True, "cache_mode": "aggressive"}
|
||||
)
|
||||
# Get token dependency based on config
|
||||
token_dependency = get_token_dependency(config)
|
||||
|
||||
# Middleware for security headers
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
@@ -89,6 +75,15 @@ async def add_security_headers(request: Request, call_next):
|
||||
response.headers.update(config["security"]["headers"])
|
||||
return response
|
||||
|
||||
# Token endpoint (always available, but usage depends on config)
|
||||
@app.post("/token")
|
||||
async def get_token(request_data: TokenRequest):
|
||||
if not verify_email_domain(request_data.email):
|
||||
raise HTTPException(status_code=400, detail="Invalid email domain")
|
||||
token = create_access_token({"sub": request_data.email})
|
||||
return {"email": request_data.email, "access_token": token, "token_type": "bearer"}
|
||||
|
||||
# Endpoints with conditional auth
|
||||
@app.get("/md/{url:path}")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
async def get_markdown(
|
||||
@@ -96,165 +91,84 @@ async def get_markdown(
|
||||
url: str,
|
||||
f: FilterType = FilterType.FIT,
|
||||
q: Optional[str] = None,
|
||||
c: Optional[str] = "0"
|
||||
c: Optional[str] = "0",
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
):
|
||||
"""Get markdown from URL with optional filtering."""
|
||||
result = await handle_markdown_request(url, f, q, c, config)
|
||||
return PlainTextResponse(result)
|
||||
|
||||
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
|
||||
async def llm_endpoint(
|
||||
request: Request,
|
||||
url: str = Path(..., description="Domain and path without protocol"),
|
||||
q: Optional[str] = Query(None, description="Question to ask about the page content"),
|
||||
url: str = Path(...),
|
||||
q: Optional[str] = Query(None),
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
):
|
||||
"""QA endpoint that uses LLM with crawled content as context."""
|
||||
if not q:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Query parameter 'q' is required"
|
||||
)
|
||||
|
||||
# Ensure URL starts with http/https
|
||||
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = 'https://' + url
|
||||
|
||||
try:
|
||||
answer = await handle_llm_qa(url, q, config)
|
||||
return JSONResponse({"answer": answer})
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# @app.get("/llm/{input:path}")
|
||||
# @limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
# async def llm_endpoint(
|
||||
# request: Request,
|
||||
# background_tasks: BackgroundTasks,
|
||||
# input: str,
|
||||
# q: Optional[str] = None,
|
||||
# s: Optional[str] = None,
|
||||
# c: Optional[str] = "0"
|
||||
# ):
|
||||
# """Handle LLM extraction requests."""
|
||||
# return await handle_llm_request(
|
||||
# redis, background_tasks, request, input, q, s, c, config
|
||||
# )
|
||||
|
||||
|
||||
|
||||
@app.get("/schema")
|
||||
async def get_schema():
|
||||
"""Endpoint for client-side validation schema."""
|
||||
from crawl4ai import BrowserConfig, CrawlerRunConfig
|
||||
return {
|
||||
"browser": BrowserConfig().dump(),
|
||||
"crawler": CrawlerRunConfig().dump()
|
||||
}
|
||||
return {"browser": BrowserConfig().dump(), "crawler": CrawlerRunConfig().dump()}
|
||||
|
||||
@app.get(config["observability"]["health_check"]["endpoint"])
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "timestamp": time.time()}
|
||||
return {"status": "ok", "timestamp": time.time(), "version": __version__}
|
||||
|
||||
@app.get(config["observability"]["prometheus"]["endpoint"])
|
||||
async def metrics():
|
||||
"""Prometheus metrics endpoint."""
|
||||
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
|
||||
|
||||
@app.post("/crawl")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
async def crawl(request: Request, crawl_request: CrawlRequest):
|
||||
"""Handle crawl requests."""
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
MemoryAdaptiveDispatcher,
|
||||
RateLimiter
|
||||
async def crawl(
|
||||
request: Request,
|
||||
crawl_request: CrawlRequest,
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
):
|
||||
if not crawl_request.urls:
|
||||
raise HTTPException(status_code=400, detail="At least one URL required")
|
||||
|
||||
results = await handle_crawl_request(
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config,
|
||||
crawler_config=crawl_request.crawler_config,
|
||||
config=config
|
||||
)
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
crawler = None
|
||||
return JSONResponse(results)
|
||||
|
||||
try:
|
||||
if not crawl_request.urls:
|
||||
logger.error("Empty URL list received")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one URL required"
|
||||
)
|
||||
|
||||
browser_config = BrowserConfig.load(crawl_request.browser_config)
|
||||
crawler_config = CrawlerRunConfig.load(crawl_request.crawler_config)
|
||||
@app.post("/crawl/stream")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
async def crawl_stream(
|
||||
request: Request,
|
||||
crawl_request: CrawlRequest,
|
||||
token_data: Optional[Dict] = Depends(token_dependency)
|
||||
):
|
||||
if not crawl_request.urls:
|
||||
raise HTTPException(status_code=400, detail="At least one URL required")
|
||||
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=config["crawler"]["memory_threshold_percent"],
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=tuple(config["crawler"]["rate_limiter"]["base_delay"])
|
||||
)
|
||||
)
|
||||
crawler, results_gen = await handle_stream_crawl_request(
|
||||
urls=crawl_request.urls,
|
||||
browser_config=crawl_request.browser_config,
|
||||
crawler_config=crawl_request.crawler_config,
|
||||
config=config
|
||||
)
|
||||
|
||||
if crawler_config.stream:
|
||||
crawler = AsyncWebCrawler(config=browser_config)
|
||||
await crawler.start()
|
||||
|
||||
results_gen = await asyncio.wait_for(
|
||||
crawler.arun_many(
|
||||
urls=crawl_request.urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
),
|
||||
timeout=config["crawler"]["timeouts"]["stream_init"]
|
||||
)
|
||||
|
||||
from api import stream_results
|
||||
return StreamingResponse(
|
||||
stream_results(crawler, results_gen),
|
||||
media_type='application/x-ndjson',
|
||||
headers={
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
'X-Stream-Status': 'active'
|
||||
}
|
||||
)
|
||||
else:
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
results = await asyncio.wait_for(
|
||||
crawler.arun_many(
|
||||
urls=crawl_request.urls,
|
||||
config=crawler_config,
|
||||
dispatcher=dispatcher
|
||||
),
|
||||
timeout=config["crawler"]["timeouts"]["batch_process"]
|
||||
)
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"results": [result.model_dump() for result in results]
|
||||
})
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"Operation timed out: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Processing timeout"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
finally:
|
||||
if crawler:
|
||||
try:
|
||||
await crawler.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Final crawler cleanup error: {e}")
|
||||
return StreamingResponse(
|
||||
stream_results(crawler, results_gen),
|
||||
media_type='application/x-ndjson',
|
||||
headers={'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Stream-Status': 'active'}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user