- Test all methods

- Update index.hml
- Update Readme
- Resolve some bugs
This commit is contained in:
unclecode
2024-05-14 21:27:41 +08:00
parent 5fea6c064b
commit f6e59157bf
17 changed files with 1004 additions and 402 deletions

157
main.py
View File

@@ -1,24 +1,19 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from pydantic import BaseModel, HttpUrl
from typing import List, Optional
from crawl4ai.web_crawler import WebCrawler
from crawl4ai.models import UrlModel
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
import chromedriver_autoinstaller
from functools import lru_cache
from crawl4ai.database import get_total_count, clear_db
import os
import uuid
# Import the CORS middleware
import importlib
import asyncio
from functools import lru_cache
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, HttpUrl
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional
# Task management
tasks = {}
from crawl4ai.web_crawler import WebCrawler
from crawl4ai.database import get_total_count, clear_db
# Configuration
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@@ -41,22 +36,25 @@ app.add_middleware(
# Mount the pages directory as a static directory
app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages")
# chromedriver_autoinstaller.install() # Ensure chromedriver is installed
@lru_cache()
def get_crawler():
# Initialize and return a WebCrawler instance
return WebCrawler()
chromedriver_autoinstaller.install() # Ensure chromedriver is installed
class UrlsInput(BaseModel):
class CrawlRequest(BaseModel):
urls: List[HttpUrl]
provider_model: str
api_token: str
include_raw_html: Optional[bool] = False
forced: bool = False
bypass_cache: bool = False
extract_blocks: bool = True
word_count_threshold: Optional[int] = 5
extraction_strategy: Optional[str] = "CosineStrategy"
chunking_strategy: Optional[str] = "RegexChunking"
css_selector: Optional[str] = None
verbose: Optional[bool] = True
@lru_cache()
def get_crawler():
# Initialize and return a WebCrawler instance
return WebCrawler(db_path='crawler_data.db')
@app.get("/", response_class=HTMLResponse)
async def read_index():
@@ -66,20 +64,30 @@ async def read_index():
@app.get("/total-count")
async def get_total_url_count():
count = get_total_count(db_path='crawler_data.db')
count = get_total_count()
return JSONResponse(content={"count": count})
# Add endpoit to clear db
@app.get("/clear-db")
async def clear_database():
clear_db(db_path='crawler_data.db')
clear_db()
return JSONResponse(content={"message": "Database cleared."})
def import_strategy(module_name: str, class_name: str):
try:
module = importlib.import_module(module_name)
strategy_class = getattr(module, class_name)
return strategy_class()
except ImportError:
raise HTTPException(status_code=400, detail=f"Module {module_name} not found.")
except AttributeError:
raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.")
@app.post("/crawl")
async def crawl_urls(urls_input: UrlsInput, request: Request):
async def crawl_urls(crawl_request: CrawlRequest, request: Request):
global current_requests
# Raise error if api_token is not provided
if not urls_input.api_token:
if not crawl_request.api_token:
raise HTTPException(status_code=401, detail="API token is required.")
async with lock:
if current_requests >= MAX_CONCURRENT_REQUESTS:
@@ -87,87 +95,50 @@ async def crawl_urls(urls_input: UrlsInput, request: Request):
current_requests += 1
try:
# Prepare URL models for crawling
url_models = [UrlModel(url=url, forced=urls_input.forced) for url in urls_input.urls]
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy)
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy)
# Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner
with ThreadPoolExecutor() as executor:
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(executor, get_crawler().fetch_page, url_model, urls_input.provider_model, urls_input.api_token, urls_input.extract_blocks, urls_input.word_count_threshold)
for url_model in url_models
loop.run_in_executor(
executor,
get_crawler().run,
str(url),
crawl_request.word_count_threshold,
extraction_strategy,
chunking_strategy,
crawl_request.bypass_cache,
crawl_request.css_selector,
crawl_request.verbose
)
for url in crawl_request.urls
]
results = await asyncio.gather(*futures)
# if include_raw_html is False, remove the raw HTML content from the results
if not urls_input.include_raw_html:
if not crawl_request.include_raw_html:
for result in results:
result.html = None
return {"results": [result.dict() for result in results]}
finally:
async with lock:
current_requests -= 1
@app.get("/strategies/extraction", response_class=JSONResponse)
async def get_extraction_strategies():
# Load docs/extraction_strategies.json" and return as JSON response
with open(f"{__location__}/docs/extraction_strategies.json", "r") as file:
return JSONResponse(content=file.read())
@app.post("/crawl_async")
async def crawl_urls(urls_input: UrlsInput, request: Request):
global current_requests
if not urls_input.api_token:
raise HTTPException(status_code=401, detail="API token is required.")
async with lock:
if current_requests >= MAX_CONCURRENT_REQUESTS:
raise HTTPException(status_code=429, detail="Too many requests - please try again later.")
current_requests += 1
task_id = str(uuid.uuid4())
tasks[task_id] = {"status": "pending", "results": None}
try:
url_models = [UrlModel(url=url, forced=urls_input.forced) for url in urls_input.urls]
loop = asyncio.get_running_loop()
loop.create_task(
process_crawl_task(url_models, urls_input.provider_model, urls_input.api_token, task_id, urls_input.extract_blocks)
)
return {"task_id": task_id}
finally:
async with lock:
current_requests -= 1
async def process_crawl_task(url_models, provider, api_token, task_id, extract_blocks_flag):
try:
with ThreadPoolExecutor() as executor:
loop = asyncio.get_running_loop()
futures = [
loop.run_in_executor(executor, get_crawler().fetch_page, url_model, provider, api_token, extract_blocks_flag)
for url_model in url_models
]
results = await asyncio.gather(*futures)
tasks[task_id] = {"status": "done", "results": results}
except Exception as e:
tasks[task_id] = {"status": "failed", "error": str(e)}
@app.get("/task/{task_id}")
async def get_task_status(task_id: str):
task = tasks.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if task['status'] == 'done':
return {
"status": task['status'],
"results": [result.dict() for result in task['results']]
}
elif task['status'] == 'failed':
return {
"status": task['status'],
"error": task['error']
}
else:
return {"status": task['status']}
@app.get("/strategies/chunking", response_class=JSONResponse)
async def get_chunking_strategies():
with open(f"{__location__}/docs/chunking_strategies.json", "r") as file:
return JSONResponse(content=file.read())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)