feat(docker-api): add job-based polling endpoints for crawl and LLM tasks
Implements new asynchronous endpoints for handling long-running crawl and LLM tasks:
- POST /crawl/job and GET /crawl/job/{task_id} for crawl operations
- POST /llm/job and GET /llm/job/{task_id} for LLM operations
- Added Redis-based task management with configurable TTL
- Moved schema definitions to dedicated schemas.py
- Added example polling client demo_docker_polling.py
This change allows clients to handle long-running operations asynchronously through a polling pattern rather than holding connections open.
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Dict
|
||||
from functools import partial
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
|
||||
import logging
|
||||
from typing import Optional, AsyncGenerator
|
||||
@@ -272,7 +274,9 @@ async def handle_llm_request(
|
||||
async def handle_task_status(
|
||||
redis: aioredis.Redis,
|
||||
task_id: str,
|
||||
base_url: str
|
||||
base_url: str,
|
||||
*,
|
||||
keep: bool = False
|
||||
) -> JSONResponse:
|
||||
"""Handle task status check requests."""
|
||||
task = await redis.hgetall(f"task:{task_id}")
|
||||
@@ -286,7 +290,7 @@ async def handle_task_status(
|
||||
response = create_task_response(task, task_id, base_url)
|
||||
|
||||
if task["status"] in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||||
if should_cleanup_task(task["created_at"]):
|
||||
if not keep and should_cleanup_task(task["created_at"]):
|
||||
await redis.delete(f"task:{task_id}")
|
||||
|
||||
return JSONResponse(response)
|
||||
@@ -520,4 +524,48 @@ async def handle_stream_crawl_request(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_crawl_job(
|
||||
redis,
|
||||
background_tasks: BackgroundTasks,
|
||||
urls: List[str],
|
||||
browser_config: Dict,
|
||||
crawler_config: Dict,
|
||||
config: Dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Fire-and-forget version of handle_crawl_request.
|
||||
Creates a task in Redis, runs the heavy work in a background task,
|
||||
lets /crawl/job/{task_id} polling fetch the result.
|
||||
"""
|
||||
task_id = f"crawl_{uuid4().hex[:8]}"
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.PROCESSING, # <-- keep enum values consistent
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"url": json.dumps(urls), # store list as JSON string
|
||||
"result": "",
|
||||
"error": "",
|
||||
})
|
||||
|
||||
async def _runner():
|
||||
try:
|
||||
result = await handle_crawl_request(
|
||||
urls=urls,
|
||||
browser_config=browser_config,
|
||||
crawler_config=crawler_config,
|
||||
config=config,
|
||||
)
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"result": json.dumps(result),
|
||||
})
|
||||
await asyncio.sleep(5) # Give Redis time to process the update
|
||||
except Exception as exc:
|
||||
await redis.hset(f"task:{task_id}", mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(exc),
|
||||
})
|
||||
|
||||
background_tasks.add_task(_runner)
|
||||
return {"task_id": task_id}
|
||||
@@ -3,7 +3,7 @@ app:
|
||||
title: "Crawl4AI API"
|
||||
version: "1.0.0"
|
||||
host: "0.0.0.0"
|
||||
port: 11235
|
||||
port: 11234
|
||||
reload: False
|
||||
workers: 1
|
||||
timeout_keep_alive: 300
|
||||
|
||||
99
deploy/docker/job.py
Normal file
99
deploy/docker/job.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Job endpoints (enqueue + poll) for long-running LLM extraction and raw crawl.
|
||||
Relies on the existing Redis task helpers in api.py
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Callable
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Request
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from api import (
|
||||
handle_llm_request,
|
||||
handle_crawl_job,
|
||||
handle_task_status,
|
||||
)
|
||||
|
||||
# ------------- dependency placeholders -------------
|
||||
_redis = None # will be injected from server.py
|
||||
_config = None
|
||||
_token_dep: Callable = lambda: None # dummy until injected
|
||||
|
||||
# public router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# === init hook called by server.py =========================================
|
||||
def init_job_router(redis, config, token_dep) -> APIRouter:
|
||||
"""Inject shared singletons and return the router for mounting."""
|
||||
global _redis, _config, _token_dep
|
||||
_redis, _config, _token_dep = redis, config, token_dep
|
||||
return router
|
||||
|
||||
|
||||
# ---------- payload models --------------------------------------------------
|
||||
class LlmJobPayload(BaseModel):
|
||||
url: HttpUrl
|
||||
q: str
|
||||
schema: Optional[str] = None
|
||||
cache: bool = False
|
||||
|
||||
|
||||
class CrawlJobPayload(BaseModel):
|
||||
urls: list[HttpUrl]
|
||||
browser_config: Dict = {}
|
||||
crawler_config: Dict = {}
|
||||
|
||||
|
||||
# ---------- LLM job ---------------------------------------------------------
|
||||
@router.post("/llm/job", status_code=202)
|
||||
async def llm_job_enqueue(
|
||||
payload: LlmJobPayload,
|
||||
background_tasks: BackgroundTasks,
|
||||
request: Request,
|
||||
_td: Dict = Depends(lambda: _token_dep()), # late-bound dep
|
||||
):
|
||||
return await handle_llm_request(
|
||||
_redis,
|
||||
background_tasks,
|
||||
request,
|
||||
str(payload.url),
|
||||
query=payload.q,
|
||||
schema=payload.schema,
|
||||
cache=payload.cache,
|
||||
config=_config,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/llm/job/{task_id}")
|
||||
async def llm_job_status(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
_td: Dict = Depends(lambda: _token_dep())
|
||||
):
|
||||
return await handle_task_status(_redis, task_id)
|
||||
|
||||
|
||||
# ---------- CRAWL job -------------------------------------------------------
|
||||
@router.post("/crawl/job", status_code=202)
|
||||
async def crawl_job_enqueue(
|
||||
payload: CrawlJobPayload,
|
||||
background_tasks: BackgroundTasks,
|
||||
_td: Dict = Depends(lambda: _token_dep()),
|
||||
):
|
||||
return await handle_crawl_job(
|
||||
_redis,
|
||||
background_tasks,
|
||||
[str(u) for u in payload.urls],
|
||||
payload.browser_config,
|
||||
payload.crawler_config,
|
||||
config=_config,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/crawl/job/{task_id}")
|
||||
async def crawl_job_status(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
_td: Dict = Depends(lambda: _token_dep())
|
||||
):
|
||||
return await handle_task_status(_redis, task_id, base_url=str(request.base_url))
|
||||
42
deploy/docker/schemas.py
Normal file
42
deploy/docker/schemas.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List, Optional, Dict
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from utils import FilterType
|
||||
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str] = Field(min_length=1, max_length=100)
|
||||
browser_config: Optional[Dict] = Field(default_factory=dict)
|
||||
crawler_config: Optional[Dict] = Field(default_factory=dict)
|
||||
|
||||
class MarkdownRequest(BaseModel):
|
||||
"""Request body for the /md endpoint."""
|
||||
url: str = Field(..., description="Absolute http/https URL to fetch")
|
||||
f: FilterType = Field(FilterType.FIT,
|
||||
description="Content‑filter strategy: FIT, RAW, BM25, or LLM")
|
||||
q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters")
|
||||
c: Optional[str] = Field("0", description="Cache‑bust / revision counter")
|
||||
|
||||
|
||||
class RawCode(BaseModel):
|
||||
code: str
|
||||
|
||||
class HTMLRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
class ScreenshotRequest(BaseModel):
|
||||
url: str
|
||||
screenshot_wait_for: Optional[float] = 2
|
||||
output_path: Optional[str] = None
|
||||
|
||||
class PDFRequest(BaseModel):
|
||||
url: str
|
||||
output_path: Optional[str] = None
|
||||
|
||||
|
||||
class JSEndpointRequest(BaseModel):
|
||||
url: str
|
||||
scripts: List[str] = Field(
|
||||
...,
|
||||
description="List of separated JavaScript snippets to execute"
|
||||
)
|
||||
@@ -12,7 +12,7 @@ from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||
from auth import create_access_token, get_token_dependency, TokenRequest
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from fastapi import Request, Depends
|
||||
from fastapi import Request, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
import base64
|
||||
import re
|
||||
@@ -22,6 +22,16 @@ from api import (
|
||||
handle_stream_crawl_request, handle_crawl_request,
|
||||
stream_results
|
||||
)
|
||||
from schemas import (
|
||||
CrawlRequest,
|
||||
MarkdownRequest,
|
||||
RawCode,
|
||||
HTMLRequest,
|
||||
ScreenshotRequest,
|
||||
PDFRequest,
|
||||
JSEndpointRequest,
|
||||
)
|
||||
|
||||
from utils import (
|
||||
FilterType, load_config, setup_logging, verify_email_domain
|
||||
)
|
||||
@@ -37,23 +47,13 @@ from fastapi import (
|
||||
FastAPI, HTTPException, Request, Path, Query, Depends
|
||||
)
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
def chunk_code_functions(code: str) -> List[str]:
|
||||
tree = ast.parse(code)
|
||||
lines = code.splitlines()
|
||||
chunks = []
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
start = node.lineno - 1
|
||||
end = getattr(node, 'end_lineno', start + 1)
|
||||
chunks.append("\n".join(lines[start:end]))
|
||||
return chunks
|
||||
from fastapi.responses import (
|
||||
StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
|
||||
)
|
||||
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from job import init_job_router
|
||||
|
||||
from mcp_bridge import attach_mcp, mcp_resource, mcp_template, mcp_tool
|
||||
|
||||
@@ -129,8 +129,6 @@ app.mount(
|
||||
name="play",
|
||||
)
|
||||
|
||||
# Optional nice‑to‑have: opening the root shows the playground
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
@@ -211,48 +209,10 @@ def _safe_eval_config(expr: str) -> dict:
|
||||
return obj.dump()
|
||||
|
||||
|
||||
# ───────────────────────── Schemas ───────────────────────────
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str] = Field(min_length=1, max_length=100)
|
||||
browser_config: Optional[Dict] = Field(default_factory=dict)
|
||||
crawler_config: Optional[Dict] = Field(default_factory=dict)
|
||||
|
||||
# ────────────── Schemas ──────────────
|
||||
class MarkdownRequest(BaseModel):
|
||||
"""Request body for the /md endpoint."""
|
||||
url: str = Field(..., description="Absolute http/https URL to fetch")
|
||||
f: FilterType = Field(FilterType.FIT,
|
||||
description="Content‑filter strategy: FIT, RAW, BM25, or LLM")
|
||||
q: Optional[str] = Field(None, description="Query string used by BM25/LLM filters")
|
||||
c: Optional[str] = Field("0", description="Cache‑bust / revision counter")
|
||||
|
||||
|
||||
class RawCode(BaseModel):
|
||||
code: str
|
||||
|
||||
class HTMLRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
class ScreenshotRequest(BaseModel):
|
||||
url: str
|
||||
screenshot_wait_for: Optional[float] = 2
|
||||
output_path: Optional[str] = None
|
||||
|
||||
class PDFRequest(BaseModel):
|
||||
url: str
|
||||
output_path: Optional[str] = None
|
||||
|
||||
|
||||
class JSEndpointRequest(BaseModel):
|
||||
url: str
|
||||
scripts: List[str] = Field(
|
||||
...,
|
||||
description="List of separated JavaScript snippets to execute"
|
||||
)
|
||||
# ── job router ──────────────────────────────────────────────
|
||||
app.include_router(init_job_router(redis, config, token_dep))
|
||||
|
||||
# ──────────────────────── Endpoints ──────────────────────────
|
||||
|
||||
|
||||
@app.post("/token")
|
||||
async def get_token(req: TokenRequest):
|
||||
if not verify_email_domain(req.email):
|
||||
@@ -278,7 +238,8 @@ async def get_markdown(
|
||||
_td: Dict = Depends(token_dep),
|
||||
):
|
||||
if not body.url.startswith(("http://", "https://")):
|
||||
raise HTTPException(400, "URL must be absolute and start with http/https")
|
||||
raise HTTPException(
|
||||
400, "URL must be absolute and start with http/https")
|
||||
markdown = await handle_markdown_request(
|
||||
body.url, body.f, body.q, body.c, config
|
||||
)
|
||||
@@ -314,12 +275,13 @@ async def generate_html(
|
||||
|
||||
# Screenshot endpoint
|
||||
|
||||
|
||||
@app.post("/screenshot")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
@mcp_tool("screenshot")
|
||||
async def generate_screenshot(
|
||||
request: Request,
|
||||
body: ScreenshotRequest,
|
||||
body: ScreenshotRequest,
|
||||
_td: Dict = Depends(token_dep),
|
||||
):
|
||||
"""
|
||||
@@ -327,7 +289,8 @@ async def generate_screenshot(
|
||||
Use when you need an image snapshot of the rendered page. Its recommened to provide an output path to save the screenshot.
|
||||
Then in result instead of the screenshot you will get a path to the saved file.
|
||||
"""
|
||||
cfg = CrawlerRunConfig(screenshot=True, screenshot_wait_for=body.screenshot_wait_for)
|
||||
cfg = CrawlerRunConfig(
|
||||
screenshot=True, screenshot_wait_for=body.screenshot_wait_for)
|
||||
async with AsyncWebCrawler(config=BrowserConfig()) as crawler:
|
||||
results = await crawler.arun(url=body.url, config=cfg)
|
||||
screenshot_data = results[0].screenshot
|
||||
@@ -341,12 +304,13 @@ async def generate_screenshot(
|
||||
|
||||
# PDF endpoint
|
||||
|
||||
|
||||
@app.post("/pdf")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
@mcp_tool("pdf")
|
||||
async def generate_pdf(
|
||||
request: Request,
|
||||
body: PDFRequest,
|
||||
body: PDFRequest,
|
||||
_td: Dict = Depends(token_dep),
|
||||
):
|
||||
"""
|
||||
@@ -384,7 +348,7 @@ async def execute_js(
|
||||
Your script will replace '{script}' and execute in the browser context. So provide either an IIFE or a sync/async function that returns a value.
|
||||
Return Format:
|
||||
- The return result is an instance of CrawlResult, so you have access to markdown, links, and other stuff. If this is enough, you don't need to call again for other endpoints.
|
||||
|
||||
|
||||
```python
|
||||
class CrawlResult(BaseModel):
|
||||
url: str
|
||||
@@ -418,7 +382,7 @@ async def execute_js(
|
||||
fit_markdown: Optional[str] = None
|
||||
fit_html: Optional[str] = None
|
||||
```
|
||||
|
||||
|
||||
"""
|
||||
cfg = CrawlerRunConfig(js_code=body.scripts)
|
||||
async with AsyncWebCrawler(config=BrowserConfig()) as crawler:
|
||||
@@ -507,6 +471,7 @@ async def crawl_stream(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def chunk_code_functions(code_md: str) -> List[str]:
|
||||
"""Extract each function/class from markdown code blocks per file."""
|
||||
pattern = re.compile(
|
||||
@@ -530,6 +495,7 @@ def chunk_code_functions(code_md: str) -> List[str]:
|
||||
chunks.append(f"# File: {file_path}\n{snippet}")
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_doc_sections(doc: str) -> List[str]:
|
||||
lines = doc.splitlines(keepends=True)
|
||||
sections = []
|
||||
@@ -545,6 +511,7 @@ def chunk_doc_sections(doc: str) -> List[str]:
|
||||
sections.append("".join(current))
|
||||
return sections
|
||||
|
||||
|
||||
@app.get("/ask")
|
||||
@limiter.limit(config["rate_limiting"]["default_limit"])
|
||||
@mcp_tool("ask")
|
||||
@@ -552,21 +519,24 @@ async def get_context(
|
||||
request: Request,
|
||||
_td: Dict = Depends(token_dep),
|
||||
context_type: str = Query("all", regex="^(code|doc|all)$"),
|
||||
query: Optional[str] = Query(None, description="search query to filter chunks"),
|
||||
score_ratio: float = Query(0.5, ge=0.0, le=1.0, description="min score as fraction of max_score"),
|
||||
max_results: int = Query(20, ge=1, description="absolute cap on returned chunks"),
|
||||
query: Optional[str] = Query(
|
||||
None, description="search query to filter chunks"),
|
||||
score_ratio: float = Query(
|
||||
0.5, ge=0.0, le=1.0, description="min score as fraction of max_score"),
|
||||
max_results: int = Query(
|
||||
20, ge=1, description="absolute cap on returned chunks"),
|
||||
):
|
||||
"""
|
||||
This end point is design for any questions about Crawl4ai library. It returns a plain text markdown with extensive information about Crawl4ai.
|
||||
You can use this as a context for any AI assistant. Use this endpoint for AI assistants to retrieve library context for decision making or code generation tasks.
|
||||
Alway is BEST practice you provide a query to filter the context. Otherwise the lenght of the response will be very long.
|
||||
|
||||
|
||||
Parameters:
|
||||
- context_type: Specify "code" for code context, "doc" for documentation context, or "all" for both.
|
||||
- query: RECOMMENDED search query to filter paragraphs using BM25. You can leave this empty to get all the context.
|
||||
- score_ratio: Minimum score as a fraction of the maximum score for filtering results.
|
||||
- max_results: Maximum number of results to return. Default is 20.
|
||||
|
||||
|
||||
Returns:
|
||||
- JSON response with the requested context.
|
||||
- If "code" is specified, returns the code context.
|
||||
@@ -576,7 +546,7 @@ async def get_context(
|
||||
# load contexts
|
||||
base = os.path.dirname(__file__)
|
||||
code_path = os.path.join(base, "c4ai-code-context.md")
|
||||
doc_path = os.path.join(base, "c4ai-doc-context.md")
|
||||
doc_path = os.path.join(base, "c4ai-doc-context.md")
|
||||
if not os.path.exists(code_path) or not os.path.exists(doc_path):
|
||||
raise HTTPException(404, "Context files not found")
|
||||
|
||||
@@ -626,7 +596,7 @@ async def get_context(
|
||||
]
|
||||
|
||||
return JSONResponse(results)
|
||||
|
||||
|
||||
|
||||
# attach MCP layer (adds /mcp/ws, /mcp/sse, /mcp/schema)
|
||||
print(f"MCP server running on {config['app']['host']}:{config['app']['port']}")
|
||||
|
||||
@@ -45,10 +45,10 @@ def datetime_handler(obj: any) -> Optional[str]:
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
|
||||
def should_cleanup_task(created_at: str) -> bool:
|
||||
def should_cleanup_task(created_at: str, ttl_seconds: int = 3600) -> bool:
|
||||
"""Check if task should be cleaned up based on creation time."""
|
||||
created = datetime.fromisoformat(created_at)
|
||||
return (datetime.now() - created).total_seconds() > 3600
|
||||
return (datetime.now() - created).total_seconds() > ttl_seconds
|
||||
|
||||
def decode_redis_hash(hash_data: Dict[bytes, bytes]) -> Dict[str, str]:
|
||||
"""Decode Redis hash data from bytes to strings."""
|
||||
|
||||
Reference in New Issue
Block a user