refactor(server): migrate to pool-based crawler management

Replace crawler_manager.py with simpler crawler_pool.py implementation:
- Add global page semaphore for hard concurrency cap
- Implement browser pool with idle cleanup
- Add playground UI for testing and stress testing
- Update API handlers to use pooled crawlers
- Enhance logging levels and symbols

BREAKING CHANGE: Removes CrawlerManager class in favor of simpler pool-based approach
This commit is contained in:
UncleCode
2025-04-20 20:14:26 +08:00
parent 16b2318242
commit a58c8000aa
14 changed files with 1447 additions and 1435 deletions

View File

@@ -1,167 +1,200 @@
# Import from auth.py
from auth import create_access_token, get_token_dependency, TokenRequest
from api import (
handle_markdown_request,
handle_llm_qa,
handle_stream_crawl_request,
handle_crawl_request,
stream_results,
_get_memory_mb
)
from utils import FilterType, load_config, setup_logging, verify_email_domain
import os
import sys
import time
from typing import List, Optional, Dict, AsyncGenerator
# ───────────────────────── server.py ─────────────────────────
"""
Crawl4AI FastAPI entrypoint
• Browser pool + global page cap
• Ratelimiting, security, metrics
• /crawl, /crawl/stream, /md, /llm endpoints
"""
# ── stdlib & 3rdparty imports ───────────────────────────────
import os, sys, time, asyncio
from typing import List, Optional, Dict
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends, status
from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
import pathlib
from fastapi import (
FastAPI, HTTPException, Request, Path, Query, Depends
)
from fastapi.responses import (
StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
)
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.staticfiles import StaticFiles
import ast, crawl4ai as _c4
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from prometheus_fastapi_instrumentator import Instrumentator
from redis import asyncio as aioredis
from crawl4ai import (
BrowserConfig,
CrawlerRunConfig,
AsyncLogger
)
from crawler_manager import (
CrawlerManager,
CrawlerManagerConfig,
PoolTimeoutError,
NoHealthyCrawlerError
)
# ── internal imports (after sys.path append) ─────────────────
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from utils import (
FilterType, load_config, setup_logging, verify_email_domain
)
from api import (
handle_markdown_request, handle_llm_qa,
handle_stream_crawl_request, handle_crawl_request,
stream_results
)
from auth import create_access_token, get_token_dependency, TokenRequest
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
from crawler_pool import get_crawler, close_all, janitor
__version__ = "0.2.6"
class CrawlRequest(BaseModel):
urls: List[str] = Field(min_length=1, max_length=100)
browser_config: Optional[Dict] = Field(default_factory=dict)
crawler_config: Optional[Dict] = Field(default_factory=dict)
# Load configuration and setup
# ────────────────── configuration / logging ──────────────────
config = load_config()
setup_logging(config)
logger = AsyncLogger(
log_file=config["logging"].get("log_file", "app.log"),
verbose=config["logging"].get("verbose", False),
tag_width=10,
)
# Initialize Redis
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
__version__ = "0.5.1-d1"
# Initialize rate limiter
limiter = Limiter(
key_func=get_remote_address,
default_limits=[config["rate_limiting"]["default_limit"]],
storage_uri=config["rate_limiting"]["storage_uri"]
)
# ── global page semaphore (hard cap) ─────────────────────────
MAX_PAGES = config["crawler"]["pool"].get("max_pages", 30)
GLOBAL_SEM = asyncio.Semaphore(MAX_PAGES)
# --- Initialize Manager (will be done in lifespan) ---
# Load manager config from the main config
manager_config_dict = config.get("crawler_pool", {})
# Use Pydantic to parse and validate
manager_config = CrawlerManagerConfig(**manager_config_dict)
crawler_manager = CrawlerManager(config=manager_config, logger=logger)
# --- FastAPI App and Lifespan ---
# import logging
# page_log = logging.getLogger("page_cap")
# orig_arun = AsyncWebCrawler.arun
# async def capped_arun(self, *a, **kw):
# await GLOBAL_SEM.acquire() # ← take slot
# try:
# in_flight = MAX_PAGES - GLOBAL_SEM._value # used permits
# page_log.info("🕸️ pages_in_flight=%s / %s", in_flight, MAX_PAGES)
# return await orig_arun(self, *a, **kw)
# finally:
# GLOBAL_SEM.release() # ← free slot
orig_arun = AsyncWebCrawler.arun
async def capped_arun(self, *a, **kw):
async with GLOBAL_SEM:
return await orig_arun(self, *a, **kw)
AsyncWebCrawler.arun = capped_arun
# ───────────────────── FastAPI lifespan ──────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting up the server...")
if manager_config.enabled:
logger.info("Initializing Crawler Manager...")
await crawler_manager.initialize()
app.state.crawler_manager = crawler_manager # Store manager in app state
logger.info("Crawler Manager is enabled.")
else:
logger.warning("Crawler Manager is disabled.")
app.state.crawler_manager = None # Indicate disabled state
yield # Server runs here
# Shutdown
logger.info("Shutting down server...")
if app.state.crawler_manager:
logger.info("Shutting down Crawler Manager...")
await app.state.crawler_manager.shutdown()
logger.info("Crawler Manager shut down.")
logger.info("Server shut down.")
async def lifespan(_: FastAPI):
await get_crawler(BrowserConfig(
extra_args=config["crawler"]["browser"].get("extra_args", []),
**config["crawler"]["browser"].get("kwargs", {}),
)) # warmup
app.state.janitor = asyncio.create_task(janitor()) # idle GC
yield
app.state.janitor.cancel()
await close_all()
# ───────────────────── FastAPI instance ──────────────────────
app = FastAPI(
title=config["app"]["title"],
version=config["app"]["version"],
lifespan=lifespan,
)
# Configure middleware
def setup_security_middleware(app, config):
sec_config = config.get("security", {})
if sec_config.get("enabled", False):
if sec_config.get("https_redirect", False):
app.add_middleware(HTTPSRedirectMiddleware)
if sec_config.get("trusted_hosts", []) != ["*"]:
app.add_middleware(TrustedHostMiddleware,
allowed_hosts=sec_config["trusted_hosts"])
# ── static playground ──────────────────────────────────────
STATIC_DIR = pathlib.Path(__file__).parent / "static" / "playground"
if not STATIC_DIR.exists():
raise RuntimeError(f"Playground assets not found at {STATIC_DIR}")
app.mount(
"/playground",
StaticFiles(directory=STATIC_DIR, html=True),
name="play",
)
# Optional nicetohave: opening the root shows the playground
@app.get("/")
async def root():
return RedirectResponse("/playground")
setup_security_middleware(app, config)
# ─────────────────── infra / middleware ─────────────────────
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
limiter = Limiter(
key_func=get_remote_address,
default_limits=[config["rate_limiting"]["default_limit"]],
storage_uri=config["rate_limiting"]["storage_uri"],
)
def _setup_security(app_: FastAPI):
sec = config["security"]
if not sec["enabled"]:
return
if sec.get("https_redirect"):
app_.add_middleware(HTTPSRedirectMiddleware)
if sec.get("trusted_hosts", []) != ["*"]:
app_.add_middleware(
TrustedHostMiddleware, allowed_hosts=sec["trusted_hosts"]
)
_setup_security(app)
# Prometheus instrumentation
if config["observability"]["prometheus"]["enabled"]:
Instrumentator().instrument(app).expose(app)
# Get token dependency based on config
token_dependency = get_token_dependency(config)
# Middleware for security headers
token_dep = get_token_dependency(config)
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
resp = await call_next(request)
if config["security"]["enabled"]:
response.headers.update(config["security"]["headers"])
return response
resp.headers.update(config["security"]["headers"])
return resp
# ───────────────── safe configdump helper ─────────────────
ALLOWED_TYPES = {
"CrawlerRunConfig": CrawlerRunConfig,
"BrowserConfig": BrowserConfig,
}
def _safe_eval_config(expr: str) -> dict:
"""
Accept exactly one toplevel call to CrawlerRunConfig(...) or BrowserConfig(...).
Whatever is inside the parentheses is fine *except* further function calls
(so no __import__('os') stuff). All public names from crawl4ai are available
when we eval.
"""
tree = ast.parse(expr, mode="eval")
# must be a single call
if not isinstance(tree.body, ast.Call):
raise ValueError("Expression must be a single constructor call")
call = tree.body
if not (isinstance(call.func, ast.Name) and call.func.id in {"CrawlerRunConfig", "BrowserConfig"}):
raise ValueError("Only CrawlerRunConfig(...) or BrowserConfig(...) are allowed")
# forbid nested calls to keep the surface tiny
for node in ast.walk(call):
if isinstance(node, ast.Call) and node is not call:
raise ValueError("Nested function calls are not permitted")
# expose everything that crawl4ai exports, nothing else
safe_env = {name: getattr(_c4, name) for name in dir(_c4) if not name.startswith("_")}
obj = eval(compile(tree, "<config>", "eval"), {"__builtins__": {}}, safe_env)
return obj.dump()
async def get_manager() -> CrawlerManager:
# Ensure manager exists and is enabled before yielding
if not hasattr(app.state, 'crawler_manager') or app.state.crawler_manager is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Crawler service is disabled or not initialized"
)
if not app.state.crawler_manager.is_enabled():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Crawler service is currently disabled"
)
return app.state.crawler_manager
# Token endpoint (always available, but usage depends on config)
# ───────────────────────── Schemas ───────────────────────────
class CrawlRequest(BaseModel):
urls: List[str] = Field(min_length=1, max_length=100)
browser_config: Optional[Dict] = Field(default_factory=dict)
crawler_config: Optional[Dict] = Field(default_factory=dict)
class RawCode(BaseModel):
code: str
# ──────────────────────── Endpoints ──────────────────────────
@app.post("/token")
async def get_token(request_data: TokenRequest):
if not verify_email_domain(request_data.email):
raise HTTPException(status_code=400, detail="Invalid email domain")
token = create_access_token({"sub": request_data.email})
return {"email": request_data.email, "access_token": token, "token_type": "bearer"}
async def get_token(req: TokenRequest):
if not verify_email_domain(req.email):
raise HTTPException(400, "Invalid email domain")
token = create_access_token({"sub": req.email})
return {"email": req.email, "access_token": token, "token_type": "bearer"}
# Endpoints with conditional auth
@app.post("/config/dump")
async def config_dump(raw: RawCode):
try:
return JSONResponse(_safe_eval_config(raw.code.strip()))
except Exception as e:
raise HTTPException(400, str(e))
@app.get("/md/{url:path}")
@@ -171,230 +204,83 @@ async def get_markdown(
url: str,
f: FilterType = FilterType.FIT,
q: Optional[str] = None,
c: Optional[str] = "0",
token_data: Optional[Dict] = Depends(token_dependency)
c: str = "0",
_td: Dict = Depends(token_dep),
):
result = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(result)
md = await handle_markdown_request(url, f, q, c, config)
return PlainTextResponse(md)
@app.get("/llm/{url:path}", description="URL should be without http/https prefix")
@app.get("/llm/{url:path}")
async def llm_endpoint(
request: Request,
url: str = Path(...),
q: Optional[str] = Query(None),
token_data: Optional[Dict] = Depends(token_dependency)
_td: Dict = Depends(token_dep),
):
if not q:
raise HTTPException(
status_code=400, detail="Query parameter 'q' is required")
if not url.startswith(('http://', 'https://')):
url = 'https://' + url
try:
answer = await handle_llm_qa(url, q, config)
return JSONResponse({"answer": answer})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(400, "Query parameter 'q' is required")
if not url.startswith(("http://", "https://")):
url = "https://" + url
answer = await handle_llm_qa(url, q, config)
return JSONResponse({"answer": answer})
@app.get("/schema")
async def get_schema():
from crawl4ai import BrowserConfig, CrawlerRunConfig
return {"browser": BrowserConfig().dump(), "crawler": CrawlerRunConfig().dump()}
return {"browser": BrowserConfig().dump(),
"crawler": CrawlerRunConfig().dump()}
@app.get(config["observability"]["health_check"]["endpoint"])
async def health():
return {"status": "ok", "timestamp": time.time(), "version": __version__}
@app.get(config["observability"]["prometheus"]["endpoint"])
async def metrics():
return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"])
@app.get("/browswers")
# Optional dependency
async def health(manager: Optional[CrawlerManager] = Depends(get_manager, use_cache=False)):
base_status = {"status": "ok", "timestamp": time.time(),
"version": __version__}
if manager:
try:
manager_status = await manager.get_status()
base_status["crawler_manager"] = manager_status
except Exception as e:
base_status["crawler_manager"] = {
"status": "error", "detail": str(e)}
else:
base_status["crawler_manager"] = {"status": "disabled"}
return base_status
return RedirectResponse(config["observability"]["prometheus"]["endpoint"])
@app.post("/crawl")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def crawl(
request: Request,
crawl_request: CrawlRequest,
manager: CrawlerManager = Depends(get_manager), # Use dependency
token_data: Optional[Dict] = Depends(token_dependency) # Keep auth
_td: Dict = Depends(token_dep),
):
if not crawl_request.urls:
raise HTTPException(
status_code=400, detail="At least one URL required")
try:
# Use the manager's context to get a crawler instance
async with manager.get_crawler() as active_crawler:
# Call the actual handler from api.py, passing the acquired crawler
results_dict = await handle_crawl_request(
crawler=active_crawler, # Pass the live crawler instance
urls=crawl_request.urls,
# Pass user-provided configs, these might override pool defaults if needed
# Or the manager/handler could decide how to merge them
browser_config=crawl_request.browser_config or {}, # Ensure dict
crawler_config=crawl_request.crawler_config or {}, # Ensure dict
config=config # Pass the global server config
)
return JSONResponse(results_dict)
except PoolTimeoutError as e:
logger.warning(f"Request rejected due to pool timeout: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, # Or 429
detail=f"Crawler resources busy. Please try again later. Timeout: {e}"
)
except NoHealthyCrawlerError as e:
logger.error(f"Request failed as no healthy crawler available: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Crawler service temporarily unavailable: {e}"
)
except HTTPException: # Re-raise HTTP exceptions from handler
raise
except Exception as e:
logger.error(
f"Unexpected error during batch crawl processing: {e}", exc_info=True)
# Return generic error, details might be logged by handle_crawl_request
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An unexpected error occurred: {e}"
)
raise HTTPException(400, "At least one URL required")
res = await handle_crawl_request(
urls=crawl_request.urls,
browser_config=crawl_request.browser_config,
crawler_config=crawl_request.crawler_config,
config=config,
)
return JSONResponse(res)
@app.post("/crawl/stream")
@limiter.limit(config["rate_limiting"]["default_limit"])
async def crawl_stream(
request: Request,
crawl_request: CrawlRequest,
manager: CrawlerManager = Depends(get_manager),
token_data: Optional[Dict] = Depends(token_dependency)
_td: Dict = Depends(token_dep),
):
if not crawl_request.urls:
raise HTTPException(
status_code=400, detail="At least one URL required")
try:
# THIS IS A BIT WORK OF ART RATHER THAN ENGINEERING
# Acquire the crawler context from the manager
# IMPORTANT: The context needs to be active for the *duration* of the stream
# This structure might be tricky with FastAPI's StreamingResponse which consumes
# the generator *after* the endpoint function returns.
# --- Option A: Acquire crawler, pass to handler, handler yields ---
# (Requires handler NOT to be async generator itself, but return one)
# async with manager.get_crawler() as active_crawler:
# # Handler returns the generator
# _, results_gen = await handle_stream_crawl_request(
# crawler=active_crawler,
# urls=crawl_request.urls,
# browser_config=crawl_request.browser_config or {},
# crawler_config=crawl_request.crawler_config or {},
# config=config
# )
# # PROBLEM: `active_crawler` context exits before StreamingResponse uses results_gen
# # This releases the semaphore too early.
# --- Option B: Pass manager to handler, handler uses context internally ---
# (Requires modifying handle_stream_crawl_request signature/logic)
# This seems cleaner. Let's assume api.py is adapted for this.
# We need a way for the generator yielded by stream_results to know when
# to release the semaphore.
# --- Option C: Create a wrapper generator that handles context ---
async def stream_wrapper(manager: CrawlerManager, crawl_request: CrawlRequest, config: dict) -> AsyncGenerator[bytes, None]:
active_crawler = None
try:
async with manager.get_crawler() as acquired_crawler:
active_crawler = acquired_crawler # Keep reference for cleanup
# Call the handler which returns the raw result generator
_crawler_ref, results_gen = await handle_stream_crawl_request(
crawler=acquired_crawler,
urls=crawl_request.urls,
browser_config=crawl_request.browser_config or {},
crawler_config=crawl_request.crawler_config or {},
config=config
)
# Use the stream_results utility to format and yield
async for data_bytes in stream_results(_crawler_ref, results_gen):
yield data_bytes
except (PoolTimeoutError, NoHealthyCrawlerError) as e:
# Yield a final error message in the stream
error_payload = {"status": "error", "detail": str(e)}
yield (json.dumps(error_payload) + "\n").encode('utf-8')
logger.warning(f"Stream request failed: {e}")
# Re-raise might be better if StreamingResponse handles it? Test needed.
except HTTPException as e: # Catch HTTP exceptions from handler setup
error_payload = {"status": "error",
"detail": e.detail, "status_code": e.status_code}
yield (json.dumps(error_payload) + "\n").encode('utf-8')
logger.warning(
f"Stream request failed with HTTPException: {e.detail}")
except Exception as e:
error_payload = {"status": "error",
"detail": f"Unexpected stream error: {e}"}
yield (json.dumps(error_payload) + "\n").encode('utf-8')
logger.error(
f"Unexpected error during stream processing: {e}", exc_info=True)
# finally:
# Ensure crawler cleanup if stream_results doesn't handle it?
# stream_results *should* call crawler.close(), but only on the
# instance it received. If we pass the *manager* instead, this gets complex.
# Let's stick to passing the acquired_crawler and rely on stream_results.
# Create the generator using the wrapper
streaming_generator = stream_wrapper(manager, crawl_request, config)
return StreamingResponse(
streaming_generator, # Use the wrapper
media_type='application/x-ndjson',
headers={'Cache-Control': 'no-cache',
'Connection': 'keep-alive', 'X-Stream-Status': 'active'}
)
except (PoolTimeoutError, NoHealthyCrawlerError) as e:
# These might occur if get_crawler fails *before* stream starts
# Or if the wrapper re-raises them.
logger.warning(f"Stream request rejected before starting: {e}")
status_code = status.HTTP_503_SERVICE_UNAVAILABLE # Or 429 for timeout
# Don't raise HTTPException here, let the wrapper yield the error message.
# If we want to return a non-200 initial status, need more complex handling.
# Return an *empty* stream with error headers? Or just let wrapper yield error.
async def _error_stream():
error_payload = {"status": "error", "detail": str(e)}
yield (json.dumps(error_payload) + "\n").encode('utf-8')
return StreamingResponse(_error_stream(), status_code=status_code, media_type='application/x-ndjson')
except HTTPException: # Re-raise HTTP exceptions from setup
raise
except Exception as e:
logger.error(
f"Unexpected error setting up stream crawl: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An unexpected error occurred setting up the stream: {e}"
)
raise HTTPException(400, "At least one URL required")
crawler, gen = await handle_stream_crawl_request(
urls=crawl_request.urls,
browser_config=crawl_request.browser_config,
crawler_config=crawl_request.crawler_config,
config=config,
)
return StreamingResponse(
stream_results(crawler, gen),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Stream-Status": "active",
},
)
# ────────────────────────── cli ──────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(
@@ -402,5 +288,6 @@ if __name__ == "__main__":
host=config["app"]["host"],
port=config["app"]["port"],
reload=config["app"]["reload"],
timeout_keep_alive=config["app"]["timeout_keep_alive"]
timeout_keep_alive=config["app"]["timeout_keep_alive"],
)
# ─────────────────────────────────────────────────────────────