chore: Update extraction strategy to support GPU, MPS, and CPU, add batch processing for CPU devices
This commit is contained in:
@@ -16,7 +16,7 @@ class ChunkingStrategy(ABC):
|
||||
|
||||
# Regex-based chunking
|
||||
class RegexChunking(ChunkingStrategy):
|
||||
def __init__(self, patterns=None):
|
||||
def __init__(self, patterns=None, **kwargs):
|
||||
if patterns is None:
|
||||
patterns = [r'\n\n'] # Default split pattern
|
||||
self.patterns = patterns
|
||||
@@ -32,7 +32,7 @@ class RegexChunking(ChunkingStrategy):
|
||||
|
||||
# NLP-based sentence chunking
|
||||
class NlpSentenceChunking(ChunkingStrategy):
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
load_nltk_punkt()
|
||||
pass
|
||||
|
||||
@@ -52,7 +52,7 @@ class NlpSentenceChunking(ChunkingStrategy):
|
||||
# Topic-based segmentation using TextTiling
|
||||
class TopicSegmentationChunking(ChunkingStrategy):
|
||||
|
||||
def __init__(self, num_keywords=3):
|
||||
def __init__(self, num_keywords=3, **kwargs):
|
||||
import nltk as nl
|
||||
self.tokenizer = nl.toknize.TextTilingTokenizer()
|
||||
self.num_keywords = num_keywords
|
||||
@@ -82,7 +82,7 @@ class TopicSegmentationChunking(ChunkingStrategy):
|
||||
|
||||
# Fixed-length word chunks
|
||||
class FixedLengthWordChunking(ChunkingStrategy):
|
||||
def __init__(self, chunk_size=100):
|
||||
def __init__(self, chunk_size=100, **kwargs):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
@@ -91,7 +91,7 @@ class FixedLengthWordChunking(ChunkingStrategy):
|
||||
|
||||
# Sliding window chunking
|
||||
class SlidingWindowChunking(ChunkingStrategy):
|
||||
def __init__(self, window_size=100, step=50):
|
||||
def __init__(self, window_size=100, step=50, **kwargs):
|
||||
self.window_size = window_size
|
||||
self.step = step
|
||||
|
||||
|
||||
@@ -6,6 +6,24 @@ from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.common.exceptions import InvalidArgumentException
|
||||
import logging
|
||||
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
logger_driver = logging.getLogger('selenium.webdriver.common.service')
|
||||
logger_driver.setLevel(logging.WARNING)
|
||||
|
||||
urllib3_logger = logging.getLogger('urllib3.connectionpool')
|
||||
urllib3_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Disable http.client logging
|
||||
http_client_logger = logging.getLogger('http.client')
|
||||
http_client_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Disable driver_finder and service logging
|
||||
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder')
|
||||
driver_finder_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
from typing import List
|
||||
import requests
|
||||
@@ -43,20 +61,20 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.options.headless = True
|
||||
self.options.add_argument("--no-sandbox")
|
||||
self.options.add_argument("--headless")
|
||||
self.options.add_argument("--disable-dev-shm-usage")
|
||||
# self.options.add_argument("--disable-dev-shm-usage")
|
||||
self.options.add_argument("--disable-gpu")
|
||||
self.options.add_argument("--disable-extensions")
|
||||
self.options.add_argument("--disable-infobars")
|
||||
self.options.add_argument("--disable-logging")
|
||||
self.options.add_argument("--disable-popup-blocking")
|
||||
self.options.add_argument("--disable-translate")
|
||||
self.options.add_argument("--disable-default-apps")
|
||||
self.options.add_argument("--disable-background-networking")
|
||||
self.options.add_argument("--disable-sync")
|
||||
self.options.add_argument("--disable-features=NetworkService,NetworkServiceInProcess")
|
||||
self.options.add_argument("--disable-browser-side-navigation")
|
||||
self.options.add_argument("--dns-prefetch-disable")
|
||||
self.options.add_argument("--disable-web-security")
|
||||
# self.options.add_argument("--disable-extensions")
|
||||
# self.options.add_argument("--disable-infobars")
|
||||
# self.options.add_argument("--disable-logging")
|
||||
# self.options.add_argument("--disable-popup-blocking")
|
||||
# self.options.add_argument("--disable-translate")
|
||||
# self.options.add_argument("--disable-default-apps")
|
||||
# self.options.add_argument("--disable-background-networking")
|
||||
# self.options.add_argument("--disable-sync")
|
||||
# self.options.add_argument("--disable-features=NetworkService,NetworkServiceInProcess")
|
||||
# self.options.add_argument("--disable-browser-side-navigation")
|
||||
# self.options.add_argument("--dns-prefetch-disable")
|
||||
# self.options.add_argument("--disable-web-security")
|
||||
self.options.add_argument("--log-level=3")
|
||||
self.use_cached_html = use_cached_html
|
||||
self.use_cached_html = use_cached_html
|
||||
@@ -66,6 +84,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
# chromedriver_autoinstaller.install()
|
||||
import chromedriver_autoinstaller
|
||||
self.service = Service(chromedriver_autoinstaller.install())
|
||||
self.service.log_path = "NUL"
|
||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
|
||||
def crawl(self, url: str) -> str:
|
||||
|
||||
@@ -46,6 +46,7 @@ class ExtractionStrategy(ABC):
|
||||
for future in as_completed(futures):
|
||||
extracted_content.extend(future.result())
|
||||
return extracted_content
|
||||
|
||||
class NoExtractionStrategy(ExtractionStrategy):
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
return [{"index": 0, "content": html}]
|
||||
@@ -187,7 +188,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
if self.verbose:
|
||||
print(f"[LOG] Loading Extraction Model for {self.device.type} device.")
|
||||
|
||||
if self.device.type == "cpu":
|
||||
if False and self.device.type == "cpu":
|
||||
self.model = load_onnx_all_MiniLM_l6_v2()
|
||||
self.tokenizer = self.model.tokenizer
|
||||
self.get_embedding_method = "direct"
|
||||
@@ -273,7 +274,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
# if self.buffer_embeddings.any() and not bypass_buffer:
|
||||
# return self.buffer_embeddings
|
||||
|
||||
if self.device.type in ["gpu", "cuda", "mps"]:
|
||||
if self.device.type in [ "cpu", "gpu", "cuda", "mps"]:
|
||||
import torch
|
||||
# Tokenize sentences and convert to tensor
|
||||
if batch_size is None:
|
||||
@@ -295,7 +296,17 @@ class CosineStrategy(ExtractionStrategy):
|
||||
|
||||
self.buffer_embeddings = np.vstack(all_embeddings)
|
||||
elif self.device.type == "cpu":
|
||||
self.buffer_embeddings = self.model(sentences)
|
||||
# self.buffer_embeddings = self.model(sentences)
|
||||
if batch_size is None:
|
||||
batch_size = self.default_batch_size
|
||||
|
||||
all_embeddings = []
|
||||
for i in range(0, len(sentences), batch_size):
|
||||
batch_sentences = sentences[i:i + batch_size]
|
||||
embeddings = self.model(batch_sentences)
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
self.buffer_embeddings = np.vstack(all_embeddings)
|
||||
return self.buffer_embeddings
|
||||
|
||||
def hierarchical_clustering(self, sentences: List[str], embeddings = None):
|
||||
|
||||
5
main.py
5
main.py
@@ -43,7 +43,7 @@ templates = Jinja2Templates(directory=__location__ + "/pages")
|
||||
@lru_cache()
|
||||
def get_crawler():
|
||||
# Initialize and return a WebCrawler instance
|
||||
return WebCrawler()
|
||||
return WebCrawler(verbose = True)
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: List[str]
|
||||
@@ -105,6 +105,9 @@ async def crawl_urls(crawl_request: CrawlRequest, request: Request):
|
||||
|
||||
try:
|
||||
logging.debug("[LOG] Loading extraction and chunking strategies...")
|
||||
crawl_request.extraction_strategy_args['verbose'] = True
|
||||
crawl_request.chunking_strategy_args['verbose'] = True
|
||||
|
||||
extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy, **crawl_request.extraction_strategy_args)
|
||||
chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy, **crawl_request.chunking_strategy_args)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
<header class="bg-zinc-950 text-lime-500 py-4 flex">
|
||||
|
||||
<div class="mx-auto px-4">
|
||||
<h1 class="text-2xl font-bold">🔥🕷️ Crawl4AI: Web Data for your Thoughts</h1>
|
||||
<h1 class="text-2xl font-bold">🔥🕷️ Crawl4AI: Web Data for your Thoughts v0.2.2</h1>
|
||||
</div>
|
||||
<div class="mx-auto px-4 flex font-bold text-xl gap-2">
|
||||
<span>📊 Total Website Processed</span>
|
||||
|
||||
Reference in New Issue
Block a user