chore: Add support for GPU, MPS, and CPU

This commit is contained in:
unclecode
2024-05-17 21:56:13 +08:00
parent 0a902f562f
commit b6319c6f6e
5 changed files with 20 additions and 8 deletions

View File

@@ -36,7 +36,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
return html return html
class LocalSeleniumCrawlerStrategy(CrawlerStrategy): class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None): def __init__(self, use_cached_html=False, js_code=None, **kwargs):
super().__init__() super().__init__()
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy") print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
self.options = Options() self.options = Options()
@@ -48,6 +48,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.add_argument("--headless") self.options.add_argument("--headless")
self.use_cached_html = use_cached_html self.use_cached_html = use_cached_html
self.js_code = js_code self.js_code = js_code
self.verbose = kwargs.get("verbose", False)
# chromedriver_autoinstaller.install() # chromedriver_autoinstaller.install()
import chromedriver_autoinstaller import chromedriver_autoinstaller
@@ -62,6 +63,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return f.read() return f.read()
try: try:
if self.verbose:
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
self.driver.get(url) self.driver.get(url)
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html")) EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
@@ -82,6 +85,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
with open(cache_file_path, "w") as f: with open(cache_file_path, "w") as f:
f.write(html) f.write(html)
if self.verbose:
print(f"[LOG] ✅ Crawled {url} successfully!")
return html return html
except InvalidArgumentException: except InvalidArgumentException:
raise InvalidArgumentException(f"Invalid URL {url}") raise InvalidArgumentException(f"Invalid URL {url}")

View File

@@ -178,11 +178,15 @@ class CosineStrategy(ExtractionStrategy):
self.buffer_embeddings = np.array([]) self.buffer_embeddings = np.array([])
if self.verbose:
print(f"[LOG] Loading Extraction Model {model_name}")
if model_name == "bert-base-uncased": if model_name == "bert-base-uncased":
self.tokenizer, self.model = load_bert_base_uncased() self.tokenizer, self.model = load_bert_base_uncased()
elif model_name == "BAAI/bge-small-en-v1.5": elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5() self.tokenizer, self.model = load_bge_small_en_v1_5()
self.nlp, self.device = load_text_multilabel_classifier() self.nlp, self.device = load_text_multilabel_classifier()
if self.verbose: if self.verbose:

View File

@@ -108,7 +108,7 @@ def load_spacy_model():
repo_folder = os.path.join(home_folder, "crawl4ai") repo_folder = os.path.join(home_folder, "crawl4ai")
model_folder = os.path.join(home_folder, name) model_folder = os.path.join(home_folder, name)
print("[LOG] ⏬ Downloading model for the first time...") # print("[LOG] ⏬ Downloading Spacy model for the first time...")
# Remove existing repo folder if it exists # Remove existing repo folder if it exists
if Path(repo_folder).exists(): if Path(repo_folder).exists():
@@ -136,7 +136,7 @@ def load_spacy_model():
shutil.rmtree(repo_folder) shutil.rmtree(repo_folder)
# Print completion message # Print completion message
print("[LOG] ✅ Model downloaded successfully") # print("[LOG] ✅ Spacy Model downloaded successfully")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"An error occurred while cloning the repository: {e}") print(f"An error occurred while cloning the repository: {e}")
except Exception as e: except Exception as e:
@@ -164,7 +164,8 @@ def download_all_models(remove_existing=False):
print("[LOG] Downloading BGE Small EN v1.5...") print("[LOG] Downloading BGE Small EN v1.5...")
load_bge_small_en_v1_5() load_bge_small_en_v1_5()
print("[LOG] Downloading text classifier...") print("[LOG] Downloading text classifier...")
load_text_multilabel_classifier() _, device = load_text_multilabel_classifier()
print(f"[LOG] Text classifier loaded on {device}")
print("[LOG] Downloading custom NLTK Punkt model...") print("[LOG] Downloading custom NLTK Punkt model...")
load_nltk_punkt() load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.") print("[LOG] ✅ All models downloaded successfully.")

View File

@@ -19,9 +19,10 @@ class WebCrawler:
# db_path: str = None, # db_path: str = None,
crawler_strategy: CrawlerStrategy = None, crawler_strategy: CrawlerStrategy = None,
always_by_pass_cache: bool = False, always_by_pass_cache: bool = False,
verbose: bool = False,
): ):
# self.db_path = db_path # self.db_path = db_path
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy() self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
self.always_by_pass_cache = always_by_pass_cache self.always_by_pass_cache = always_by_pass_cache
# Create the .crawl4ai folder in the user's home directory if it doesn't exist # Create the .crawl4ai folder in the user's home directory if it doesn't exist

View File

@@ -12,7 +12,7 @@ console = Console()
@lru_cache() @lru_cache()
def create_crawler(): def create_crawler():
crawler = WebCrawler() crawler = WebCrawler(verbose=True)
crawler.warmup() crawler.warmup()
return crawler return crawler
@@ -86,7 +86,7 @@ def add_extraction_strategy(crawler):
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!") cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3) extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, verbose=True)
) )
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
print_result(result) print_result(result)
@@ -171,7 +171,7 @@ def main():
crawler = create_crawler() crawler = create_crawler()
basic_usage(crawler) # basic_usage(crawler)
understanding_parameters(crawler) understanding_parameters(crawler)
crawler.always_by_pass_cache = True crawler.always_by_pass_cache = True