chore: Add support for GPU, MPS, and CPU

This commit is contained in:
unclecode
2024-05-17 21:56:13 +08:00
parent 0a902f562f
commit b6319c6f6e
5 changed files with 20 additions and 8 deletions

View File

@@ -36,7 +36,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
return html
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None):
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
super().__init__()
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
self.options = Options()
@@ -48,6 +48,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.add_argument("--headless")
self.use_cached_html = use_cached_html
self.js_code = js_code
self.verbose = kwargs.get("verbose", False)
# chromedriver_autoinstaller.install()
import chromedriver_autoinstaller
@@ -62,6 +63,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return f.read()
try:
if self.verbose:
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
self.driver.get(url)
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
@@ -81,6 +84,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
with open(cache_file_path, "w") as f:
f.write(html)
if self.verbose:
print(f"[LOG] ✅ Crawled {url} successfully!")
return html
except InvalidArgumentException:

View File

@@ -178,11 +178,15 @@ class CosineStrategy(ExtractionStrategy):
self.buffer_embeddings = np.array([])
if self.verbose:
print(f"[LOG] Loading Extraction Model {model_name}")
if model_name == "bert-base-uncased":
self.tokenizer, self.model = load_bert_base_uncased()
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5()
self.nlp, self.device = load_text_multilabel_classifier()
if self.verbose:

View File

@@ -108,7 +108,7 @@ def load_spacy_model():
repo_folder = os.path.join(home_folder, "crawl4ai")
model_folder = os.path.join(home_folder, name)
print("[LOG] ⏬ Downloading model for the first time...")
# print("[LOG] ⏬ Downloading Spacy model for the first time...")
# Remove existing repo folder if it exists
if Path(repo_folder).exists():
@@ -136,7 +136,7 @@ def load_spacy_model():
shutil.rmtree(repo_folder)
# Print completion message
print("[LOG] ✅ Model downloaded successfully")
# print("[LOG] ✅ Spacy Model downloaded successfully")
except subprocess.CalledProcessError as e:
print(f"An error occurred while cloning the repository: {e}")
except Exception as e:
@@ -164,7 +164,8 @@ def download_all_models(remove_existing=False):
print("[LOG] Downloading BGE Small EN v1.5...")
load_bge_small_en_v1_5()
print("[LOG] Downloading text classifier...")
load_text_multilabel_classifier()
_, device = load_text_multilabel_classifier()
print(f"[LOG] Text classifier loaded on {device}")
print("[LOG] Downloading custom NLTK Punkt model...")
load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.")

View File

@@ -19,9 +19,10 @@ class WebCrawler:
# db_path: str = None,
crawler_strategy: CrawlerStrategy = None,
always_by_pass_cache: bool = False,
verbose: bool = False,
):
# self.db_path = db_path
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy()
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
self.always_by_pass_cache = always_by_pass_cache
# Create the .crawl4ai folder in the user's home directory if it doesn't exist

View File

@@ -12,7 +12,7 @@ console = Console()
@lru_cache()
def create_crawler():
crawler = WebCrawler()
crawler = WebCrawler(verbose=True)
crawler.warmup()
return crawler
@@ -86,7 +86,7 @@ def add_extraction_strategy(crawler):
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
result = crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3)
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, verbose=True)
)
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
print_result(result)
@@ -171,7 +171,7 @@ def main():
crawler = create_crawler()
basic_usage(crawler)
# basic_usage(crawler)
understanding_parameters(crawler)
crawler.always_by_pass_cache = True