From b6319c6f6e398e901c489dedfe11cbe50eb402e8 Mon Sep 17 00:00:00 2001 From: unclecode Date: Fri, 17 May 2024 21:56:13 +0800 Subject: [PATCH] chore: Add support for GPU, MPS, and CPU --- crawl4ai/crawler_strategy.py | 8 +++++++- crawl4ai/extraction_strategy.py | 4 ++++ crawl4ai/model_loader.py | 7 ++++--- crawl4ai/web_crawler.py | 3 ++- docs/examples/quickstart.py | 6 +++--- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index 24add103..0b189e67 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -36,7 +36,7 @@ class CloudCrawlerStrategy(CrawlerStrategy): return html class LocalSeleniumCrawlerStrategy(CrawlerStrategy): - def __init__(self, use_cached_html=False, js_code=None): + def __init__(self, use_cached_html=False, js_code=None, **kwargs): super().__init__() print("[LOG] πŸš€ Initializing LocalSeleniumCrawlerStrategy") self.options = Options() @@ -48,6 +48,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.options.add_argument("--headless") self.use_cached_html = use_cached_html self.js_code = js_code + self.verbose = kwargs.get("verbose", False) # chromedriver_autoinstaller.install() import chromedriver_autoinstaller @@ -62,6 +63,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): return f.read() try: + if self.verbose: + print(f"[LOG] πŸ•ΈοΈ Crawling {url} using LocalSeleniumCrawlerStrategy...") self.driver.get(url) WebDriverWait(self.driver, 10).until( EC.presence_of_all_elements_located((By.TAG_NAME, "html")) @@ -81,6 +84,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_")) with open(cache_file_path, "w") as f: f.write(html) + + if self.verbose: + print(f"[LOG] βœ… Crawled {url} successfully!") return html except InvalidArgumentException: diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 49cce284..e76c1084 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -178,11 +178,15 @@ class CosineStrategy(ExtractionStrategy): self.buffer_embeddings = np.array([]) + if self.verbose: + print(f"[LOG] Loading Extraction Model {model_name}") + if model_name == "bert-base-uncased": self.tokenizer, self.model = load_bert_base_uncased() elif model_name == "BAAI/bge-small-en-v1.5": self.tokenizer, self.model = load_bge_small_en_v1_5() + self.nlp, self.device = load_text_multilabel_classifier() if self.verbose: diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py index 459dec1e..1d7181d7 100644 --- a/crawl4ai/model_loader.py +++ b/crawl4ai/model_loader.py @@ -108,7 +108,7 @@ def load_spacy_model(): repo_folder = os.path.join(home_folder, "crawl4ai") model_folder = os.path.join(home_folder, name) - print("[LOG] ⏬ Downloading model for the first time...") + # print("[LOG] ⏬ Downloading Spacy model for the first time...") # Remove existing repo folder if it exists if Path(repo_folder).exists(): @@ -136,7 +136,7 @@ def load_spacy_model(): shutil.rmtree(repo_folder) # Print completion message - print("[LOG] βœ… Model downloaded successfully") + # print("[LOG] βœ… Spacy Model downloaded successfully") except subprocess.CalledProcessError as e: print(f"An error occurred while cloning the repository: {e}") except Exception as e: @@ -164,7 +164,8 @@ def download_all_models(remove_existing=False): print("[LOG] Downloading BGE Small EN v1.5...") load_bge_small_en_v1_5() print("[LOG] Downloading text classifier...") - load_text_multilabel_classifier() + _, device = load_text_multilabel_classifier() + print(f"[LOG] Text classifier loaded on {device}") print("[LOG] Downloading custom NLTK Punkt model...") load_nltk_punkt() print("[LOG] βœ… All models downloaded successfully.") diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 564f64f0..4535930c 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -19,9 +19,10 @@ class WebCrawler: # db_path: str = None, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, + verbose: bool = False, ): # self.db_path = db_path - self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy() + self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose) self.always_by_pass_cache = always_by_pass_cache # Create the .crawl4ai folder in the user's home directory if it doesn't exist diff --git a/docs/examples/quickstart.py b/docs/examples/quickstart.py index 73772c25..012ea65a 100644 --- a/docs/examples/quickstart.py +++ b/docs/examples/quickstart.py @@ -12,7 +12,7 @@ console = Console() @lru_cache() def create_crawler(): - crawler = WebCrawler() + crawler = WebCrawler(verbose=True) crawler.warmup() return crawler @@ -86,7 +86,7 @@ def add_extraction_strategy(crawler): cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!") result = crawler.run( url="https://www.nbcnews.com/business", - extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3) + extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, verbose=True) ) cprint("[LOG] πŸ“¦ [bold yellow]CosineStrategy result:[/bold yellow]") print_result(result) @@ -171,7 +171,7 @@ def main(): crawler = create_crawler() - basic_usage(crawler) + # basic_usage(crawler) understanding_parameters(crawler) crawler.always_by_pass_cache = True