feat(adaptive-crawling): implement adaptive crawling endpoints and job management
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -271,3 +271,4 @@ docs/**/data
|
|||||||
|
|
||||||
docs/apps/linkdin/debug*/
|
docs/apps/linkdin/debug*/
|
||||||
docs/apps/linkdin/samples/insights/*
|
docs/apps/linkdin/samples/insights/*
|
||||||
|
.yoyo/
|
||||||
155
deploy/docker/adaptive_routes.py
Normal file
155
deploy/docker/adaptive_routes.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crawl4ai import AsyncWebCrawler
|
||||||
|
from crawl4ai.adaptive_crawler import AdaptiveConfig, AdaptiveCrawler
|
||||||
|
from crawl4ai.utils import get_error_context
|
||||||
|
|
||||||
|
# --- In-memory storage for job statuses. For production, use Redis or a database. ---
|
||||||
|
ADAPTIVE_JOBS: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# --- Pydantic Models for API Validation ---
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveConfigPayload(BaseModel):
|
||||||
|
"""Pydantic model for receiving AdaptiveConfig parameters."""
|
||||||
|
|
||||||
|
confidence_threshold: float = 0.7
|
||||||
|
max_pages: int = 20
|
||||||
|
top_k_links: int = 3
|
||||||
|
strategy: str = "statistical" # "statistical" or "embedding"
|
||||||
|
embedding_model: Optional[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
# Add any other AdaptiveConfig fields you want to expose
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveCrawlRequest(BaseModel):
|
||||||
|
"""Input model for the adaptive digest job."""
|
||||||
|
|
||||||
|
start_url: str = Field(..., description="The starting URL for the adaptive crawl.")
|
||||||
|
query: str = Field(..., description="The user query to guide the crawl.")
|
||||||
|
config: Optional[AdaptiveConfigPayload] = Field(
|
||||||
|
None, description="Optional adaptive crawler configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveJobStatus(BaseModel):
|
||||||
|
"""Output model for the job status."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
status: str
|
||||||
|
metrics: Optional[Dict[str, Any]] = None
|
||||||
|
result: Optional[Dict[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# --- APIRouter for Adaptive Crawling Endpoints ---
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/adaptive/digest",
|
||||||
|
tags=["Adaptive Crawling"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Background Worker Function ---
|
||||||
|
|
||||||
|
|
||||||
|
async def run_adaptive_digest(task_id: str, request: AdaptiveCrawlRequest):
|
||||||
|
"""The actual async worker that performs the adaptive crawl."""
|
||||||
|
try:
|
||||||
|
# Update job status to RUNNING
|
||||||
|
ADAPTIVE_JOBS[task_id]["status"] = "RUNNING"
|
||||||
|
|
||||||
|
# Create AdaptiveConfig from payload or use default
|
||||||
|
if request.config:
|
||||||
|
adaptive_config = AdaptiveConfig(**request.config.model_dump())
|
||||||
|
else:
|
||||||
|
adaptive_config = AdaptiveConfig()
|
||||||
|
|
||||||
|
# The adaptive crawler needs an instance of the web crawler
|
||||||
|
async with AsyncWebCrawler() as crawler:
|
||||||
|
adaptive_crawler = AdaptiveCrawler(crawler, config=adaptive_config)
|
||||||
|
|
||||||
|
# This is the long-running operation
|
||||||
|
final_state = await adaptive_crawler.digest(
|
||||||
|
start_url=request.start_url, query=request.query
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the final state into a clean result
|
||||||
|
result_data = {
|
||||||
|
"confidence": final_state.metrics.get("confidence", 0.0),
|
||||||
|
"is_sufficient": adaptive_crawler.is_sufficient,
|
||||||
|
"coverage_stats": adaptive_crawler.coverage_stats,
|
||||||
|
"relevant_content": adaptive_crawler.get_relevant_content(top_k=5),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update job with the final result
|
||||||
|
ADAPTIVE_JOBS[task_id].update(
|
||||||
|
{
|
||||||
|
"status": "COMPLETED",
|
||||||
|
"result": result_data,
|
||||||
|
"metrics": final_state.metrics,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# On failure, update the job with an error message
|
||||||
|
import sys
|
||||||
|
|
||||||
|
error_context = get_error_context(sys.exc_info())
|
||||||
|
error_message = f"Adaptive crawl failed: {str(e)}\nContext: {error_context}"
|
||||||
|
|
||||||
|
ADAPTIVE_JOBS[task_id].update({"status": "FAILED", "error": error_message})
|
||||||
|
|
||||||
|
|
||||||
|
# --- API Endpoints ---
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/job", response_model=AdaptiveJobStatus, status_code=202)
|
||||||
|
async def submit_adaptive_digest_job(
|
||||||
|
request: AdaptiveCrawlRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Submit a new adaptive crawling job.
|
||||||
|
|
||||||
|
This endpoint starts a long-running adaptive crawl in the background and
|
||||||
|
immediately returns a task ID for polling the job's status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("Received adaptive crawl request:", request)
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Initialize the job in our in-memory store
|
||||||
|
ADAPTIVE_JOBS[task_id] = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "PENDING",
|
||||||
|
"metrics": None,
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add the long-running task to the background
|
||||||
|
background_tasks.add_task(run_adaptive_digest, task_id, request)
|
||||||
|
|
||||||
|
return ADAPTIVE_JOBS[task_id]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/job/{task_id}", response_model=AdaptiveJobStatus)
|
||||||
|
async def get_adaptive_digest_status(task_id: str):
|
||||||
|
"""
|
||||||
|
Get the status and result of an adaptive crawling job.
|
||||||
|
|
||||||
|
Poll this endpoint with the `task_id` returned from the submission
|
||||||
|
endpoint until the status is 'COMPLETED' or 'FAILED'.
|
||||||
|
"""
|
||||||
|
job = ADAPTIVE_JOBS.get(task_id)
|
||||||
|
if not job:
|
||||||
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
|
# If the job is running, update the metrics from the live state
|
||||||
|
if job["status"] == "RUNNING" and job.get("live_state"):
|
||||||
|
job["metrics"] = job["live_state"].metrics
|
||||||
|
|
||||||
|
return job
|
||||||
@@ -7,70 +7,63 @@ Crawl4AI FastAPI entry‑point
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# ── stdlib & 3rd‑party imports ───────────────────────────────
|
# ── stdlib & 3rd‑party imports ───────────────────────────────
|
||||||
from crawler_pool import get_crawler, close_all, janitor
|
import ast
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
|
||||||
from auth import create_access_token, get_token_dependency, TokenRequest
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from fastapi import Request, Depends
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
import base64
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
import base64
|
||||||
from pathlib import Path
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
|
||||||
from api import (
|
|
||||||
handle_markdown_request, handle_llm_qa,
|
|
||||||
handle_stream_crawl_request, handle_crawl_request,
|
|
||||||
stream_results, handle_seed
|
|
||||||
)
|
|
||||||
from schemas import (
|
|
||||||
CrawlRequestWithHooks,
|
|
||||||
MarkdownRequest,
|
|
||||||
RawCode,
|
|
||||||
HTMLRequest,
|
|
||||||
ScreenshotRequest,
|
|
||||||
PDFRequest,
|
|
||||||
JSEndpointRequest,
|
|
||||||
SeedRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
from utils import (
|
|
||||||
FilterType, load_config, setup_logging, verify_email_domain
|
|
||||||
)
|
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import asyncio
|
|
||||||
from typing import List
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import pathlib
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from fastapi import (
|
import adaptive_routes
|
||||||
FastAPI, HTTPException, Request, Path, Query, Depends
|
from api import (
|
||||||
)
|
handle_crawl_request,
|
||||||
from rank_bm25 import BM25Okapi
|
handle_llm_qa,
|
||||||
from fastapi.responses import (
|
handle_markdown_request,
|
||||||
StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse
|
handle_seed,
|
||||||
|
handle_stream_crawl_request,
|
||||||
|
stream_results,
|
||||||
)
|
)
|
||||||
|
from auth import TokenRequest, create_access_token, get_token_dependency
|
||||||
|
from crawler_pool import close_all, get_crawler, janitor
|
||||||
|
from fastapi import Depends, FastAPI, HTTPException, Path, Query, Request
|
||||||
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
|
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
from fastapi.responses import (
|
||||||
|
FileResponse,
|
||||||
|
JSONResponse,
|
||||||
|
PlainTextResponse,
|
||||||
|
RedirectResponse,
|
||||||
|
StreamingResponse,
|
||||||
|
)
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from job import init_job_router
|
from job import init_job_router
|
||||||
|
|
||||||
from mcp_bridge import attach_mcp, mcp_resource, mcp_template, mcp_tool
|
from mcp_bridge import attach_mcp, mcp_resource, mcp_template, mcp_tool
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
import ast
|
|
||||||
import crawl4ai as _c4
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from rank_bm25 import BM25Okapi
|
||||||
|
from redis import asyncio as aioredis
|
||||||
|
from schemas import (
|
||||||
|
CrawlRequestWithHooks,
|
||||||
|
HTMLRequest,
|
||||||
|
JSEndpointRequest,
|
||||||
|
MarkdownRequest,
|
||||||
|
PDFRequest,
|
||||||
|
RawCode,
|
||||||
|
ScreenshotRequest,
|
||||||
|
SeedRequest,
|
||||||
|
)
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
from utils import FilterType, load_config, setup_logging, verify_email_domain
|
||||||
from redis import asyncio as aioredis
|
|
||||||
|
import crawl4ai as _c4
|
||||||
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||||
|
|
||||||
# ── internal imports (after sys.path append) ─────────────────
|
# ── internal imports (after sys.path append) ─────────────────
|
||||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||||
@@ -103,6 +96,8 @@ orig_arun = AsyncWebCrawler.arun
|
|||||||
async def capped_arun(self, *a, **kw):
|
async def capped_arun(self, *a, **kw):
|
||||||
async with GLOBAL_SEM:
|
async with GLOBAL_SEM:
|
||||||
return await orig_arun(self, *a, **kw)
|
return await orig_arun(self, *a, **kw)
|
||||||
|
|
||||||
|
|
||||||
AsyncWebCrawler.arun = capped_arun
|
AsyncWebCrawler.arun = capped_arun
|
||||||
|
|
||||||
# ───────────────────── FastAPI lifespan ──────────────────────
|
# ───────────────────── FastAPI lifespan ──────────────────────
|
||||||
@@ -110,15 +105,18 @@ AsyncWebCrawler.arun = capped_arun
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_: FastAPI):
|
async def lifespan(_: FastAPI):
|
||||||
await get_crawler(BrowserConfig(
|
await get_crawler(
|
||||||
extra_args=config["crawler"]["browser"].get("extra_args", []),
|
BrowserConfig(
|
||||||
**config["crawler"]["browser"].get("kwargs", {}),
|
extra_args=config["crawler"]["browser"].get("extra_args", []),
|
||||||
)) # warm‑up
|
**config["crawler"]["browser"].get("kwargs", {}),
|
||||||
app.state.janitor = asyncio.create_task(janitor()) # idle GC
|
)
|
||||||
|
) # warm‑up
|
||||||
|
app.state.janitor = asyncio.create_task(janitor()) # idle GC
|
||||||
yield
|
yield
|
||||||
app.state.janitor.cancel()
|
app.state.janitor.cancel()
|
||||||
await close_all()
|
await close_all()
|
||||||
|
|
||||||
|
|
||||||
# ───────────────────── FastAPI instance ──────────────────────
|
# ───────────────────── FastAPI instance ──────────────────────
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=config["app"]["title"],
|
title=config["app"]["title"],
|
||||||
@@ -126,6 +124,7 @@ app = FastAPI(
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.include_router(adaptive_routes.router)
|
||||||
# ── static playground ──────────────────────────────────────
|
# ── static playground ──────────────────────────────────────
|
||||||
STATIC_DIR = pathlib.Path(__file__).parent / "static" / "playground"
|
STATIC_DIR = pathlib.Path(__file__).parent / "static" / "playground"
|
||||||
if not STATIC_DIR.exists():
|
if not STATIC_DIR.exists():
|
||||||
@@ -141,6 +140,7 @@ app.mount(
|
|||||||
async def root():
|
async def root():
|
||||||
return RedirectResponse("/playground")
|
return RedirectResponse("/playground")
|
||||||
|
|
||||||
|
|
||||||
# ─────────────────── infra / middleware ─────────────────────
|
# ─────────────────── infra / middleware ─────────────────────
|
||||||
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
|
redis = aioredis.from_url(config["redis"].get("uri", "redis://localhost"))
|
||||||
|
|
||||||
@@ -158,9 +158,7 @@ def _setup_security(app_: FastAPI):
|
|||||||
if sec.get("https_redirect"):
|
if sec.get("https_redirect"):
|
||||||
app_.add_middleware(HTTPSRedirectMiddleware)
|
app_.add_middleware(HTTPSRedirectMiddleware)
|
||||||
if sec.get("trusted_hosts", []) != ["*"]:
|
if sec.get("trusted_hosts", []) != ["*"]:
|
||||||
app_.add_middleware(
|
app_.add_middleware(TrustedHostMiddleware, allowed_hosts=sec["trusted_hosts"])
|
||||||
TrustedHostMiddleware, allowed_hosts=sec["trusted_hosts"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_setup_security(app)
|
_setup_security(app)
|
||||||
@@ -178,6 +176,7 @@ async def add_security_headers(request: Request, call_next):
|
|||||||
resp.headers.update(config["security"]["headers"])
|
resp.headers.update(config["security"]["headers"])
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
# ───────────────── safe config‑dump helper ─────────────────
|
# ───────────────── safe config‑dump helper ─────────────────
|
||||||
ALLOWED_TYPES = {
|
ALLOWED_TYPES = {
|
||||||
"CrawlerRunConfig": CrawlerRunConfig,
|
"CrawlerRunConfig": CrawlerRunConfig,
|
||||||
@@ -199,9 +198,11 @@ def _safe_eval_config(expr: str) -> dict:
|
|||||||
raise ValueError("Expression must be a single constructor call")
|
raise ValueError("Expression must be a single constructor call")
|
||||||
|
|
||||||
call = tree.body
|
call = tree.body
|
||||||
if not (isinstance(call.func, ast.Name) and call.func.id in {"CrawlerRunConfig", "BrowserConfig"}):
|
if not (
|
||||||
raise ValueError(
|
isinstance(call.func, ast.Name)
|
||||||
"Only CrawlerRunConfig(...) or BrowserConfig(...) are allowed")
|
and call.func.id in {"CrawlerRunConfig", "BrowserConfig"}
|
||||||
|
):
|
||||||
|
raise ValueError("Only CrawlerRunConfig(...) or BrowserConfig(...) are allowed")
|
||||||
|
|
||||||
# forbid nested calls to keep the surface tiny
|
# forbid nested calls to keep the surface tiny
|
||||||
for node in ast.walk(call):
|
for node in ast.walk(call):
|
||||||
@@ -209,16 +210,17 @@ def _safe_eval_config(expr: str) -> dict:
|
|||||||
raise ValueError("Nested function calls are not permitted")
|
raise ValueError("Nested function calls are not permitted")
|
||||||
|
|
||||||
# expose everything that crawl4ai exports, nothing else
|
# expose everything that crawl4ai exports, nothing else
|
||||||
safe_env = {name: getattr(_c4, name)
|
safe_env = {
|
||||||
for name in dir(_c4) if not name.startswith("_")}
|
name: getattr(_c4, name) for name in dir(_c4) if not name.startswith("_")
|
||||||
obj = eval(compile(tree, "<config>", "eval"),
|
}
|
||||||
{"__builtins__": {}}, safe_env)
|
obj = eval(compile(tree, "<config>", "eval"), {"__builtins__": {}}, safe_env)
|
||||||
return obj.dump()
|
return obj.dump()
|
||||||
|
|
||||||
|
|
||||||
# ── job router ──────────────────────────────────────────────
|
# ── job router ──────────────────────────────────────────────
|
||||||
app.include_router(init_job_router(redis, config, token_dep))
|
app.include_router(init_job_router(redis, config, token_dep))
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────── Endpoints ──────────────────────────
|
# ──────────────────────── Endpoints ──────────────────────────
|
||||||
@app.post("/token")
|
@app.post("/token")
|
||||||
async def get_token(req: TokenRequest):
|
async def get_token(req: TokenRequest):
|
||||||
@@ -252,8 +254,8 @@ async def seed_url(request: SeedRequest):
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
detail="Invalid URL provided. Could not extract domain.",
|
detail="Invalid URL provided. Could not extract domain.",
|
||||||
)
|
)
|
||||||
res = await handle_seed(request.url , request.config)
|
res = await handle_seed(request.url, request.config)
|
||||||
return JSONResponse({"seed_url":res , "count":len(res)})
|
return JSONResponse({"seed_url": res, "count": len(res)})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error in seed_url: {e}")
|
print(f"❌ Error in seed_url: {e}")
|
||||||
@@ -268,21 +270,33 @@ async def get_markdown(
|
|||||||
body: MarkdownRequest,
|
body: MarkdownRequest,
|
||||||
_td: Dict = Depends(token_dep),
|
_td: Dict = Depends(token_dep),
|
||||||
):
|
):
|
||||||
if not body.url.startswith(("http://", "https://")) and not body.url.startswith(("raw:", "raw://")):
|
if not body.url.startswith(("http://", "https://")) and not body.url.startswith(
|
||||||
|
("raw:", "raw://")
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
400, "Invalid URL format. Must start with http://, https://, or for raw HTML (raw:, raw://)")
|
400,
|
||||||
|
"Invalid URL format. Must start with http://, https://, or for raw HTML (raw:, raw://)",
|
||||||
|
)
|
||||||
markdown = await handle_markdown_request(
|
markdown = await handle_markdown_request(
|
||||||
body.url, body.f, body.q, body.c, config, body.provider,
|
body.url,
|
||||||
body.temperature, body.base_url
|
body.f,
|
||||||
|
body.q,
|
||||||
|
body.c,
|
||||||
|
config,
|
||||||
|
body.provider,
|
||||||
|
body.temperature,
|
||||||
|
body.base_url,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"url": body.url,
|
||||||
|
"filter": body.f,
|
||||||
|
"query": body.q,
|
||||||
|
"cache": body.c,
|
||||||
|
"markdown": markdown,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return JSONResponse({
|
|
||||||
"url": body.url,
|
|
||||||
"filter": body.f,
|
|
||||||
"query": body.q,
|
|
||||||
"cache": body.c,
|
|
||||||
"markdown": markdown,
|
|
||||||
"success": True
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/html")
|
@app.post("/html")
|
||||||
@@ -304,20 +318,18 @@ async def generate_html(
|
|||||||
# Check if the crawl was successful
|
# Check if the crawl was successful
|
||||||
if not results[0].success:
|
if not results[0].success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=results[0].error_message or "Crawl failed"
|
||||||
detail=results[0].error_message or "Crawl failed"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_html = results[0].html
|
raw_html = results[0].html
|
||||||
from crawl4ai.utils import preprocess_html_for_schema
|
from crawl4ai.utils import preprocess_html_for_schema
|
||||||
|
|
||||||
processed_html = preprocess_html_for_schema(raw_html)
|
processed_html = preprocess_html_for_schema(raw_html)
|
||||||
return JSONResponse({"html": processed_html, "url": body.url, "success": True})
|
return JSONResponse({"html": processed_html, "url": body.url, "success": True})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log and raise as HTTP 500 for other exceptions
|
# Log and raise as HTTP 500 for other exceptions
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
status_code=500,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Screenshot endpoint
|
# Screenshot endpoint
|
||||||
|
|
||||||
@@ -337,13 +349,13 @@ async def generate_screenshot(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cfg = CrawlerRunConfig(
|
cfg = CrawlerRunConfig(
|
||||||
screenshot=True, screenshot_wait_for=body.screenshot_wait_for)
|
screenshot=True, screenshot_wait_for=body.screenshot_wait_for
|
||||||
|
)
|
||||||
async with AsyncWebCrawler(config=BrowserConfig()) as crawler:
|
async with AsyncWebCrawler(config=BrowserConfig()) as crawler:
|
||||||
results = await crawler.arun(url=body.url, config=cfg)
|
results = await crawler.arun(url=body.url, config=cfg)
|
||||||
if not results[0].success:
|
if not results[0].success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=results[0].error_message or "Crawl failed"
|
||||||
detail=results[0].error_message or "Crawl failed"
|
|
||||||
)
|
)
|
||||||
screenshot_data = results[0].screenshot
|
screenshot_data = results[0].screenshot
|
||||||
if body.output_path:
|
if body.output_path:
|
||||||
@@ -354,10 +366,8 @@ async def generate_screenshot(
|
|||||||
return {"success": True, "path": abs_path}
|
return {"success": True, "path": abs_path}
|
||||||
return {"success": True, "screenshot": screenshot_data}
|
return {"success": True, "screenshot": screenshot_data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
status_code=500,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
# PDF endpoint
|
# PDF endpoint
|
||||||
|
|
||||||
@@ -381,8 +391,7 @@ async def generate_pdf(
|
|||||||
results = await crawler.arun(url=body.url, config=cfg)
|
results = await crawler.arun(url=body.url, config=cfg)
|
||||||
if not results[0].success:
|
if not results[0].success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=results[0].error_message or "Crawl failed"
|
||||||
detail=results[0].error_message or "Crawl failed"
|
|
||||||
)
|
)
|
||||||
pdf_data = results[0].pdf
|
pdf_data = results[0].pdf
|
||||||
if body.output_path:
|
if body.output_path:
|
||||||
@@ -393,10 +402,7 @@ async def generate_pdf(
|
|||||||
return {"success": True, "path": abs_path}
|
return {"success": True, "path": abs_path}
|
||||||
return {"success": True, "pdf": base64.b64encode(pdf_data).decode()}
|
return {"success": True, "pdf": base64.b64encode(pdf_data).decode()}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
status_code=500,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/execute_js")
|
@app.post("/execute_js")
|
||||||
@@ -458,17 +464,13 @@ async def execute_js(
|
|||||||
results = await crawler.arun(url=body.url, config=cfg)
|
results = await crawler.arun(url=body.url, config=cfg)
|
||||||
if not results[0].success:
|
if not results[0].success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail=results[0].error_message or "Crawl failed"
|
||||||
detail=results[0].error_message or "Crawl failed"
|
|
||||||
)
|
)
|
||||||
# Return JSON-serializable dict of the first CrawlResult
|
# Return JSON-serializable dict of the first CrawlResult
|
||||||
data = results[0].model_dump()
|
data = results[0].model_dump()
|
||||||
return JSONResponse(data)
|
return JSONResponse(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
status_code=500,
|
|
||||||
detail=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/llm/{url:path}")
|
@app.get("/llm/{url:path}")
|
||||||
@@ -480,7 +482,9 @@ async def llm_endpoint(
|
|||||||
):
|
):
|
||||||
if not q:
|
if not q:
|
||||||
raise HTTPException(400, "Query parameter 'q' is required")
|
raise HTTPException(400, "Query parameter 'q' is required")
|
||||||
if not url.startswith(("http://", "https://")) and not url.startswith(("raw:", "raw://")):
|
if not url.startswith(("http://", "https://")) and not url.startswith(
|
||||||
|
("raw:", "raw://")
|
||||||
|
):
|
||||||
url = "https://" + url
|
url = "https://" + url
|
||||||
answer = await handle_llm_qa(url, q, config)
|
answer = await handle_llm_qa(url, q, config)
|
||||||
return JSONResponse({"answer": answer})
|
return JSONResponse({"answer": answer})
|
||||||
@@ -489,8 +493,8 @@ async def llm_endpoint(
|
|||||||
@app.get("/schema")
|
@app.get("/schema")
|
||||||
async def get_schema():
|
async def get_schema():
|
||||||
from crawl4ai import BrowserConfig, CrawlerRunConfig
|
from crawl4ai import BrowserConfig, CrawlerRunConfig
|
||||||
return {"browser": BrowserConfig().dump(),
|
|
||||||
"crawler": CrawlerRunConfig().dump()}
|
return {"browser": BrowserConfig().dump(), "crawler": CrawlerRunConfig().dump()}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/hooks/info")
|
@app.get("/hooks/info")
|
||||||
@@ -503,17 +507,15 @@ async def get_hooks_info():
|
|||||||
hook_info[hook_point] = {
|
hook_info[hook_point] = {
|
||||||
"parameters": params,
|
"parameters": params,
|
||||||
"description": get_hook_description(hook_point),
|
"description": get_hook_description(hook_point),
|
||||||
"example": get_hook_example(hook_point)
|
"example": get_hook_example(hook_point),
|
||||||
}
|
}
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
"available_hooks": hook_info,
|
{
|
||||||
"timeout_limits": {
|
"available_hooks": hook_info,
|
||||||
"min": 1,
|
"timeout_limits": {"min": 1, "max": 120, "default": 30},
|
||||||
"max": 120,
|
|
||||||
"default": 30
|
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hook_description(hook_point: str) -> str:
|
def get_hook_description(hook_point: str) -> str:
|
||||||
@@ -526,7 +528,7 @@ def get_hook_description(hook_point: str) -> str:
|
|||||||
"on_user_agent_updated": "Called when user agent is updated",
|
"on_user_agent_updated": "Called when user agent is updated",
|
||||||
"on_execution_started": "Called when custom JavaScript execution begins",
|
"on_execution_started": "Called when custom JavaScript execution begins",
|
||||||
"before_retrieve_html": "Called before retrieving the final HTML - ideal for scrolling",
|
"before_retrieve_html": "Called before retrieving the final HTML - ideal for scrolling",
|
||||||
"before_return_html": "Called just before returning the HTML content"
|
"before_return_html": "Called just before returning the HTML content",
|
||||||
}
|
}
|
||||||
return descriptions.get(hook_point, "")
|
return descriptions.get(hook_point, "")
|
||||||
|
|
||||||
@@ -542,19 +544,17 @@ def get_hook_example(hook_point: str) -> str:
|
|||||||
'domain': '.example.com'
|
'domain': '.example.com'
|
||||||
}])
|
}])
|
||||||
return page""",
|
return page""",
|
||||||
|
|
||||||
"before_retrieve_html": """async def hook(page, context, **kwargs):
|
"before_retrieve_html": """async def hook(page, context, **kwargs):
|
||||||
# Scroll to load lazy content
|
# Scroll to load lazy content
|
||||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||||
await page.wait_for_timeout(2000)
|
await page.wait_for_timeout(2000)
|
||||||
return page""",
|
return page""",
|
||||||
|
|
||||||
"before_goto": """async def hook(page, context, url, **kwargs):
|
"before_goto": """async def hook(page, context, url, **kwargs):
|
||||||
# Set custom headers
|
# Set custom headers
|
||||||
await page.set_extra_http_headers({
|
await page.set_extra_http_headers({
|
||||||
'X-Custom-Header': 'value'
|
'X-Custom-Header': 'value'
|
||||||
})
|
})
|
||||||
return page"""
|
return page""",
|
||||||
}
|
}
|
||||||
return examples.get(hook_point, "# Implement your hook logic here\nreturn page")
|
return examples.get(hook_point, "# Implement your hook logic here\nreturn page")
|
||||||
|
|
||||||
@@ -593,8 +593,8 @@ async def crawl(
|
|||||||
hooks_config = None
|
hooks_config = None
|
||||||
if crawl_request.hooks:
|
if crawl_request.hooks:
|
||||||
hooks_config = {
|
hooks_config = {
|
||||||
'code': crawl_request.hooks.code,
|
"code": crawl_request.hooks.code,
|
||||||
'timeout': crawl_request.hooks.timeout
|
"timeout": crawl_request.hooks.timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = await handle_crawl_request(
|
results = await handle_crawl_request(
|
||||||
@@ -602,11 +602,13 @@ async def crawl(
|
|||||||
browser_config=crawl_request.browser_config,
|
browser_config=crawl_request.browser_config,
|
||||||
crawler_config=crawl_request.crawler_config,
|
crawler_config=crawl_request.crawler_config,
|
||||||
config=config,
|
config=config,
|
||||||
hooks_config=hooks_config
|
hooks_config=hooks_config,
|
||||||
)
|
)
|
||||||
# check if all of the results are not successful
|
# check if all of the results are not successful
|
||||||
if all(not result["success"] for result in results["results"]):
|
if all(not result["success"] for result in results["results"]):
|
||||||
raise HTTPException(500, f"Crawl request failed: {results['results'][0]['error_message']}")
|
raise HTTPException(
|
||||||
|
500, f"Crawl request failed: {results['results'][0]['error_message']}"
|
||||||
|
)
|
||||||
return JSONResponse(results)
|
return JSONResponse(results)
|
||||||
|
|
||||||
|
|
||||||
@@ -622,14 +624,14 @@ async def crawl_stream(
|
|||||||
|
|
||||||
return await stream_process(crawl_request=crawl_request)
|
return await stream_process(crawl_request=crawl_request)
|
||||||
|
|
||||||
async def stream_process(crawl_request: CrawlRequestWithHooks):
|
|
||||||
|
|
||||||
|
async def stream_process(crawl_request: CrawlRequestWithHooks):
|
||||||
# Prepare hooks config if provided# Prepare hooks config if provided
|
# Prepare hooks config if provided# Prepare hooks config if provided
|
||||||
hooks_config = None
|
hooks_config = None
|
||||||
if crawl_request.hooks:
|
if crawl_request.hooks:
|
||||||
hooks_config = {
|
hooks_config = {
|
||||||
'code': crawl_request.hooks.code,
|
"code": crawl_request.hooks.code,
|
||||||
'timeout': crawl_request.hooks.timeout
|
"timeout": crawl_request.hooks.timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
crawler, gen, hooks_info = await handle_stream_crawl_request(
|
crawler, gen, hooks_info = await handle_stream_crawl_request(
|
||||||
@@ -637,7 +639,7 @@ async def stream_process(crawl_request: CrawlRequestWithHooks):
|
|||||||
browser_config=crawl_request.browser_config,
|
browser_config=crawl_request.browser_config,
|
||||||
crawler_config=crawl_request.crawler_config,
|
crawler_config=crawl_request.crawler_config,
|
||||||
config=config,
|
config=config,
|
||||||
hooks_config=hooks_config
|
hooks_config=hooks_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add hooks info to response headers if available
|
# Add hooks info to response headers if available
|
||||||
@@ -648,7 +650,8 @@ async def stream_process(crawl_request: CrawlRequestWithHooks):
|
|||||||
}
|
}
|
||||||
if hooks_info:
|
if hooks_info:
|
||||||
import json
|
import json
|
||||||
headers["X-Hooks-Status"] = json.dumps(hooks_info['status']['status'])
|
|
||||||
|
headers["X-Hooks-Status"] = json.dumps(hooks_info["status"]["status"])
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_results(crawler, gen),
|
stream_results(crawler, gen),
|
||||||
@@ -661,10 +664,10 @@ def chunk_code_functions(code_md: str) -> List[str]:
|
|||||||
"""Extract each function/class from markdown code blocks per file."""
|
"""Extract each function/class from markdown code blocks per file."""
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
# match "## File: <path>" then a ```py fence, then capture until the closing ```
|
# match "## File: <path>" then a ```py fence, then capture until the closing ```
|
||||||
r'##\s*File:\s*(?P<path>.+?)\s*?\r?\n' # file header
|
r"##\s*File:\s*(?P<path>.+?)\s*?\r?\n" # file header
|
||||||
r'```py\s*?\r?\n' # opening fence
|
r"```py\s*?\r?\n" # opening fence
|
||||||
r'(?P<code>.*?)(?=\r?\n```)', # code block
|
r"(?P<code>.*?)(?=\r?\n```)", # code block
|
||||||
re.DOTALL
|
re.DOTALL,
|
||||||
)
|
)
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
for m in pattern.finditer(code_md):
|
for m in pattern.finditer(code_md):
|
||||||
@@ -704,12 +707,11 @@ async def get_context(
|
|||||||
request: Request,
|
request: Request,
|
||||||
_td: Dict = Depends(token_dep),
|
_td: Dict = Depends(token_dep),
|
||||||
context_type: str = Query("all", regex="^(code|doc|all)$"),
|
context_type: str = Query("all", regex="^(code|doc|all)$"),
|
||||||
query: Optional[str] = Query(
|
query: Optional[str] = Query(None, description="search query to filter chunks"),
|
||||||
None, description="search query to filter chunks"),
|
|
||||||
score_ratio: float = Query(
|
score_ratio: float = Query(
|
||||||
0.5, ge=0.0, le=1.0, description="min score as fraction of max_score"),
|
0.5, ge=0.0, le=1.0, description="min score as fraction of max_score"
|
||||||
max_results: int = Query(
|
),
|
||||||
20, ge=1, description="absolute cap on returned chunks"),
|
max_results: int = Query(20, ge=1, description="absolute cap on returned chunks"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This end point is design for any questions about Crawl4ai library. It returns a plain text markdown with extensive information about Crawl4ai.
|
This end point is design for any questions about Crawl4ai library. It returns a plain text markdown with extensive information about Crawl4ai.
|
||||||
@@ -746,10 +748,12 @@ async def get_context(
|
|||||||
return JSONResponse({"code_context": code_content})
|
return JSONResponse({"code_context": code_content})
|
||||||
if context_type == "doc":
|
if context_type == "doc":
|
||||||
return JSONResponse({"doc_context": doc_content})
|
return JSONResponse({"doc_context": doc_content})
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
"code_context": code_content,
|
{
|
||||||
"doc_context": doc_content,
|
"code_context": code_content,
|
||||||
})
|
"doc_context": doc_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
tokens = query.split()
|
tokens = query.split()
|
||||||
results: Dict[str, List[Dict[str, float]]] = {}
|
results: Dict[str, List[Dict[str, float]]] = {}
|
||||||
@@ -773,7 +777,7 @@ async def get_context(
|
|||||||
max_sd = float(scores_d.max()) if scores_d.size > 0 else 0.0
|
max_sd = float(scores_d.max()) if scores_d.size > 0 else 0.0
|
||||||
cutoff_d = max_sd * score_ratio
|
cutoff_d = max_sd * score_ratio
|
||||||
idxs = [i for i, s in enumerate(scores_d) if s >= cutoff_d]
|
idxs = [i for i, s in enumerate(scores_d) if s >= cutoff_d]
|
||||||
neighbors = set(i for idx in idxs for i in (idx-1, idx, idx+1))
|
neighbors = set(i for idx in idxs for i in (idx - 1, idx, idx + 1))
|
||||||
valid = [i for i in sorted(neighbors) if 0 <= i < len(sections)]
|
valid = [i for i in sorted(neighbors) if 0 <= i < len(sections)]
|
||||||
valid = valid[:max_results]
|
valid = valid[:max_results]
|
||||||
results["doc_results"] = [
|
results["doc_results"] = [
|
||||||
@@ -785,14 +789,12 @@ async def get_context(
|
|||||||
|
|
||||||
# attach MCP layer (adds /mcp/ws, /mcp/sse, /mcp/schema)
|
# attach MCP layer (adds /mcp/ws, /mcp/sse, /mcp/schema)
|
||||||
print(f"MCP server running on {config['app']['host']}:{config['app']['port']}")
|
print(f"MCP server running on {config['app']['host']}:{config['app']['port']}")
|
||||||
attach_mcp(
|
attach_mcp(app, base_url=f"http://{config['app']['host']}:{config['app']['port']}")
|
||||||
app,
|
|
||||||
base_url=f"http://{config['app']['host']}:{config['app']['port']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ────────────────────────── cli ──────────────────────────────
|
# ────────────────────────── cli ──────────────────────────────
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"server:app",
|
"server:app",
|
||||||
host=config["app"]["host"],
|
host=config["app"]["host"],
|
||||||
|
|||||||
Reference in New Issue
Block a user