feat(docker-api): add job-based polling endpoints for crawl and LLM tasks

Implements new asynchronous endpoints for handling long-running crawl and LLM tasks:
- POST /crawl/job and GET /crawl/job/{task_id} for crawl operations
- POST /llm/job and GET /llm/job/{task_id} for LLM operations
- Added Redis-based task management with configurable TTL
- Moved schema definitions to dedicated schemas.py
- Added example polling client demo_docker_polling.py

This change allows clients to handle long-running operations asynchronously through a polling pattern rather than holding connections open.
This commit is contained in:
UncleCode
2025-05-01 21:24:52 +08:00
parent ee01b81f3e
commit 94e9959fe0
8 changed files with 385 additions and 75 deletions

2
.gitignore vendored
View File

@@ -262,3 +262,5 @@ CLAUDE.md
tests/**/test_site tests/**/test_site
tests/**/reports tests/**/reports
tests/**/benchmark_reports tests/**/benchmark_reports
.codecat/

View File

@@ -1,8 +1,10 @@
import os import os
import json import json
import asyncio import asyncio
from typing import List, Tuple from typing import List, Tuple, Dict
from functools import partial from functools import partial
from uuid import uuid4
from datetime import datetime
import logging import logging
from typing import Optional, AsyncGenerator from typing import Optional, AsyncGenerator
@@ -272,7 +274,9 @@ async def handle_llm_request(
async def handle_task_status( async def handle_task_status(
redis: aioredis.Redis, redis: aioredis.Redis,
task_id: str, task_id: str,
base_url: str base_url: str,
*,
keep: bool = False
) -> JSONResponse: ) -> JSONResponse:
"""Handle task status check requests.""" """Handle task status check requests."""
task = await redis.hgetall(f"task:{task_id}") task = await redis.hgetall(f"task:{task_id}")
@@ -286,7 +290,7 @@ async def handle_task_status(
response = create_task_response(task, task_id, base_url) response = create_task_response(task, task_id, base_url)
if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]: if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
if should_cleanup_task(task["created_at"]): if not keep and should_cleanup_task(task["created_at"]):
await redis.delete(f"task:{task_id}") await redis.delete(f"task:{task_id}")
return JSONResponse(response) return JSONResponse(response)
@@ -521,3 +525,47 @@ async def handle_stream_crawl_request(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e) detail=str(e)
) )
async def handle_crawl_job(
redis,
background_tasks: BackgroundTasks,
urls: List[str],
browser_config: Dict,
crawler_config: Dict,
config: Dict,
) -> Dict:
"""
Fire-and-forget version of handle_crawl_request.
Creates a task in Redis, runs the heavy work in a background task,
lets /crawl/job/{task_id} polling fetch the result.
"""
task_id = f"crawl_{uuid4().hex[:8]}"
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.PROCESSING, # <-- keep enum values consistent
"created_at": datetime.utcnow().isoformat(),
"url": json.dumps(urls), # store list as JSON string
"result": "",
"error": "",
})
async def _runner():
try:
result = await handle_crawl_request(
urls=urls,
browser_config=browser_config,
crawler_config=crawler_config,
config=config,
)
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.COMPLETED,
"result": json.dumps(result),
})
await asyncio.sleep(5) # Give Redis time to process the update
except Exception as exc:
await redis.hset(f"task:{task_id}", mapping={
"status": TaskStatus.FAILED,
"error": str(exc),
})
background_tasks.add_task(_runner)
return {"task_id": task_id}

View File

@@ -3,7 +3,7 @@ app:
title: "Crawl4AI API" title: "Crawl4AI API"
version: "1.0.0" version: "1.0.0"
host: "0.0.0.0" host: "0.0.0.0"
port: 11235 port: 11234
reload: False reload: False
workers: 1 workers: 1
timeout_keep_alive: 300 timeout_keep_alive: 300

99
deploy/docker/job.py Normal file
View File

