Creating the API server component
This commit is contained in:
550
main.py
550
main.py
@@ -1,254 +1,346 @@
|
|||||||
import os
|
|
||||||
import importlib
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import lru_cache
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
||||||
import logging
|
from fastapi.responses import JSONResponse
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
from pydantic import BaseModel, HttpUrl, Field
|
||||||
|
from typing import Optional, List, Dict, Any, Union
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
import psutil
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from fastapi.exceptions import RequestValidationError
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import FileResponse
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
|
|
||||||
from pydantic import BaseModel, HttpUrl
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from crawl4ai.web_crawler import WebCrawler
|
|
||||||
from crawl4ai.database import get_total_count, clear_db
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
import uuid
|
||||||
from slowapi.util import get_remote_address
|
from collections import defaultdict
|
||||||
from slowapi.errors import RateLimitExceeded
|
from urllib.parse import urlparse
|
||||||
|
import math
|
||||||
# load .env file
|
import logging
|
||||||
from dotenv import load_dotenv
|
from enum import Enum
|
||||||
load_dotenv()
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
# Configuration
|
from crawl4ai import AsyncWebCrawler, CrawlResult
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
from crawl4ai.extraction_strategy import (
|
||||||
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
LLMExtractionStrategy,
|
||||||
current_requests = 0
|
CosineStrategy,
|
||||||
lock = asyncio.Lock()
|
JsonCssExtractionStrategy,
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# Initialize rate limiter
|
|
||||||
def rate_limit_key_func(request: Request):
|
|
||||||
access_token = request.headers.get("access-token")
|
|
||||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
|
||||||
return None
|
|
||||||
return get_remote_address(request)
|
|
||||||
|
|
||||||
limiter = Limiter(key_func=rate_limit_key_func)
|
|
||||||
app.state.limiter = limiter
|
|
||||||
|
|
||||||
# Dictionary to store last request times for each client
|
|
||||||
last_request_times = {}
|
|
||||||
last_rate_limit = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_rate_limit():
|
|
||||||
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
|
||||||
return f"{limit}/minute"
|
|
||||||
|
|
||||||
# Custom rate limit exceeded handler
|
|
||||||
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
|
||||||
if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60:
|
|
||||||
last_rate_limit[request.client.host] = time.time()
|
|
||||||
retry_after = 60 - (time.time() - last_rate_limit[request.client.host])
|
|
||||||
reset_at = time.time() + retry_after
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=429,
|
|
||||||
content={
|
|
||||||
"detail": "Rate limit exceeded",
|
|
||||||
"limit": str(exc.limit.limit),
|
|
||||||
"retry_after": retry_after,
|
|
||||||
'reset_at': reset_at,
|
|
||||||
"message": f"You have exceeded the rate limit of {exc.limit.limit}."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
|
||||||
|
|
||||||
|
|
||||||
# Middleware for token-based bypass and per-request limit
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
|
||||||
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
|
||||||
access_token = request.headers.get("access-token")
|
|
||||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
path = request.url.path
|
|
||||||
if path in ["/crawl", "/old"]:
|
|
||||||
client_ip = request.client.host
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# Check time since last request
|
|
||||||
if client_ip in last_request_times:
|
|
||||||
time_since_last_request = current_time - last_request_times[client_ip]
|
|
||||||
if time_since_last_request < SPAN:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=429,
|
|
||||||
content={
|
|
||||||
"detail": "Too many requests",
|
|
||||||
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
|
||||||
"retry_after": max(0, SPAN - time_since_last_request),
|
|
||||||
"reset_at": current_time + max(0, SPAN - time_since_last_request),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
last_request_times[client_ip] = current_time
|
|
||||||
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
app.add_middleware(RateLimitMiddleware)
|
|
||||||
|
|
||||||
# CORS configuration
|
|
||||||
origins = ["*"] # Allow all origins
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=origins, # List of origins that are allowed to make requests
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"], # Allows all methods
|
|
||||||
allow_headers=["*"], # Allows all headers
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mount the pages directory as a static directory
|
logging.basicConfig(level=logging.INFO)
|
||||||
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
|
logger = logging.getLogger(__name__)
|
||||||
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
|
|
||||||
site_templates = Jinja2Templates(directory=__location__ + "/site")
|
|
||||||
templates = Jinja2Templates(directory=__location__ + "/pages")
|
|
||||||
|
|
||||||
@lru_cache()
|
class TaskStatus(str, Enum):
|
||||||
def get_crawler():
|
PENDING = "pending"
|
||||||
# Initialize and return a WebCrawler instance
|
PROCESSING = "processing"
|
||||||
crawler = WebCrawler(verbose = True)
|
COMPLETED = "completed"
|
||||||
crawler.warmup()
|
FAILED = "failed"
|
||||||
return crawler
|
|
||||||
|
class CrawlerType(str, Enum):
|
||||||
|
BASIC = "basic"
|
||||||
|
LLM = "llm"
|
||||||
|
COSINE = "cosine"
|
||||||
|
JSON_CSS = "json_css"
|
||||||
|
|
||||||
|
class ExtractionConfig(BaseModel):
|
||||||
|
type: CrawlerType
|
||||||
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
class CrawlRequest(BaseModel):
|
class CrawlRequest(BaseModel):
|
||||||
urls: List[str]
|
urls: Union[HttpUrl, List[HttpUrl]]
|
||||||
include_raw_html: Optional[bool] = False
|
extraction_config: Optional[ExtractionConfig] = None
|
||||||
bypass_cache: bool = False
|
crawler_params: Dict[str, Any] = {}
|
||||||
extract_blocks: bool = True
|
priority: int = Field(default=5, ge=1, le=10)
|
||||||
word_count_threshold: Optional[int] = 5
|
ttl: Optional[int] = 3600
|
||||||
extraction_strategy: Optional[str] = "NoExtractionStrategy"
|
js_code: Optional[List[str]] = None
|
||||||
extraction_strategy_args: Optional[dict] = {}
|
wait_for: Optional[str] = None
|
||||||
chunking_strategy: Optional[str] = "RegexChunking"
|
|
||||||
chunking_strategy_args: Optional[dict] = {}
|
|
||||||
css_selector: Optional[str] = None
|
css_selector: Optional[str] = None
|
||||||
screenshot: Optional[bool] = False
|
screenshot: bool = False
|
||||||
user_agent: Optional[str] = None
|
magic: bool = False
|
||||||
verbose: Optional[bool] = True
|
|
||||||
|
|
||||||
@app.get("/")
|
@dataclass
|
||||||
def read_root():
|
class TaskInfo:
|
||||||
return RedirectResponse(url="/mkdocs")
|
id: str
|
||||||
|
status: TaskStatus
|
||||||
|
result: Optional[Union[CrawlResult, List[CrawlResult]]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
created_at: float = time.time()
|
||||||
|
ttl: int = 3600
|
||||||
|
|
||||||
@app.get("/old", response_class=HTMLResponse)
|
class ResourceMonitor:
|
||||||
@limiter.limit(get_rate_limit())
|
def __init__(self, max_concurrent_tasks: int = 10):
|
||||||
async def read_index(request: Request):
|
self.max_concurrent_tasks = max_concurrent_tasks
|
||||||
partials_dir = os.path.join(__location__, "pages", "partial")
|
self.memory_threshold = 0.85
|
||||||
partials = {}
|
self.cpu_threshold = 0.90
|
||||||
|
self._last_check = 0
|
||||||
|
self._check_interval = 1 # seconds
|
||||||
|
self._last_available_slots = max_concurrent_tasks
|
||||||
|
|
||||||
for filename in os.listdir(partials_dir):
|
async def get_available_slots(self) -> int:
|
||||||
if filename.endswith(".html"):
|
current_time = time.time()
|
||||||
with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file:
|
if current_time - self._last_check < self._check_interval:
|
||||||
partials[filename[:-5]] = file.read()
|
return self._last_available_slots
|
||||||
|
|
||||||
return templates.TemplateResponse("index.html", {"request": request, **partials})
|
mem_usage = psutil.virtual_memory().percent / 100
|
||||||
|
cpu_usage = psutil.cpu_percent() / 100
|
||||||
|
|
||||||
@app.get("/total-count")
|
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold)
|
||||||
async def get_total_url_count():
|
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
|
||||||
count = get_total_count()
|
|
||||||
return JSONResponse(content={"count": count})
|
|
||||||
|
|
||||||
@app.get("/clear-db")
|
self._last_available_slots = math.floor(
|
||||||
async def clear_database():
|
self.max_concurrent_tasks * min(memory_factor, cpu_factor)
|
||||||
# clear_db()
|
)
|
||||||
return JSONResponse(content={"message": "Database cleared."})
|
self._last_check = current_time
|
||||||
|
|
||||||
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
return self._last_available_slots
|
||||||
try:
|
|
||||||
module = importlib.import_module(module_name)
|
class TaskManager:
|
||||||
strategy_class = getattr(module, class_name)
|
def __init__(self, cleanup_interval: int = 300):
|
||||||
return strategy_class(*args, **kwargs)
|
self.tasks: Dict[str, TaskInfo] = {}
|
||||||
except ImportError:
|
self.high_priority = asyncio.PriorityQueue()
|
||||||
print("ImportError: Module not found.")
|
self.low_priority = asyncio.PriorityQueue()
|
||||||
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
|
self.cleanup_interval = cleanup_interval
|
||||||
except AttributeError:
|
self.cleanup_task = None
|
||||||
print("AttributeError: Class not found.")
|
|
||||||
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
async def start(self):
|
||||||
|
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
if self.cleanup_task:
|
||||||
|
self.cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def add_task(self, task_id: str, priority: int, ttl: int) -> None:
|
||||||
|
task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl)
|
||||||
|
self.tasks[task_id] = task_info
|
||||||
|
queue = self.high_priority if priority > 5 else self.low_priority
|
||||||
|
await queue.put((-priority, task_id)) # Negative for proper priority ordering
|
||||||
|
|
||||||
|
async def get_next_task(self) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
# Try high priority first
|
||||||
|
_, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1)
|
||||||
|
return task_id
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
try:
|
||||||
|
# Then try low priority
|
||||||
|
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1)
|
||||||
|
return task_id
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None):
|
||||||
|
if task_id in self.tasks:
|
||||||
|
task_info = self.tasks[task_id]
|
||||||
|
task_info.status = status
|
||||||
|
task_info.result = result
|
||||||
|
task_info.error = error
|
||||||
|
|
||||||
|
def get_task(self, task_id: str) -> Optional[TaskInfo]:
|
||||||
|
return self.tasks.get(task_id)
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
current_time = time.time()
|
||||||
|
expired_tasks = [
|
||||||
|
task_id
|
||||||
|
for task_id, task in self.tasks.items()
|
||||||
|
if current_time - task.created_at > task.ttl
|
||||||
|
and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||||||
|
]
|
||||||
|
for task_id in expired_tasks:
|
||||||
|
del self.tasks[task_id]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup loop: {e}")
|
||||||
|
|
||||||
|
class CrawlerPool:
|
||||||
|
def __init__(self, max_size: int = 10):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.active_crawlers: Dict[AsyncWebCrawler, float] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def acquire(self, **kwargs) -> AsyncWebCrawler:
|
||||||
|
async with self._lock:
|
||||||
|
# Clean up inactive crawlers
|
||||||
|
current_time = time.time()
|
||||||
|
inactive = [
|
||||||
|
crawler
|
||||||
|
for crawler, last_used in self.active_crawlers.items()
|
||||||
|
if current_time - last_used > 600 # 10 minutes timeout
|
||||||
|
]
|
||||||
|
for crawler in inactive:
|
||||||
|
await crawler.__aexit__(None, None, None)
|
||||||
|
del self.active_crawlers[crawler]
|
||||||
|
|
||||||
|
# Create new crawler if needed
|
||||||
|
if len(self.active_crawlers) < self.max_size:
|
||||||
|
crawler = AsyncWebCrawler(**kwargs)
|
||||||
|
await crawler.__aenter__()
|
||||||
|
self.active_crawlers[crawler] = current_time
|
||||||
|
return crawler
|
||||||
|
|
||||||
|
# Reuse least recently used crawler
|
||||||
|
crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0]
|
||||||
|
self.active_crawlers[crawler] = current_time
|
||||||
|
return crawler
|
||||||
|
|
||||||
|
async def release(self, crawler: AsyncWebCrawler):
|
||||||
|
async with self._lock:
|
||||||
|
if crawler in self.active_crawlers:
|
||||||
|
self.active_crawlers[crawler] = time.time()
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
async with self._lock:
|
||||||
|
for crawler in list(self.active_crawlers.keys()):
|
||||||
|
await crawler.__aexit__(None, None, None)
|
||||||
|
self.active_crawlers.clear()
|
||||||
|
|
||||||
|
class CrawlerService:
|
||||||
|
def __init__(self, max_concurrent_tasks: int = 10):
|
||||||
|
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
|
||||||
|
self.task_manager = TaskManager()
|
||||||
|
self.crawler_pool = CrawlerPool(max_concurrent_tasks)
|
||||||
|
self._processing_task = None
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
await self.task_manager.start()
|
||||||
|
self._processing_task = asyncio.create_task(self._process_queue())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
if self._processing_task:
|
||||||
|
self._processing_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._processing_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await self.task_manager.stop()
|
||||||
|
await self.crawler_pool.cleanup()
|
||||||
|
|
||||||
|
def _create_extraction_strategy(self, config: ExtractionConfig):
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if config.type == CrawlerType.LLM:
|
||||||
|
return LLMExtractionStrategy(**config.params)
|
||||||
|
elif config.type == CrawlerType.COSINE:
|
||||||
|
return CosineStrategy(**config.params)
|
||||||
|
elif config.type == CrawlerType.JSON_CSS:
|
||||||
|
return JsonCssExtractionStrategy(**config.params)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def submit_task(self, request: CrawlRequest) -> str:
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600)
|
||||||
|
|
||||||
|
# Store request data with task
|
||||||
|
self.task_manager.tasks[task_id].request = request
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
async def _process_queue(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
available_slots = await self.resource_monitor.get_available_slots()
|
||||||
|
if available_slots <= 0:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_id = await self.task_manager.get_next_task()
|
||||||
|
if not task_id:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_info = self.task_manager.get_task(task_id)
|
||||||
|
if not task_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
request = task_info.request
|
||||||
|
self.task_manager.update_task(task_id, TaskStatus.PROCESSING)
|
||||||
|
|
||||||
|
try:
|
||||||
|
crawler = await self.crawler_pool.acquire(**request.crawler_params)
|
||||||
|
|
||||||
|
extraction_strategy = self._create_extraction_strategy(request.extraction_config)
|
||||||
|
|
||||||
|
if isinstance(request.urls, list):
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
urls=[str(url) for url in request.urls],
|
||||||
|
extraction_strategy=extraction_strategy,
|
||||||
|
js_code=request.js_code,
|
||||||
|
wait_for=request.wait_for,
|
||||||
|
css_selector=request.css_selector,
|
||||||
|
screenshot=request.screenshot,
|
||||||
|
magic=request.magic,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results = await crawler.arun(
|
||||||
|
url=str(request.urls),
|
||||||
|
extraction_strategy=extraction_strategy,
|
||||||
|
js_code=request.js_code,
|
||||||
|
wait_for=request.wait_for,
|
||||||
|
css_selector=request.css_selector,
|
||||||
|
screenshot=request.screenshot,
|
||||||
|
magic=request.magic,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.crawler_pool.release(crawler)
|
||||||
|
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||||
|
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in queue processing: {str(e)}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
app = FastAPI(title="Crawl4AI API")
|
||||||
|
crawler_service = CrawlerService()
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
await crawler_service.start()
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
await crawler_service.stop()
|
||||||
|
|
||||||
@app.post("/crawl")
|
@app.post("/crawl")
|
||||||
@limiter.limit(get_rate_limit())
|
async def crawl(request: CrawlRequest) -> Dict[str, str]:
|
||||||
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
task_id = await crawler_service.submit_task(request)
|
||||||
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
return {"task_id": task_id}
|
||||||
global current_requests
|
|
||||||
async with lock:
|
|
||||||
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
|
||||||
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
|
||||||
current_requests += 1
|
|
||||||
|
|
||||||
try:
|
@app.get("/task/{task_id}")
|
||||||
logging.debug("[LOG] Loading extraction and chunking strategies...")
|
async def get_task_status(task_id: str):
|
||||||
crawl_request.extraction_strategy_args['verbose'] = True
|
task_info = crawler_service.task_manager.get_task(task_id)
|
||||||
crawl_request.chunking_strategy_args['verbose'] = True
|
if not task_info:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
|
response = {
|
||||||
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
|
"status": task_info.status,
|
||||||
|
"created_at": task_info.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
|
if task_info.status == TaskStatus.COMPLETED:
|
||||||
logging.debug("[LOG] Running the WebCrawler...")
|
# Convert CrawlResult to dict for JSON response
|
||||||
with ThreadPoolExecutor() as executor:
|
if isinstance(task_info.result, list):
|
||||||
loop = asyncio.get_event_loop()
|
response["results"] = [result.dict() for result in task_info.result]
|
||||||
futures = [
|
else:
|
||||||
loop.run_in_executor(
|
response["result"] = task_info.result.dict()
|
||||||
executor,
|
elif task_info.status == TaskStatus.FAILED:
|
||||||
get_crawler().run,
|
response["error"] = task_info.error
|
||||||
str(url),
|
|
||||||
crawl_request.word_count_threshold,
|
|
||||||
extraction_strategy,
|
|
||||||
chunking_strategy,
|
|
||||||
crawl_request.bypass_cache,
|
|
||||||
crawl_request.css_selector,
|
|
||||||
crawl_request.screenshot,
|
|
||||||
crawl_request.user_agent,
|
|
||||||
crawl_request.verbose
|
|
||||||
)
|
|
||||||
for url in crawl_request.urls
|
|
||||||
]
|
|
||||||
results = await asyncio.gather(*futures)
|
|
||||||
|
|
||||||
# if include_raw_html is False, remove the raw HTML content from the results
|
return response
|
||||||
if not crawl_request.include_raw_html:
|
|
||||||
for result in results:
|
|
||||||
result.html = None
|
|
||||||
|
|
||||||
return {"results": [result.model_dump() for result in results]}
|
|
||||||
finally:
|
|
||||||
async with lock:
|
|
||||||
current_requests -= 1
|
|
||||||
|
|
||||||
@app.get("/strategies/extraction", response_class=JSONResponse)
|
|
||||||
async def get_extraction_strategies():
|
|
||||||
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
|
|
||||||
return JSONResponse(content=file.read())
|
|
||||||
|
|
||||||
@app.get("/strategies/chunking", response_class=JSONResponse)
|
|
||||||
async def get_chunking_strategies():
|
|
||||||
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
|
|
||||||
return JSONResponse(content=file.read())
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
available_slots = await crawler_service.resource_monitor.get_available_slots()
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"available_slots": available_slots,
|
||||||
|
"memory_usage": memory.percent,
|
||||||
|
"cpu_usage": psutil.cpu_percent(),
|
||||||
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8888)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
254
main_v0.py
Normal file
254
main_v0.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
import os
|
||||||
|
import importlib
|
||||||
|
import asyncio
|
||||||
|
from functools import lru_cache
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import FileResponse
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from crawl4ai.web_crawler import WebCrawler
|
||||||
|
from crawl4ai.database import get_total_count, clear_db
|
||||||
|
|
||||||
|
import time
|
||||||
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
# load .env file
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
||||||
|
current_requests = 0
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Initialize rate limiter
|
||||||
|
def rate_limit_key_func(request: Request):
|
||||||
|
access_token = request.headers.get("access-token")
|
||||||
|
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||||
|
return None
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=rate_limit_key_func)
|
||||||
|
app.state.limiter = limiter
|
||||||
|
|
||||||
|
# Dictionary to store last request times for each client
|
||||||
|
last_request_times = {}
|
||||||
|
last_rate_limit = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limit():
|
||||||
|
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
||||||
|
return f"{limit}/minute"
|
||||||
|
|
||||||
|
# Custom rate limit exceeded handler
|
||||||
|
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||||
|
if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60:
|
||||||
|
last_rate_limit[request.client.host] = time.time()
|
||||||
|
retry_after = 60 - (time.time() - last_rate_limit[request.client.host])
|
||||||
|
reset_at = time.time() + retry_after
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"detail": "Rate limit exceeded",
|
||||||
|
"limit": str(exc.limit.limit),
|
||||||
|
"retry_after": retry_after,
|
||||||
|
'reset_at': reset_at,
|
||||||
|
"message": f"You have exceeded the rate limit of {exc.limit.limit}."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
|
||||||
|
# Middleware for token-based bypass and per-request limit
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
||||||
|
access_token = request.headers.get("access-token")
|
||||||
|
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
path = request.url.path
|
||||||
|
if path in ["/crawl", "/old"]:
|
||||||
|
client_ip = request.client.host
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check time since last request
|
||||||
|
if client_ip in last_request_times:
|
||||||
|
time_since_last_request = current_time - last_request_times[client_ip]
|
||||||
|
if time_since_last_request < SPAN:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"detail": "Too many requests",
|
||||||
|
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
||||||
|
"retry_after": max(0, SPAN - time_since_last_request),
|
||||||
|
"reset_at": current_time + max(0, SPAN - time_since_last_request),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
last_request_times[client_ip] = current_time
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
app.add_middleware(RateLimitMiddleware)
|
||||||
|
|
||||||
|
# CORS configuration
|
||||||
|
origins = ["*"] # Allow all origins
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins, # List of origins that are allowed to make requests
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allows all methods
|
||||||
|
allow_headers=["*"], # Allows all headers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mount the pages directory as a static directory
|
||||||
|
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
|
||||||
|
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs")
|
||||||
|
site_templates = Jinja2Templates(directory=__location__ + "/site")
|
||||||
|
templates = Jinja2Templates(directory=__location__ + "/pages")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_crawler():
|
||||||
|
# Initialize and return a WebCrawler instance
|
||||||
|
crawler = WebCrawler(verbose = True)
|
||||||
|
crawler.warmup()
|
||||||
|
return crawler
|
||||||
|
|
||||||
|
class CrawlRequest(BaseModel):
|
||||||
|
urls: List[str]
|
||||||
|
include_raw_html: Optional[bool] = False
|
||||||
|
bypass_cache: bool = False
|
||||||
|
extract_blocks: bool = True
|
||||||
|
word_count_threshold: Optional[int] = 5
|
||||||
|
extraction_strategy: Optional[str] = "NoExtractionStrategy"
|
||||||
|
extraction_strategy_args: Optional[dict] = {}
|
||||||
|
chunking_strategy: Optional[str] = "RegexChunking"
|
||||||
|
chunking_strategy_args: Optional[dict] = {}
|
||||||
|
css_selector: Optional[str] = None
|
||||||
|
screenshot: Optional[bool] = False
|
||||||
|
user_agent: Optional[str] = None
|
||||||
|
verbose: Optional[bool] = True
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def read_root():
|
||||||
|
return RedirectResponse(url="/mkdocs")
|
||||||
|
|
||||||
|
@app.get("/old", response_class=HTMLResponse)
|
||||||
|
@limiter.limit(get_rate_limit())
|
||||||
|
async def read_index(request: Request):
|
||||||
|
partials_dir = os.path.join(__location__, "pages", "partial")
|
||||||
|
partials = {}
|
||||||
|
|
||||||
|
for filename in os.listdir(partials_dir):
|
||||||
|
if filename.endswith(".html"):
|
||||||
|
with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file:
|
||||||
|
partials[filename[:-5]] = file.read()
|
||||||
|
|
||||||
|
return templates.TemplateResponse("index.html", {"request": request, **partials})
|
||||||
|
|
||||||
|
@app.get("/total-count")
|
||||||
|
async def get_total_url_count():
|
||||||
|
count = get_total_count()
|
||||||
|
return JSONResponse(content={"count": count})
|
||||||
|
|
||||||
|
@app.get("/clear-db")
|
||||||
|
async def clear_database():
|
||||||
|
# clear_db()
|
||||||
|
return JSONResponse(content={"message": "Database cleared."})
|
||||||
|
|
||||||
|
def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
strategy_class = getattr(module, class_name)
|
||||||
|
return strategy_class(*args, **kwargs)
|
||||||
|
except ImportError:
|
||||||
|
print("ImportError: Module not found.")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
|
||||||
|
except AttributeError:
|
||||||
|
print("AttributeError: Class not found.")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
||||||
|
|
||||||
|
@app.post("/crawl")
|
||||||
|
@limiter.limit(get_rate_limit())
|
||||||
|
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
||||||
|
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
||||||
|
global current_requests
|
||||||
|
async with lock:
|
||||||
|
if current_requests >= MAX_CONCURRENT_REQUESTS:
|
||||||
|
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
|
||||||
|
current_requests += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.debug("[LOG] Loading extraction and chunking strategies...")
|
||||||
|
crawl_request.extraction_strategy_args['verbose'] = True
|
||||||
|
crawl_request.chunking_strategy_args['verbose'] = True
|
||||||
|
|
||||||
|
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
|
||||||
|
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
|
||||||
|
|
||||||
|
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
|
||||||
|
logging.debug("[LOG] Running the WebCrawler...")
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
futures = [
|
||||||
|
loop.run_in_executor(
|
||||||
|
executor,
|
||||||
|
get_crawler().run,
|
||||||
|
str(url),
|
||||||
|
crawl_request.word_count_threshold,
|
||||||
|
extraction_strategy,
|
||||||
|
chunking_strategy,
|
||||||
|
crawl_request.bypass_cache,
|
||||||
|
crawl_request.css_selector,
|
||||||
|
crawl_request.screenshot,
|
||||||
|
crawl_request.user_agent,
|
||||||
|
crawl_request.verbose
|
||||||
|
)
|
||||||
|
for url in crawl_request.urls
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*futures)
|
||||||
|
|
||||||
|
# if include_raw_html is False, remove the raw HTML content from the results
|
||||||
|
if not crawl_request.include_raw_html:
|
||||||
|
for result in results:
|
||||||
|
result.html = None
|
||||||
|
|
||||||
|
return {"results": [result.model_dump() for result in results]}
|
||||||
|
finally:
|
||||||
|
async with lock:
|
||||||
|
current_requests -= 1
|
||||||
|
|
||||||
|
@app.get("/strategies/extraction", response_class=JSONResponse)
|
||||||
|
async def get_extraction_strategies():
|
||||||
|
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
|
||||||
|
return JSONResponse(content=file.read())
|
||||||
|
|
||||||
|
@app.get("/strategies/chunking", response_class=JSONResponse)
|
||||||
|
async def get_chunking_strategies():
|
||||||
|
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
|
||||||
|
return JSONResponse(content=file.read())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8888)
|
||||||
281
tests/test_main.py
Normal file
281
tests/test_main.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
|
class NBCNewsAPITest:
|
||||||
|
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
|
||||||
|
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response:
|
||||||
|
result = await response.json()
|
||||||
|
return result["task_id"]
|
||||||
|
|
||||||
|
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]:
|
||||||
|
start_time = time.time()
|
||||||
|
while True:
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
||||||
|
|
||||||
|
status = await self.get_task_status(task_id)
|
||||||
|
if status["status"] in ["completed", "failed"]:
|
||||||
|
return status
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
async def check_health(self) -> Dict[str, Any]:
|
||||||
|
async with self.session.get(f"{self.base_url}/health") as response:
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
async def test_basic_crawl():
|
||||||
|
print("\n=== Testing Basic Crawl ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
request = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 10
|
||||||
|
}
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert "result" in result
|
||||||
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
async def test_js_execution():
|
||||||
|
print("\n=== Testing JS Execution ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
request = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 8,
|
||||||
|
"js_code": [
|
||||||
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
|
],
|
||||||
|
"wait_for": "article.tease-card:nth-child(10)",
|
||||||
|
"crawler_params": {
|
||||||
|
"headless": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
async def test_css_selector():
|
||||||
|
print("\n=== Testing CSS Selector ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
request = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 7,
|
||||||
|
"css_selector": ".wide-tease-item__description"
|
||||||
|
}
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
async def test_structured_extraction():
|
||||||
|
print("\n=== Testing Structured Extraction ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
schema = {
|
||||||
|
"name": "NBC News Articles",
|
||||||
|
"baseSelector": "article.tease-card",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "title",
|
||||||
|
"selector": "h2",
|
||||||
|
"type": "text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "description",
|
||||||
|
"selector": ".tease-card__description",
|
||||||
|
"type": "text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "link",
|
||||||
|
"selector": "a",
|
||||||
|
"type": "attribute",
|
||||||
|
"attribute": "href"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 9,
|
||||||
|
"extraction_config": {
|
||||||
|
"type": "json_css",
|
||||||
|
"params": {
|
||||||
|
"schema": schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
extracted = json.loads(result["result"]["extracted_content"])
|
||||||
|
print(f"Extracted {len(extracted)} articles")
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert result["result"]["success"]
|
||||||
|
assert len(extracted) > 0
|
||||||
|
|
||||||
|
async def test_batch_crawl():
|
||||||
|
print("\n=== Testing Batch Crawl ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
request = {
|
||||||
|
"urls": [
|
||||||
|
"https://www.nbcnews.com/business",
|
||||||
|
"https://www.nbcnews.com/business/consumer",
|
||||||
|
"https://www.nbcnews.com/business/economy"
|
||||||
|
],
|
||||||
|
"priority": 6,
|
||||||
|
"crawler_params": {
|
||||||
|
"headless": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
print(f"Batch crawl completed, got {len(result['results'])} results")
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
async def test_llm_extraction():
|
||||||
|
print("\n=== Testing LLM Extraction with Ollama ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"article_title": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The main title of the news article"
|
||||||
|
},
|
||||||
|
"summary": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A brief summary of the article content"
|
||||||
|
},
|
||||||
|
"main_topics": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Main topics or themes discussed in the article"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["article_title", "summary", "main_topics"]
|
||||||
|
}
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 8,
|
||||||
|
"extraction_config": {
|
||||||
|
"type": "llm",
|
||||||
|
"params": {
|
||||||
|
"provider": "openai/gpt-4o-mini",
|
||||||
|
"api_key": os.getenv("OLLAMA_API_KEY"),
|
||||||
|
"schema": schema,
|
||||||
|
"extraction_type": "schema",
|
||||||
|
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
|
||||||
|
Focus on the primary business news article on the page."""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"crawler_params": {
|
||||||
|
"headless": True,
|
||||||
|
"word_count_threshold": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
|
||||||
|
if result["status"] == "completed":
|
||||||
|
extracted = json.loads(result["result"]["extracted_content"])
|
||||||
|
print(f"Extracted article analysis:")
|
||||||
|
print(json.dumps(extracted, indent=2))
|
||||||
|
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
async def test_screenshot():
|
||||||
|
print("\n=== Testing Screenshot ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
request = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 5,
|
||||||
|
"screenshot": True,
|
||||||
|
"crawler_params": {
|
||||||
|
"headless": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
task_id = await api.submit_crawl(request)
|
||||||
|
result = await api.wait_for_task(task_id)
|
||||||
|
print("Screenshot captured:", bool(result["result"]["screenshot"]))
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert result["result"]["success"]
|
||||||
|
assert result["result"]["screenshot"] is not None
|
||||||
|
|
||||||
|
async def test_priority_handling():
|
||||||
|
print("\n=== Testing Priority Handling ===")
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
# Submit low priority task first
|
||||||
|
low_priority = {
|
||||||
|
"urls": "https://www.nbcnews.com/business",
|
||||||
|
"priority": 1,
|
||||||
|
"crawler_params": {"headless": True}
|
||||||
|
}
|
||||||
|
low_task_id = await api.submit_crawl(low_priority)
|
||||||
|
|
||||||
|
# Submit high priority task
|
||||||
|
high_priority = {
|
||||||
|
"urls": "https://www.nbcnews.com/business/consumer",
|
||||||
|
"priority": 10,
|
||||||
|
"crawler_params": {"headless": True}
|
||||||
|
}
|
||||||
|
high_task_id = await api.submit_crawl(high_priority)
|
||||||
|
|
||||||
|
# Get both results
|
||||||
|
high_result = await api.wait_for_task(high_task_id)
|
||||||
|
low_result = await api.wait_for_task(low_task_id)
|
||||||
|
|
||||||
|
print("Both tasks completed")
|
||||||
|
assert high_result["status"] == "completed"
|
||||||
|
assert low_result["status"] == "completed"
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
try:
|
||||||
|
# Start with health check
|
||||||
|
async with NBCNewsAPITest() as api:
|
||||||
|
health = await api.check_health()
|
||||||
|
print("Server health:", health)
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
# await test_basic_crawl()
|
||||||
|
# await test_js_execution()
|
||||||
|
# await test_css_selector()
|
||||||
|
# await test_structured_extraction()
|
||||||
|
await test_llm_extraction()
|
||||||
|
# await test_batch_crawl()
|
||||||
|
# await test_screenshot()
|
||||||
|
# await test_priority_handling()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user