From 65ed1aeade82c6987ec51019b92d6f9132f4fb68 Mon Sep 17 00:00:00 2001 From: unclecode Date: Mon, 8 Jul 2024 20:02:12 +0800 Subject: [PATCH] feat: Add rate limiting functionality with custom handlers --- main.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++ middlewares.py | 0 2 files changed, 77 insertions(+) create mode 100644 middlewares.py diff --git a/main.py b/main.py index 32d26be7..a928ad25 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,15 @@ from typing import List, Optional from crawl4ai.web_crawler import WebCrawler from crawl4ai.database import get_total_count, clear_db +import time +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded + +# load .env file +from dotenv import load_dotenv +load_dotenv() + # Configuration __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests @@ -30,6 +39,72 @@ lock = asyncio.Lock() app = FastAPI() +# Initialize rate limiter +def rate_limit_key_func(request: Request): + access_token = request.headers.get("access-token") + if access_token == os.environ.get('ACCESS_TOKEN'): + return None + return get_remote_address(request) + +limiter = Limiter(key_func=rate_limit_key_func) +app.state.limiter = limiter + +# Dictionary to store last request times for each client +last_request_times = {} + +def get_rate_limit(): + limit = os.environ.get('ACCESS_PER_MIN', "5") + return f"{limit}/minute" + +# Custom rate limit exceeded handler +async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: + try_after = last_request_times.get(request.client.host, 0) + 10 - time.time() + reset_at = time.time() + try_after + return JSONResponse( + status_code=429, + content={ + "detail": "Rate limit exceeded", + "limit": str(exc.limit.limit), + "reset_at": reset_at, + "message": f"You have exceeded the rate limit of {exc.limit.limit}. Please try again after {try_after} seconds." + } + ) + +app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler) + + +# Middleware for token-based bypass and per-request limit +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10)) + access_token = request.headers.get("access-token") + if access_token == os.environ.get('ACCESS_TOKEN'): + return await call_next(request) + + path = request.url.path + if path in ["/crawl", "/old"]: + client_ip = request.client.host + current_time = time.time() + + # Check time since last request + if client_ip in last_request_times: + time_since_last_request = current_time - last_request_times[client_ip] + if time_since_last_request < SPAN: + return JSONResponse( + status_code=429, + content={ + "detail": "Too many requests", + "message": "Rate limit exceeded. Please wait 10 seconds between requests.", + "retry_after": max(0, SPAN - time_since_last_request) + } + ) + + last_request_times[client_ip] = current_time + + return await call_next(request) + +app.add_middleware(RateLimitMiddleware) + # CORS configuration origins = ["*"] # Allow all origins app.add_middleware( @@ -73,6 +148,7 @@ def read_root(): return RedirectResponse(url="/mkdocs") @app.get("/old", response_class=HTMLResponse) +@limiter.limit(get_rate_limit()) async def read_index(request: Request): partials_dir = os.path.join(__location__, "pages", "partial") partials = {} @@ -107,6 +183,7 @@ def import_strategy(module_name: str, class_name: str, *args, **kwargs): raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.") @app.post("/crawl") +@limiter.limit(get_rate_limit()) async def crawl_urls(crawl_request: CrawlRequest, request: Request): logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}") global current_requests diff --git a/middlewares.py b/middlewares.py new file mode 100644 index 00000000..e69de29b