@@ -0,0 +1,99 @@
"""
Job endpoints (enqueue + poll) for long-running LLM extraction and raw crawl.
Relies on the existing Redis task helpers in api.py
"""
from typing import Dict, Optional, Callable
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from pydantic import BaseModel, HttpUrl
from api import (
handle_llm_request,
handle_crawl_job,
handle_task_status,
)
# ------------- dependency placeholders -------------
_redis = None # will be injected from server.py
_config = None
_token_dep: Callable = lambda: None # dummy until injected
# public router
router = APIRouter()
# === init hook called by server.py =========================================
def init_job_router(redis, config, token_dep) -> APIRouter:
"""Inject shared singletons and return the router for mounting."""
global _redis, _config, _token_dep
_redis, _config, _token_dep = redis, config, token_dep
return router
# ---------- payload models --------------------------------------------------
class LlmJobPayload(BaseModel):
url: HttpUrl
q: str
schema: Optional[str] = None
cache: bool = False
class CrawlJobPayload(BaseModel):
urls: list[HttpUrl]
browser_config: Dict = {}
crawler_config: Dict = {}
# ---------- LLM job ---------------------------------------------------------
@router.post("/llm/job", status_code=202)
async def llm_job_enqueue(
payload: LlmJobPayload,
background_tasks: BackgroundTasks,
request: Request,
_td: Dict = Depends(lambda: _token_dep()), # late-bound dep
):
return await handle_llm_request(
_redis,
background_tasks,
request,
str(payload.url),
query=payload.q,
schema=payload.schema,
cache=payload.cache,
config=_config,
)
@router.get("/llm/job/{task_id}")
async def llm_job_status(
request: Request,
task_id: str,
_td: Dict = Depends(lambda: _token_dep())
):
return await handle_task_status(_redis, task_id)
# ---------- CRAWL job -------------------------------------------------------
@router.post("/crawl/job", status_code=202)
async def crawl_job_enqueue(
payload: CrawlJobPayload,
background_tasks: BackgroundTasks,
_td: Dict = Depends(lambda: _token_dep()),
):
return await handle_crawl_job(
_redis,
background_tasks,
[str(u) for u in payload.urls],
payload.browser_config,
payload.crawler_config,
config=_config,
)
@router.get("/crawl/job/{task_id}")
async def crawl_job_status(
request: Request,
task_id: str,
_td: Dict = Depends(lambda: _token_dep())
):
return await handle_task_status(_redis, task_id, base_url=str(request.base_url))

42
deploy/docker/schemas.py Normal file
View File

@@ -0,0 +1,42 @@
from typing import List, Optional, Dict
from enum import Enum
from pydantic import BaseModel, Field
from utils import FilterType
class CrawlRequest(BaseModel):
urls: List[str] = Field(min_length=1, max_length=100)
browser_config: Optional[Dict] = Field(default_factory=dict)
crawler_config: Optional[Dict] = Field(default_factory=dict)
class MarkdownRequest(BaseModel):
"""Request body for the /md endpoint."""
url: str = Field(..., description="Absolute http/https URL to fetch")
f: FilterType = Field(FilterType.FIT,
description="Contentfilter strategy: FIT, RAW, BM25, or LLM")
q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters")
c: Optional[str] = Field("0", description="Cachebust / revision counter")
class RawCode(BaseModel):
code: str
class HTMLRequest(BaseModel):
url: str
class ScreenshotRequest(BaseModel):
url: str
screenshot_wait_for: Optional[float] = 2
output_path: Optional[str] = None
class PDFRequest(BaseModel):
url: str
output_path: Optional[str] = None
class JSEndpointRequest(BaseModel):
url: str
scripts: List[str] = Field(
...,
description="List of separated JavaScript snippets to execute"
)

View File

