chore: Add support for GPU, MPS, and CPU
This commit is contained in:
@@ -36,7 +36,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
return html
|
||||
|
||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self, use_cached_html=False, js_code=None):
|
||||
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||
super().__init__()
|
||||
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
|
||||
self.options = Options()
|
||||
@@ -48,6 +48,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.options.add_argument("--headless")
|
||||
self.use_cached_html = use_cached_html
|
||||
self.js_code = js_code
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
# chromedriver_autoinstaller.install()
|
||||
import chromedriver_autoinstaller
|
||||
@@ -62,6 +63,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
return f.read()
|
||||
|
||||
try:
|
||||
if self.verbose:
|
||||
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
||||
self.driver.get(url)
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||
@@ -81,6 +84,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
|
||||
with open(cache_file_path, "w") as f:
|
||||
f.write(html)
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] ✅ Crawled {url} successfully!")
|
||||
|
||||
return html
|
||||
except InvalidArgumentException:
|
||||
|
||||
@@ -178,11 +178,15 @@ class CosineStrategy(ExtractionStrategy):
|
||||
|
||||
self.buffer_embeddings = np.array([])
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] Loading Extraction Model {model_name}")
|
||||
|
||||
if model_name == "bert-base-uncased":
|
||||
self.tokenizer, self.model = load_bert_base_uncased()
|
||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||
|
||||
|
||||
self.nlp, self.device = load_text_multilabel_classifier()
|
||||
|
||||
if self.verbose:
|
||||
|
||||
@@ -108,7 +108,7 @@ def load_spacy_model():
|
||||
repo_folder = os.path.join(home_folder, "crawl4ai")
|
||||
model_folder = os.path.join(home_folder, name)
|
||||
|
||||
print("[LOG] ⏬ Downloading model for the first time...")
|
||||
# print("[LOG] ⏬ Downloading Spacy model for the first time...")
|
||||
|
||||
# Remove existing repo folder if it exists
|
||||
if Path(repo_folder).exists():
|
||||
@@ -136,7 +136,7 @@ def load_spacy_model():
|
||||
shutil.rmtree(repo_folder)
|
||||
|
||||
# Print completion message
|
||||
print("[LOG] ✅ Model downloaded successfully")
|
||||
# print("[LOG] ✅ Spacy Model downloaded successfully")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"An error occurred while cloning the repository: {e}")
|
||||
except Exception as e:
|
||||
@@ -164,7 +164,8 @@ def download_all_models(remove_existing=False):
|
||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||
load_bge_small_en_v1_5()
|
||||
print("[LOG] Downloading text classifier...")
|
||||
load_text_multilabel_classifier()
|
||||
_, device = load_text_multilabel_classifier()
|
||||
print(f"[LOG] Text classifier loaded on {device}")
|
||||
print("[LOG] Downloading custom NLTK Punkt model...")
|
||||
load_nltk_punkt()
|
||||
print("[LOG] ✅ All models downloaded successfully.")
|
||||
|
||||
@@ -19,9 +19,10 @@ class WebCrawler:
|
||||
# db_path: str = None,
|
||||
crawler_strategy: CrawlerStrategy = None,
|
||||
always_by_pass_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
# self.db_path = db_path
|
||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy()
|
||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
|
||||
self.always_by_pass_cache = always_by_pass_cache
|
||||
|
||||
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
||||
|
||||
@@ -12,7 +12,7 @@ console = Console()
|
||||
|
||||
@lru_cache()
|
||||
def create_crawler():
|
||||
crawler = WebCrawler()
|
||||
crawler = WebCrawler(verbose=True)
|
||||
crawler.warmup()
|
||||
return crawler
|
||||
|
||||
@@ -86,7 +86,7 @@ def add_extraction_strategy(crawler):
|
||||
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3)
|
||||
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, verbose=True)
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
||||
print_result(result)
|
||||
@@ -171,7 +171,7 @@ def main():
|
||||
|
||||
crawler = create_crawler()
|
||||
|
||||
basic_usage(crawler)
|
||||
# basic_usage(crawler)
|
||||
understanding_parameters(crawler)
|
||||
|
||||
crawler.always_by_pass_cache = True
|
||||
|
||||
Reference in New Issue
Block a user