feat: Add rate limiting functionality with custom handlers
This commit is contained in:
77
main.py
77
main.py
@@ -22,6 +22,15 @@ from typing import List, Optional
|
||||
from crawl4ai.web_crawler import WebCrawler
|
||||
from crawl4ai.database import get_total_count, clear_db
|
||||
|
||||
import time
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
||||
@@ -30,6 +39,72 @@ lock = asyncio.Lock()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize rate limiter
|
||||
def rate_limit_key_func(request: Request):
|
||||
access_token = request.headers.get("access-token")
|
||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||
return None
|
||||
return get_remote_address(request)
|
||||
|
||||
limiter = Limiter(key_func=rate_limit_key_func)
|
||||
app.state.limiter = limiter
|
||||
|
||||
# Dictionary to store last request times for each client
|
||||
last_request_times = {}
|
||||
|
||||
def get_rate_limit():
|
||||
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
||||
return f"{limit}/minute"
|
||||
|
||||
# Custom rate limit exceeded handler
|
||||
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||
try_after = last_request_times.get(request.client.host, 0) + 10 - time.time()
|
||||
reset_at = time.time() + try_after
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Rate limit exceeded",
|
||||
"limit": str(exc.limit.limit),
|
||||
"reset_at": reset_at,
|
||||
"message": f"You have exceeded the rate limit of {exc.limit.limit}. Please try again after {try_after} seconds."
|
||||
}
|
||||
)
|
||||
|
||||
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
# Middleware for token-based bypass and per-request limit
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
||||
access_token = request.headers.get("access-token")
|
||||
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path in ["/crawl", "/old"]:
|
||||
client_ip = request.client.host
|
||||
current_time = time.time()
|
||||
|
||||
# Check time since last request
|
||||
if client_ip in last_request_times:
|
||||
time_since_last_request = current_time - last_request_times[client_ip]
|
||||
if time_since_last_request < SPAN:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Too many requests",
|
||||
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
||||
"retry_after": max(0, SPAN - time_since_last_request)
|
||||
}
|
||||
)
|
||||
|
||||
last_request_times[client_ip] = current_time
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# CORS configuration
|
||||
origins = ["*"] # Allow all origins
|
||||
app.add_middleware(
|
||||
@@ -73,6 +148,7 @@ def read_root():
|
||||
return RedirectResponse(url="/mkdocs")
|
||||
|
||||
@app.get("/old", response_class=HTMLResponse)
|
||||
@limiter.limit(get_rate_limit())
|
||||
async def read_index(request: Request):
|
||||
partials_dir = os.path.join(__location__, "pages", "partial")
|
||||
partials = {}
|
||||
@@ -107,6 +183,7 @@ def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
||||
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
||||
|
||||
@app.post("/crawl")
|
||||
@limiter.limit(get_rate_limit())
|
||||
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
||||
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
||||
global current_requests
|
||||
|
||||
0
middlewares.py
Normal file
0
middlewares.py
Normal file
Reference in New Issue
Block a user