@@ -22,6 +22,16 @@ from api import (
handle_stream_crawl_request, handle_crawl_request, handle_stream_crawl_request, handle_crawl_request,
stream_results stream_results
) )
from schemas import (
CrawlRequest,
MarkdownRequest,
RawCode,
HTMLRequest,
ScreenshotRequest,
PDFRequest,
JSEndpointRequest,
)
from utils import ( from utils import (
FilterType, load_config, setup_logging, verify_email_domain FilterType, load_config, setup_logging, verify_email_domain
) )
@@ -37,23 +47,13 @@ from fastapi import (
FastAPI, HTTPException, Request, Path, Query, Depends FastAPI, HTTPException, Request, Path, Query, Depends
) )
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
def chunk_code_functions(code: str) -> List[str]:
tree = ast.parse(code)
lines = code.splitlines()
chunks = []
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
start = node.lineno - 1
end = getattr(node, 'end_lineno', start + 1)
chunks.append("\n".join(lines[start:end]))
return chunks
from fastapi.responses import ( from fastapi.responses import (
StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
) )
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from job import init_job_router
from mcp_bridge import attach_mcp, mcp_resource, mcp_template, mcp_tool from mcp_bridge import attach_mcp, mcp_resource, mcp_template, mcp_tool
@@ -129,8 +129,6 @@ app.mount(
name="play", name="play",
) )
# Optional nicetohave: opening the root shows the playground
@app.get("/") @app.get("/")
async def root(): async def root():
@@ -211,48 +209,10 @@ def _safe_eval_config(expr: str) -> dict:
return obj.dump() return obj.dump()
# ───────────────────────── Schemas ─────────────────────────── # ── job router ──────────────────────────────────────────────
class CrawlRequest(BaseModel): app.include_router(init_job_router(redis, config, token_dep))
urls: List[str] = Field(min_length=1, max_length=100)
browser_config: Optional[Dict] = Field(default_factory=dict)
crawler_config: Optional[Dict] = Field(default_factory=dict)
# ────────────── Schemas ──────────────
class MarkdownRequest(BaseModel):
"""Request body for the /md endpoint."""
url: str = Field(..., description="Absolute http/https URL to fetch")
f: FilterType = Field(FilterType.FIT,
description="Contentfilter strategy: FIT, RAW, BM25, or LLM")
q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters")
c: Optional[str] = Field("0", description="Cachebust / revision counter")
class RawCode(BaseModel):
code: str
class HTMLRequest(BaseModel):
url: str
class ScreenshotRequest(BaseModel):
url: str
screenshot_wait_for: Optional[float] = 2
output_path: Optional[str] = None
class PDFRequest(BaseModel):
url: str
output_path: Optional[str] = None
class JSEndpointRequest(BaseModel):
url: str
scripts: List[str] = Field(
...,
description="List of separated JavaScript snippets to execute"
)
# ──────────────────────── Endpoints ────────────────────────── # ──────────────────────── Endpoints ──────────────────────────
@app.post("/token") @app.post("/token")
async def get_token(req: TokenRequest): async def get_token(req: TokenRequest):
if not verify_email_domain(req.email): if not verify_email_domain(req.email):
@@ -278,7 +238,8 @@ async def get_markdown(
_td: Dict = Depends(token_dep), _td: Dict = Depends(token_dep),
): ):
if not body.url.startswith(("http://", "https://")): if not body.url.startswith(("http://", "https://")):
raise HTTPException(400, "URL must be absolute and start with http/https") raise HTTPException(
400, "URL must be absolute and start with http/https")
markdown = await handle_markdown_request( markdown = await handle_markdown_request(
body.url, body.f, body.q, body.c, config body.url, body.f, body.q, body.c, config
) )
@@ -314,6 +275,7 @@ async def generate_html(
# Screenshot endpoint # Screenshot endpoint
@app.post("/screenshot") @app.post("/screenshot")
@limiter.limit(config["rate_limiting"]["default_limit"]) @limiter.limit(config["rate_limiting"]["default_limit"])
@mcp_tool("screenshot") @mcp_tool("screenshot")
@@ -327,7 +289,8 @@ async def generate_screenshot(
Use when you need an image snapshot of the rendered page. Its recommened to provide an output path to save the screenshot. Use when you need an image snapshot of the rendered page. Its recommened to provide an output path to save the screenshot.
Then in result instead of the screenshot you will get a path to the saved file. Then in result instead of the screenshot you will get a path to the saved file.
""" """
cfg = CrawlerRunConfig(screenshot=True, screenshot_wait_for=body.screenshot_wait_for) cfg = CrawlerRunConfig(
screenshot=True, screenshot_wait_for=body.screenshot_wait_for)
async with AsyncWebCrawler(config=BrowserConfig()) as crawler: async with AsyncWebCrawler(config=BrowserConfig()) as crawler:
results = await crawler.arun(url=body.url, config=cfg) results = await crawler.arun(url=body.url, config=cfg)
screenshot_data = results[0].screenshot screenshot_data = results[0].screenshot
@@ -341,6 +304,7 @@ async def generate_screenshot(
# PDF endpoint # PDF endpoint
@app.post("/pdf") @app.post("/pdf")
@limiter.limit(config["rate_limiting"]["default_limit"]) @limiter.limit(config["rate_limiting"]["default_limit"])
@mcp_tool("pdf") @mcp_tool("pdf")
@@ -507,6 +471,7 @@ async def crawl_stream(
}, },
) )
def chunk_code_functions(code_md: str) -> List[str]: def chunk_code_functions(code_md: str) -> List[str]:
"""Extract each function/class from markdown code blocks per file.""" """Extract each function/class from markdown code blocks per file."""
pattern = re.compile( pattern = re.compile(
@@ -530,6 +495,7 @@ def chunk_code_functions(code_md: str) -> List[str]:
chunks.append(f"# File: {file_path}\n{snippet}") chunks.append(f"# File: {file_path}\n{snippet}")
return chunks return chunks
def chunk_doc_sections(doc: str) -> List[str]: def chunk_doc_sections(doc: str) -> List[str]:
lines = doc.splitlines(keepends=True) lines = doc.splitlines(keepends=True)
sections = [] sections = []
@@ -545,6 +511,7 @@ def chunk_doc_sections(doc: str) -> List[str]:
sections.append("".join(current)) sections.append("".join(current))
return sections return sections
@app.get("/ask") @app.get("/ask")
@limiter.limit(config["rate_limiting"]["default_limit"]) @limiter.limit(config["rate_limiting"]["default_limit"])
@mcp_tool("ask") @mcp_tool("ask")
@@ -552,9 +519,12 @@ async def get_context(
request: Request, request: Request,
_td: Dict = Depends(token_dep), _td: Dict = Depends(token_dep),
context_type: str = Query("all", regex="^(code|doc|all)$"), context_type: str = Query("all", regex="^(code|doc|all)$"),
query: Optional[str] = Query(None, description="search query to filter chunks"), query: Optional[str] = Query(
score_ratio: float = Query(0.5, ge=0.0, le=1.0, description="min score as fraction of max_score"), None, description="search query to filter chunks"),
max_results: int = Query(20, ge=1, description="absolute cap on returned chunks"), score_ratio: float = Query(
0.5, ge=0.0, le=1.0, description="min score as fraction of max_score"),
max_results: int = Query(
20, ge=1, description="absolute cap on returned chunks"),
): ):
""" """
This end point is design for any questions about Crawl4ai library. It returns a plain text markdown with extensive information about Crawl4ai. This end point is design for any questions about Crawl4ai library. It returns a plain text markdown with extensive information about Crawl4ai.

View File

@@ -45,10 +45,10 @@ def datetime_handler(obj: any) -> Optional[str]:
return obj.isoformat() return obj.isoformat()
raise TypeError(f"Object of type {type(obj)} is not JSON serializable") raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
def should_cleanup_task(created_at: str) -> bool: def should_cleanup_task(created_at: str, ttl_seconds: int = 3600) -> bool:
"""Check if task should be cleaned up based on creation time.""" """Check if task should be cleaned up based on creation time."""
created = datetime.fromisoformat(created_at) created = datetime.fromisoformat(created_at)
return (datetime.now() - created).total_seconds() > 3600 return (datetime.now() - created).total_seconds() > ttl_seconds
def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]: def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
"""Decode Redis hash data from bytes to strings.""" """Decode Redis hash data from bytes to strings."""

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python3
"""
demo_docker_polling.py
Quick sanity-check for the asynchronous crawl job endpoints:
• POST /crawl/job enqueue work, get task_id
• GET /crawl/job/{id} poll status / fetch result
The style matches demo_docker_api.py (console.rule banners, helper
functions, coloured status lines). Adjust BASE_URL as needed.
Run: python demo_docker_polling.py
"""
import asyncio, json, os, time, urllib.parse
from typing import Dict, List
import httpx
from rich.console import Console
from rich.panel import Panel
from rich.syntax import Syntax
console = Console()
BASE_URL = os.getenv("BASE_URL", "http://localhost:11234")
SIMPLE_URL = "https://example.org"
LINKS_URL = "https://httpbin.org/links/10/1"
# --- helpers --------------------------------------------------------------
def print_payload(payload: Dict):
console.print(Panel(Syntax(json.dumps(payload, indent=2),
"json", theme="monokai", line_numbers=False),
title="Payload", border_style="cyan", expand=False))
async def check_server_health(client: httpx.AsyncClient) -> bool:
try:
resp = await client.get("/health")
if resp.is_success:
console.print("[green]Server healthy[/]")
return True
except Exception:
pass
console.print("[bold red]Server is not responding on /health[/]")
return False
async def poll_for_result(client: httpx.AsyncClient, task_id: str,
poll_interval: float = 1.5, timeout: float = 90.0):
"""Hit /crawl/job/{id} until COMPLETED/FAILED or timeout."""
start = time.time()
while True:
resp = await client.get(f"/crawl/job/{task_id}")
resp.raise_for_status()
data = resp.json()
status = data.get("status")
if status.upper() in ("COMPLETED", "FAILED"):
return data
if time.time() - start > timeout:
raise TimeoutError(f"Task {task_id} did not finish in {timeout}s")
await asyncio.sleep(poll_interval)
# --- demo functions -------------------------------------------------------
async def demo_poll_single_url(client: httpx.AsyncClient):
payload = {
"urls": [SIMPLE_URL],
"browser_config": {"type": "BrowserConfig",
"params": {"headless": True}},
"crawler_config": {"type": "CrawlerRunConfig",
"params": {"cache_mode": "BYPASS"}}
}
console.rule("[bold blue]Demo A: /crawl/job Single URL[/]", style="blue")
print_payload(payload)
# enqueue
resp = await client.post("/crawl/job", json=payload)
console.print(f"Enqueue status: [bold]{resp.status_code}[/]")
resp.raise_for_status()
task_id = resp.json()["task_id"]
console.print(f"Task ID: [yellow]{task_id}[/]")
# poll
console.print("Polling…")
result = await poll_for_result(client, task_id)
console.print(Panel(Syntax(json.dumps(result, indent=2),
"json", theme="fruity"),
title="Final result", border_style="green"))
if result["status"] == "COMPLETED":
console.print("[green]✅ Crawl succeeded[/]")
else:
console.print("[red]❌ Crawl failed[/]")
async def demo_poll_multi_url(client: httpx.AsyncClient):
payload = {
"urls": [SIMPLE_URL, LINKS_URL],
"browser_config": {"type": "BrowserConfig",
"params": {"headless": True}},
"crawler_config": {"type": "CrawlerRunConfig",
"params": {"cache_mode": "BYPASS"}}
}
console.rule("[bold magenta]Demo B: /crawl/job Multi-URL[/]",
style="magenta")
print_payload(payload)
resp = await client.post("/crawl/job", json=payload)
console.print(f"Enqueue status: [bold]{resp.status_code}[/]")
resp.raise_for_status()
task_id = resp.json()["task_id"]
console.print(f"Task ID: [yellow]{task_id}[/]")
console.print("Polling…")
result = await poll_for_result(client, task_id)
console.print(Panel(Syntax(json.dumps(result, indent=2),
"json", theme="fruity"),
title="Final result", border_style="green"))
if result["status"] == "COMPLETED":
console.print(
f"[green]✅ {len(json.loads(result['result'])['results'])} URLs crawled[/]")
else:
console.print("[red]❌ Crawl failed[/]")
# --- main runner ----------------------------------------------------------
async def main_demo():
async with httpx.AsyncClient(base_url=BASE_URL, timeout=300.0) as client:
if not await check_server_health(client):
return
await demo_poll_single_url(client)
await demo_poll_multi_url(client)
console.rule("[bold green]Polling demos complete[/]", style="green")
if __name__ == "__main__":
try:
asyncio.run(main_demo())
except KeyboardInterrupt:
console.print("\n[yellow]Interrupted by user[/]")
except Exception:
console.print_exception(show_locals=False)