chore: Update extraction strategy to support GPU, MPS, and CPU, add batch processing for CPU devices

This commit is contained in:
Unclecode
2024-05-19 16:18:58 +00:00
parent 52c4be0696
commit 53d1176d53
5 changed files with 56 additions and 23 deletions

View File

@@ -16,7 +16,7 @@ class ChunkingStrategy(ABC):
# Regex-based chunking
class RegexChunking(ChunkingStrategy):
def __init__(self, patterns=None):
def __init__(self, patterns=None, **kwargs):
if patterns is None:
patterns = [r'\n\n'] # Default split pattern
self.patterns = patterns
@@ -32,7 +32,7 @@ class RegexChunking(ChunkingStrategy):
# NLP-based sentence chunking
class NlpSentenceChunking(ChunkingStrategy):
def __init__(self):
def __init__(self, **kwargs):
load_nltk_punkt()
pass
@@ -52,7 +52,7 @@ class NlpSentenceChunking(ChunkingStrategy):
# Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy):
def __init__(self, num_keywords=3):
def __init__(self, num_keywords=3, **kwargs):
import nltk as nl
self.tokenizer = nl.toknize.TextTilingTokenizer()
self.num_keywords = num_keywords
@@ -82,7 +82,7 @@ class TopicSegmentationChunking(ChunkingStrategy):
# Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy):
def __init__(self, chunk_size=100):
def __init__(self, chunk_size=100, **kwargs):
self.chunk_size = chunk_size
def chunk(self, text: str) -> list:
@@ -91,7 +91,7 @@ class FixedLengthWordChunking(ChunkingStrategy):
# Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy):
def __init__(self, window_size=100, step=50):
def __init__(self, window_size=100, step=50, **kwargs):
self.window_size = window_size
self.step = step

View File

@@ -6,6 +6,24 @@ from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options
from selenium.common.exceptions import InvalidArgumentException
import logging
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
logger.setLevel(logging.WARNING)
logger_driver = logging.getLogger('selenium.webdriver.common.service')
logger_driver.setLevel(logging.WARNING)
urllib3_logger = logging.getLogger('urllib3.connectionpool')
urllib3_logger.setLevel(logging.WARNING)
# Disable http.client logging
http_client_logger = logging.getLogger('http.client')
http_client_logger.setLevel(logging.WARNING)
# Disable driver_finder and service logging
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder')
driver_finder_logger.setLevel(logging.WARNING)
from typing import List
import requests
@@ -43,20 +61,20 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.headless = True
self.options.add_argument("--no-sandbox")
self.options.add_argument("--headless")
self.options.add_argument("--disable-dev-shm-usage")
# self.options.add_argument("--disable-dev-shm-usage")
self.options.add_argument("--disable-gpu")
self.options.add_argument("--disable-extensions")
self.options.add_argument("--disable-infobars")
self.options.add_argument("--disable-logging")
self.options.add_argument("--disable-popup-blocking")
self.options.add_argument("--disable-translate")
self.options.add_argument("--disable-default-apps")
self.options.add_argument("--disable-background-networking")
self.options.add_argument("--disable-sync")
self.options.add_argument("--disable-features=NetworkService,NetworkServiceInProcess")
self.options.add_argument("--disable-browser-side-navigation")
self.options.add_argument("--dns-prefetch-disable")
self.options.add_argument("--disable-web-security")
# self.options.add_argument("--disable-extensions")
# self.options.add_argument("--disable-infobars")
# self.options.add_argument("--disable-logging")
# self.options.add_argument("--disable-popup-blocking")
# self.options.add_argument("--disable-translate")
# self.options.add_argument("--disable-default-apps")
# self.options.add_argument("--disable-background-networking")
# self.options.add_argument("--disable-sync")
# self.options.add_argument("--disable-features=NetworkService,NetworkServiceInProcess")
# self.options.add_argument("--disable-browser-side-navigation")
# self.options.add_argument("--dns-prefetch-disable")
# self.options.add_argument("--disable-web-security")
self.options.add_argument("--log-level=3")
self.use_cached_html = use_cached_html
self.use_cached_html = use_cached_html
@@ -66,6 +84,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# chromedriver_autoinstaller.install()
import chromedriver_autoinstaller
self.service = Service(chromedriver_autoinstaller.install())
self.service.log_path = "NUL"
self.driver = webdriver.Chrome(service=self.service, options=self.options)
def crawl(self, url: str) -> str:

View File

@@ -46,6 +46,7 @@ class ExtractionStrategy(ABC):
for future in as_completed(futures):
extracted_content.extend(future.result())
return extracted_content
class NoExtractionStrategy(ExtractionStrategy):
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": 0, "content": html}]
@@ -187,7 +188,7 @@ class CosineStrategy(ExtractionStrategy):
if self.verbose:
print(f"[LOG] Loading Extraction Model for {self.device.type} device.")
if self.device.type == "cpu":
if False and self.device.type == "cpu":
self.model = load_onnx_all_MiniLM_l6_v2()
self.tokenizer = self.model.tokenizer
self.get_embedding_method = "direct"
@@ -273,7 +274,7 @@ class CosineStrategy(ExtractionStrategy):
# if self.buffer_embeddings.any() and not bypass_buffer:
# return self.buffer_embeddings
if self.device.type in ["gpu", "cuda", "mps"]:
if self.device.type in [ "cpu", "gpu", "cuda", "mps"]:
import torch
# Tokenize sentences and convert to tensor
if batch_size is None:
@@ -295,7 +296,17 @@ class CosineStrategy(ExtractionStrategy):
self.buffer_embeddings = np.vstack(all_embeddings)
elif self.device.type == "cpu":
self.buffer_embeddings = self.model(sentences)
# self.buffer_embeddings = self.model(sentences)
if batch_size is None:
batch_size = self.default_batch_size
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i + batch_size]
embeddings = self.model(batch_sentences)
all_embeddings.append(embeddings)
self.buffer_embeddings = np.vstack(all_embeddings)
return self.buffer_embeddings
def hierarchical_clustering(self, sentences: List[str], embeddings = None):

View File

@@ -43,7 +43,7 @@ templates = Jinja2Templates(directory=__location__ + "/pages")
@lru_cache()
def get_crawler():
# Initialize and return a WebCrawler instance
return WebCrawler()
return WebCrawler(verbose = True)
class CrawlRequest(BaseModel):
urls: List[str]
@@ -105,6 +105,9 @@ async def crawl_urls(crawl_request: CrawlRequest, request: Request):
try:
logging.debug("[LOG] Loading extraction and chunking strategies...")
crawl_request.extraction_strategy_args['verbose'] = True
crawl_request.chunking_strategy_args['verbose'] = True
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)

View File

@@ -25,7 +25,7 @@
<header class="bg-zinc-950 text-lime-500 py-4 flex">
<div class="mx-auto px-4">
<h1 class="text-2xl font-bold">🔥🕷️ Crawl4AI: Web Data for your Thoughts</h1>
<h1 class="text-2xl font-bold">🔥🕷️ Crawl4AI: Web Data for your Thoughts v0.2.2</h1>
</div>
<div class="mx-auto px-4 flex font-bold text-xl gap-2">
<span>📊 Total Website Processed</span>