- Text Categorization
- Crawler, Extraction, and Chunking strategies
- Clustering for semantic segmentation
This commit is contained in:
unclecode
2024-05-12 22:37:21 +08:00
parent 7039e3c1ee
commit 82706129f5
19 changed files with 84568 additions and 102 deletions

View File

@@ -1,39 +1,21 @@
import asyncio
import os, time
import json
from pathlib import Path
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options
import chromedriver_autoinstaller
from pydantic import parse_obj_as
from .models import UrlModel, CrawlResult
from .database import init_db, get_cached_url, cache_url
from .utils import *
from .chunking_strategy import *
from .extraction_strategy import *
from .crawler_strategy import *
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from .config import *
class WebCrawler:
def __init__(self, db_path: str):
def __init__(self, db_path: str, crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy()):
self.db_path = db_path
init_db(self.db_path)
self.options = Options()
self.options.headless = True
self.options.add_argument("--no-sandbox")
self.options.add_argument("--disable-dev-shm-usage")
# make it headless
self.options.add_argument("--headless")
# Automatically install or update chromedriver
chromedriver_autoinstaller.install()
# Initialize WebDriver for crawling
self.service = Service(chromedriver_autoinstaller.install())
self.driver = webdriver.Chrome(service=self.service, options=self.options)
self.crawler_strategy = crawler_strategy
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
@@ -47,10 +29,15 @@ class WebCrawler:
api_token: str = None,
extract_blocks_flag: bool = True,
word_count_threshold = MIN_WORD_THRESHOLD,
use_cached_html: bool = False) -> CrawlResult:
use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs
) -> CrawlResult:
# make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD
# if word_count_threshold < MIN_WORD_THRESHOLD:
# word_count_threshold = MIN_WORD_THRESHOLD
if word_count_threshold < MIN_WORD_THRESHOLD:
word_count_threshold = MIN_WORD_THRESHOLD
# Check cache first
cached = get_cached_url(self.db_path, str(url_model.url))
@@ -67,87 +54,38 @@ class WebCrawler:
# Initialize WebDriver for crawling
if use_cached_html:
# load html from crawl4ai_folder/cache
valid_file_name = str(url_model.url).replace("/", "_").replace(":", "_")
if os.path.exists(os.path.join(self.crawl4ai_folder, "cache", valid_file_name)):
with open(os.path.join(self.crawl4ai_folder, "cache", valid_file_name), "r") as f:
html = f.read()
else:
raise Exception("Cached HTML file not found")
t = time.time()
try:
html = self.crawler_strategy.crawl(str(url_model.url))
success = True
error_message = ""
else:
service = self.service
driver = self.driver
try:
driver.get(str(url_model.url))
WebDriverWait(driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
)
html = driver.page_source
success = True
error_message = ""
# Save html in crawl4ai_folder/cache
valid_file_name = str(url_model.url).replace("/", "_").replace(":", "_")
with open(os.path.join(self.crawl4ai_folder, "cache", valid_file_name), "w") as f:
f.write(html)
except Exception as e:
html = ""
success = False
error_message = str(e)
finally:
driver.quit()
except Exception as e:
html = ""
success = False
error_message = str(e)
# Extract content from HTML
result = get_content_of_website(html, word_count_threshold)
cleaned_html = result.get('cleaned_html', html)
markdown = result.get('markdown', "")
print("Crawling is done 🚀")
# Print a profession LOG style message, show time taken and say crawling is done
print(f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds")
parsed_json = []
if extract_blocks_flag:
print(f"[LOG] 🚀 Extracting semantic blocks for {url_model.url}")
# Split markdown into sections
paragraphs = markdown.split('\n\n')
sections = []
chunks = []
total_token_so_far = 0
for paragraph in paragraphs:
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
chunk = paragraph.split(' ')
total_token_so_far += len(chunk) * 1.3
chunks.append(paragraph)
else:
sections.append('\n\n'.join(chunks))
chunks = [paragraph]
total_token_so_far = len(paragraph.split(' ')) * 1.3
if chunks:
sections.append('\n\n'.join(chunks))
# Process sections to extract blocks
parsed_json = []
if provider.startswith("groq/"):
# Sequential processing with a delay
for section in sections:
parsed_json.extend(extract_blocks(str(url_model.url), section, provider, api_token))
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = [executor.submit(extract_blocks, str(url_model.url), section, provider, api_token) for section in sections]
for future in as_completed(futures):
parsed_json.extend(future.result())
sections = chunking_strategy.chunk(markdown)
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
parsed_json = extraction_strategy.run(str(url_model.url), sections, provider, api_token)
parsed_json = json.dumps(parsed_json)
print(f"[LOG] 🚀 Extraction done for {url_model.url}")
else:
parsed_json = "{}"
print(f"[LOG] 🚀 Skipping extraction for {url_model.url}")
# Cache the result
cleaned_html = beautify_html(cleaned_html)
@@ -163,7 +101,23 @@ class WebCrawler:
error_message=error_message
)
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None) -> List[CrawlResult]:
def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None,
extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False, extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs) -> List[CrawlResult]:
def fetch_page_wrapper(url_model, *args, **kwargs):
return self.fetch_page(url_model, *args, **kwargs)
with ThreadPoolExecutor() as executor:
results = list(executor.map(self.fetch_page, url_models, [provider] * len(url_models), [api_token] * len(url_models)))
results = list(executor.map(fetch_page_wrapper, url_models,
[provider] * len(url_models),
[api_token] * len(url_models),
[extract_blocks_flag] * len(url_models),
[word_count_threshold] * len(url_models),
[use_cached_html] * len(url_models),
[extraction_strategy] * len(url_models),
[chunking_strategy] * len(url_models),
*[kwargs] * len(url_models)))
return results