From 53d1176d53823e8f23f1dd5d627e10ef80ce861e Mon Sep 17 00:00:00 2001 From: Unclecode Date: Sun, 19 May 2024 16:18:58 +0000 Subject: [PATCH] chore: Update extraction strategy to support GPU, MPS, and CPU, add batch processing for CPU devices --- crawl4ai/chunking_strategy.py | 10 ++++---- crawl4ai/crawler_strategy.py | 45 +++++++++++++++++++++++---------- crawl4ai/extraction_strategy.py | 17 ++++++++++--- main.py | 5 +++- pages/index.html | 2 +- 5 files changed, 56 insertions(+), 23 deletions(-) diff --git a/crawl4ai/chunking_strategy.py b/crawl4ai/chunking_strategy.py index 6ece75e3..5fe9b5e1 100644 --- a/crawl4ai/chunking_strategy.py +++ b/crawl4ai/chunking_strategy.py @@ -16,7 +16,7 @@ class ChunkingStrategy(ABC): # Regex-based chunking class RegexChunking(ChunkingStrategy): - def __init__(self, patterns=None): + def __init__(self, patterns=None, **kwargs): if patterns is None: patterns = [r'\n\n'] # Default split pattern self.patterns = patterns @@ -32,7 +32,7 @@ class RegexChunking(ChunkingStrategy): # NLP-based sentence chunking class NlpSentenceChunking(ChunkingStrategy): - def __init__(self): + def __init__(self, **kwargs): load_nltk_punkt() pass @@ -52,7 +52,7 @@ class NlpSentenceChunking(ChunkingStrategy): # Topic-based segmentation using TextTiling class TopicSegmentationChunking(ChunkingStrategy): - def __init__(self, num_keywords=3): + def __init__(self, num_keywords=3, **kwargs): import nltk as nl self.tokenizer = nl.toknize.TextTilingTokenizer() self.num_keywords = num_keywords @@ -82,7 +82,7 @@ class TopicSegmentationChunking(ChunkingStrategy): # Fixed-length word chunks class FixedLengthWordChunking(ChunkingStrategy): - def __init__(self, chunk_size=100): + def __init__(self, chunk_size=100, **kwargs): self.chunk_size = chunk_size def chunk(self, text: str) -> list: @@ -91,7 +91,7 @@ class FixedLengthWordChunking(ChunkingStrategy): # Sliding window chunking class SlidingWindowChunking(ChunkingStrategy): - def __init__(self, window_size=100, step=50): + def __init__(self, window_size=100, step=50, **kwargs): self.window_size = window_size self.step = step diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index 33988dec..a98402bc 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -6,6 +6,24 @@ from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.chrome.options import Options from selenium.common.exceptions import InvalidArgumentException +import logging +logger = logging.getLogger('selenium.webdriver.remote.remote_connection') +logger.setLevel(logging.WARNING) + +logger_driver = logging.getLogger('selenium.webdriver.common.service') +logger_driver.setLevel(logging.WARNING) + +urllib3_logger = logging.getLogger('urllib3.connectionpool') +urllib3_logger.setLevel(logging.WARNING) + +# Disable http.client logging +http_client_logger = logging.getLogger('http.client') +http_client_logger.setLevel(logging.WARNING) + +# Disable driver_finder and service logging +driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder') +driver_finder_logger.setLevel(logging.WARNING) + from typing import List import requests @@ -43,20 +61,20 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.options.headless = True self.options.add_argument("--no-sandbox") self.options.add_argument("--headless") - self.options.add_argument("--disable-dev-shm-usage") + # self.options.add_argument("--disable-dev-shm-usage") self.options.add_argument("--disable-gpu") - self.options.add_argument("--disable-extensions") - self.options.add_argument("--disable-infobars") - self.options.add_argument("--disable-logging") - self.options.add_argument("--disable-popup-blocking") - self.options.add_argument("--disable-translate") - self.options.add_argument("--disable-default-apps") - self.options.add_argument("--disable-background-networking") - self.options.add_argument("--disable-sync") - self.options.add_argument("--disable-features=NetworkService,NetworkServiceInProcess") - self.options.add_argument("--disable-browser-side-navigation") - self.options.add_argument("--dns-prefetch-disable") - self.options.add_argument("--disable-web-security") + # self.options.add_argument("--disable-extensions") + # self.options.add_argument("--disable-infobars") + # self.options.add_argument("--disable-logging") + # self.options.add_argument("--disable-popup-blocking") + # self.options.add_argument("--disable-translate") + # self.options.add_argument("--disable-default-apps") + # self.options.add_argument("--disable-background-networking") + # self.options.add_argument("--disable-sync") + # self.options.add_argument("--disable-features=NetworkService,NetworkServiceInProcess") + # self.options.add_argument("--disable-browser-side-navigation") + # self.options.add_argument("--dns-prefetch-disable") + # self.options.add_argument("--disable-web-security") self.options.add_argument("--log-level=3") self.use_cached_html = use_cached_html self.use_cached_html = use_cached_html @@ -66,6 +84,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): # chromedriver_autoinstaller.install() import chromedriver_autoinstaller self.service = Service(chromedriver_autoinstaller.install()) + self.service.log_path = "NUL" self.driver = webdriver.Chrome(service=self.service, options=self.options) def crawl(self, url: str) -> str: diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 25027bfc..2d164ff0 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -46,6 +46,7 @@ class ExtractionStrategy(ABC): for future in as_completed(futures): extracted_content.extend(future.result()) return extracted_content + class NoExtractionStrategy(ExtractionStrategy): def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: return [{"index": 0, "content": html}] @@ -187,7 +188,7 @@ class CosineStrategy(ExtractionStrategy): if self.verbose: print(f"[LOG] Loading Extraction Model for {self.device.type} device.") - if self.device.type == "cpu": + if False and self.device.type == "cpu": self.model = load_onnx_all_MiniLM_l6_v2() self.tokenizer = self.model.tokenizer self.get_embedding_method = "direct" @@ -273,7 +274,7 @@ class CosineStrategy(ExtractionStrategy): # if self.buffer_embeddings.any() and not bypass_buffer: # return self.buffer_embeddings - if self.device.type in ["gpu", "cuda", "mps"]: + if self.device.type in [ "cpu", "gpu", "cuda", "mps"]: import torch # Tokenize sentences and convert to tensor if batch_size is None: @@ -295,7 +296,17 @@ class CosineStrategy(ExtractionStrategy): self.buffer_embeddings = np.vstack(all_embeddings) elif self.device.type == "cpu": - self.buffer_embeddings = self.model(sentences) + # self.buffer_embeddings = self.model(sentences) + if batch_size is None: + batch_size = self.default_batch_size + + all_embeddings = [] + for i in range(0, len(sentences), batch_size): + batch_sentences = sentences[i:i + batch_size] + embeddings = self.model(batch_sentences) + all_embeddings.append(embeddings) + + self.buffer_embeddings = np.vstack(all_embeddings) return self.buffer_embeddings def hierarchical_clustering(self, sentences: List[str], embeddings = None): diff --git a/main.py b/main.py index 20ee0acb..604fff3c 100644 --- a/main.py +++ b/main.py @@ -43,7 +43,7 @@ templates = Jinja2Templates(directory=__location__ + "/pages") @lru_cache() def get_crawler(): # Initialize and return a WebCrawler instance - return WebCrawler() + return WebCrawler(verbose = True) class CrawlRequest(BaseModel): urls: List[str] @@ -105,6 +105,9 @@ async def crawl_urls(crawl_request: CrawlRequest, request: Request): try: logging.debug("[LOG] Loading extraction and chunking strategies...") + crawl_request.extraction_strategy_args['verbose'] = True + crawl_request.chunking_strategy_args['verbose'] = True + extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args) chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args) diff --git a/pages/index.html b/pages/index.html index 2947c34a..aa74aa0e 100644 --- a/pages/index.html +++ b/pages/index.html @@ -25,7 +25,7 @@
-

🔥🕷️ Crawl4AI: Web Data for your Thoughts

+

🔥🕷️ Crawl4AI: Web Data for your Thoughts v0.2.2

📊 Total Website Processed