diff --git a/main_v0.py b/main_v0.py deleted file mode 100644 index 71d5eeee..00000000 --- a/main_v0.py +++ /dev/null @@ -1,254 +0,0 @@ -import os -import importlib -import asyncio -from functools import lru_cache -import logging -logging.basicConfig(level=logging.DEBUG) - -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, JSONResponse -from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware -from fastapi.templating import Jinja2Templates -from fastapi.exceptions import RequestValidationError -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import FileResponse -from fastapi.responses import RedirectResponse - -from pydantic import BaseModel, HttpUrl -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List, Optional - -from crawl4ai.web_crawler import WebCrawler -from crawl4ai.database import get_total_count, clear_db - -import time -from slowapi import Limiter, _rate_limit_exceeded_handler -from slowapi.util import get_remote_address -from slowapi.errors import RateLimitExceeded - -# load .env file -from dotenv import load_dotenv -load_dotenv() - -# Configuration -__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) -MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests -current_requests = 0 -lock = asyncio.Lock() - -app = FastAPI() - -# Initialize rate limiter -def rate_limit_key_func(request: Request): - access_token = request.headers.get("access-token") - if access_token == os.environ.get('ACCESS_TOKEN'): - return None - return get_remote_address(request) - -limiter = Limiter(key_func=rate_limit_key_func) -app.state.limiter = limiter - -# Dictionary to store last request times for each client -last_request_times = {} -last_rate_limit = {} - - -def get_rate_limit(): - limit = os.environ.get('ACCESS_PER_MIN', "5") - return f"{limit}/minute" - -# Custom rate limit exceeded handler -async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: - if request.client.host not in last_rate_limit or time.time() - last_rate_limit[request.client.host] > 60: - last_rate_limit[request.client.host] = time.time() - retry_after = 60 - (time.time() - last_rate_limit[request.client.host]) - reset_at = time.time() + retry_after - return JSONResponse( - status_code=429, - content={ - "detail": "Rate limit exceeded", - "limit": str(exc.limit.limit), - "retry_after": retry_after, - 'reset_at': reset_at, - "message": f"You have exceeded the rate limit of {exc.limit.limit}." - } - ) - -app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) - - -# Middleware for token-based bypass and per-request limit -class RateLimitMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10)) - access_token = request.headers.get("access-token") - if access_token == os.environ.get('ACCESS_TOKEN'): - return await call_next(request) - - path = request.url.path - if path in ["/crawl", "/old"]: - client_ip = request.client.host - current_time = time.time() - - # Check time since last request - if client_ip in last_request_times: - time_since_last_request = current_time - last_request_times[client_ip] - if time_since_last_request < SPAN: - return JSONResponse( - status_code=429, - content={ - "detail": "Too many requests", - "message": "Rate limit exceeded. Please wait 10 seconds between requests.", - "retry_after": max(0, SPAN - time_since_last_request), - "reset_at": current_time + max(0, SPAN - time_since_last_request), - } - ) - - last_request_times[client_ip] = current_time - - return await call_next(request) - -app.add_middleware(RateLimitMiddleware) - -# CORS configuration -origins = ["*"] # Allow all origins -app.add_middleware( - CORSMiddleware, - allow_origins=origins, # List of origins that are allowed to make requests - allow_credentials=True, - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers -) - -# Mount the pages directory as a static directory -app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages") -app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs") -site_templates = Jinja2Templates(directory=__location__ + "/site") -templates = Jinja2Templates(directory=__location__ + "/pages") - -@lru_cache() -def get_crawler(): - # Initialize and return a WebCrawler instance - crawler = WebCrawler(verbose = True) - crawler.warmup() - return crawler - -class CrawlRequest(BaseModel): - urls: List[str] - include_raw_html: Optional[bool] = False - bypass_cache: bool = False - extract_blocks: bool = True - word_count_threshold: Optional[int] = 5 - extraction_strategy: Optional[str] = "NoExtractionStrategy" - extraction_strategy_args: Optional[dict] = {} - chunking_strategy: Optional[str] = "RegexChunking" - chunking_strategy_args: Optional[dict] = {} - css_selector: Optional[str] = None - screenshot: Optional[bool] = False - user_agent: Optional[str] = None - verbose: Optional[bool] = True - -@app.get("/") -def read_root(): - return RedirectResponse(url="/mkdocs") - -@app.get("/old", response_class=HTMLResponse) -@limiter.limit(get_rate_limit()) -async def read_index(request: Request): - partials_dir = os.path.join(__location__, "pages", "partial") - partials = {} - - for filename in os.listdir(partials_dir): - if filename.endswith(".html"): - with open(os.path.join(partials_dir, filename), "r", encoding="utf8") as file: - partials[filename[:-5]] = file.read() - - return templates.TemplateResponse("index.html", {"request": request, **partials}) - -@app.get("/total-count") -async def get_total_url_count(): - count = get_total_count() - return JSONResponse(content={"count": count}) - -@app.get("/clear-db") -async def clear_database(): - # clear_db() - return JSONResponse(content={"message": "Database cleared."}) - -def import_strategy(module_name: str, class_name: str, *args, **kwargs): - try: - module = importlib.import_module(module_name) - strategy_class = getattr(module, class_name) - return strategy_class(*args, **kwargs) - except ImportError: - print("ImportError: Module not found.") - raise HTTPException(status_code=400, detail=f"Module {module_name} not found.") - except AttributeError: - print("AttributeError: Class not found.") - raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.") - -@app.post("/crawl") -@limiter.limit(get_rate_limit()) -async def crawl_urls(crawl_request: CrawlRequest, request: Request): - logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}") - global current_requests - async with lock: - if current_requests >= MAX_CONCURRENT_REQUESTS: - raise HTTPException(status_code=429, detail="Too many requests - please try again later.") - current_requests += 1 - - try: - logging.debug("[LOG] Loading extraction and chunking strategies...") - crawl_request.extraction_strategy_args['verbose'] = True - crawl_request.chunking_strategy_args['verbose'] = True - - extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args) - chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args) - - # Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner - logging.debug("[LOG] Running the WebCrawler...") - with ThreadPoolExecutor() as executor: - loop = asyncio.get_event_loop() - futures = [ - loop.run_in_executor( - executor, - get_crawler().run, - str(url), - crawl_request.word_count_threshold, - extraction_strategy, - chunking_strategy, - crawl_request.bypass_cache, - crawl_request.css_selector, - crawl_request.screenshot, - crawl_request.user_agent, - crawl_request.verbose - ) - for url in crawl_request.urls - ] - results = await asyncio.gather(*futures) - - # if include_raw_html is False, remove the raw HTML content from the results - if not crawl_request.include_raw_html: - for result in results: - result.html = None - - return {"results": [result.model_dump() for result in results]} - finally: - async with lock: - current_requests -= 1 - -@app.get("/strategies/extraction", response_class=JSONResponse) -async def get_extraction_strategies(): - with open(f"{__location__}/docs/extraction_strategies.json", "r") as file: - return JSONResponse(content=file.read()) - -@app.get("/strategies/chunking", response_class=JSONResponse) -async def get_chunking_strategies(): - with open(f"{__location__}/docs/chunking_strategies.json", "r") as file: - return JSONResponse(content=file.read()) - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8888)