Merge branch 'main' of https://github.com/unclecode/crawl4ai
This commit is contained in:
77
main.py
77
main.py
@@ -22,6 +22,15 @@ from typing import List, Optional
|
|||||||
from crawl4ai.web_crawler import WebCrawler
|
from crawl4ai.web_crawler import WebCrawler
|
||||||
from crawl4ai.database import get_total_count, clear_db
|
from crawl4ai.database import get_total_count, clear_db
|
||||||
|
|
||||||
|
import time
|
||||||
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
|
||||||
|
# load .env file
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
MAX_CONCURRENT_REQUESTS = 10 # Adjust this to change the maximum concurrent requests
|
||||||
@@ -30,6 +39,72 @@ lock = asyncio.Lock()
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Initialize rate limiter
|
||||||
|
def rate_limit_key_func(request: Request):
|
||||||
|
access_token = request.headers.get("access-token")
|
||||||
|
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||||
|
return None
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=rate_limit_key_func)
|
||||||
|
app.state.limiter = limiter
|
||||||
|
|
||||||
|
# Dictionary to store last request times for each client
|
||||||
|
last_request_times = {}
|
||||||
|
|
||||||
|
def get_rate_limit():
|
||||||
|
limit = os.environ.get('ACCESS_PER_MIN', "5")
|
||||||
|
return f"{limit}/minute"
|
||||||
|
|
||||||
|
# Custom rate limit exceeded handler
|
||||||
|
async def custom_rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||||
|
try_after = last_request_times.get(request.client.host, 0) + 10 - time.time()
|
||||||
|
reset_at = time.time() + try_after
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"detail": "Rate limit exceeded",
|
||||||
|
"limit": str(exc.limit.limit),
|
||||||
|
"reset_at": reset_at,
|
||||||
|
"message": f"You have exceeded the rate limit of {exc.limit.limit}. Please try again after {try_after} seconds."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_exception_handler(RateLimitExceeded, custom_rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
|
||||||
|
# Middleware for token-based bypass and per-request limit
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
SPAN = int(os.environ.get('ACCESS_TIME_SPAN', 10))
|
||||||
|
access_token = request.headers.get("access-token")
|
||||||
|
if access_token == os.environ.get('ACCESS_TOKEN'):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
path = request.url.path
|
||||||
|
if path in ["/crawl", "/old"]:
|
||||||
|
client_ip = request.client.host
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check time since last request
|
||||||
|
if client_ip in last_request_times:
|
||||||
|
time_since_last_request = current_time - last_request_times[client_ip]
|
||||||
|
if time_since_last_request < SPAN:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"detail": "Too many requests",
|
||||||
|
"message": "Rate limit exceeded. Please wait 10 seconds between requests.",
|
||||||
|
"retry_after": max(0, SPAN - time_since_last_request)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
last_request_times[client_ip] = current_time
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
app.add_middleware(RateLimitMiddleware)
|
||||||
|
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
origins = ["*"] # Allow all origins
|
origins = ["*"] # Allow all origins
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@@ -73,6 +148,7 @@ def read_root():
|
|||||||
return RedirectResponse(url="/mkdocs")
|
return RedirectResponse(url="/mkdocs")
|
||||||
|
|
||||||
@app.get("/old", response_class=HTMLResponse)
|
@app.get("/old", response_class=HTMLResponse)
|
||||||
|
@limiter.limit(get_rate_limit())
|
||||||
async def read_index(request: Request):
|
async def read_index(request: Request):
|
||||||
partials_dir = os.path.join(__location__, "pages", "partial")
|
partials_dir = os.path.join(__location__, "pages", "partial")
|
||||||
partials = {}
|
partials = {}
|
||||||
@@ -107,6 +183,7 @@ def import_strategy(module_name: str, class_name: str, *args, **kwargs):
|
|||||||
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
|
||||||
|
|
||||||
@app.post("/crawl")
|
@app.post("/crawl")
|
||||||
|
@limiter.limit(get_rate_limit())
|
||||||
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
||||||
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
logging.debug(f"[LOG] Crawl request for URL: {crawl_request.urls}")
|
||||||
global current_requests
|
global current_requests
|
||||||
|
|||||||
0
middlewares.py
Normal file
0
middlewares.py
Normal file
Reference in New Issue
Block a user