chore: Add support for GPU, MPS, and CPU
This commit is contained in:
@@ -36,7 +36,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
|
|||||||
return html
|
return html
|
||||||
|
|
||||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||||
def __init__(self, use_cached_html=False, js_code=None):
|
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
|
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
|
||||||
self.options = Options()
|
self.options = Options()
|
||||||
@@ -48,6 +48,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
self.options.add_argument("--headless")
|
self.options.add_argument("--headless")
|
||||||
self.use_cached_html = use_cached_html
|
self.use_cached_html = use_cached_html
|
||||||
self.js_code = js_code
|
self.js_code = js_code
|
||||||
|
self.verbose = kwargs.get("verbose", False)
|
||||||
|
|
||||||
# chromedriver_autoinstaller.install()
|
# chromedriver_autoinstaller.install()
|
||||||
import chromedriver_autoinstaller
|
import chromedriver_autoinstaller
|
||||||
@@ -62,6 +63,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
||||||
self.driver.get(url)
|
self.driver.get(url)
|
||||||
WebDriverWait(self.driver, 10).until(
|
WebDriverWait(self.driver, 10).until(
|
||||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||||
@@ -81,6 +84,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
|
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
|
||||||
with open(cache_file_path, "w") as f:
|
with open(cache_file_path, "w") as f:
|
||||||
f.write(html)
|
f.write(html)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[LOG] ✅ Crawled {url} successfully!")
|
||||||
|
|
||||||
return html
|
return html
|
||||||
except InvalidArgumentException:
|
except InvalidArgumentException:
|
||||||
|
|||||||
@@ -178,11 +178,15 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
self.buffer_embeddings = np.array([])
|
self.buffer_embeddings = np.array([])
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[LOG] Loading Extraction Model {model_name}")
|
||||||
|
|
||||||
if model_name == "bert-base-uncased":
|
if model_name == "bert-base-uncased":
|
||||||
self.tokenizer, self.model = load_bert_base_uncased()
|
self.tokenizer, self.model = load_bert_base_uncased()
|
||||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||||
|
|
||||||
|
|
||||||
self.nlp, self.device = load_text_multilabel_classifier()
|
self.nlp, self.device = load_text_multilabel_classifier()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ def load_spacy_model():
|
|||||||
repo_folder = os.path.join(home_folder, "crawl4ai")
|
repo_folder = os.path.join(home_folder, "crawl4ai")
|
||||||
model_folder = os.path.join(home_folder, name)
|
model_folder = os.path.join(home_folder, name)
|
||||||
|
|
||||||
print("[LOG] ⏬ Downloading model for the first time...")
|
# print("[LOG] ⏬ Downloading Spacy model for the first time...")
|
||||||
|
|
||||||
# Remove existing repo folder if it exists
|
# Remove existing repo folder if it exists
|
||||||
if Path(repo_folder).exists():
|
if Path(repo_folder).exists():
|
||||||
@@ -136,7 +136,7 @@ def load_spacy_model():
|
|||||||
shutil.rmtree(repo_folder)
|
shutil.rmtree(repo_folder)
|
||||||
|
|
||||||
# Print completion message
|
# Print completion message
|
||||||
print("[LOG] ✅ Model downloaded successfully")
|
# print("[LOG] ✅ Spacy Model downloaded successfully")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"An error occurred while cloning the repository: {e}")
|
print(f"An error occurred while cloning the repository: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -164,7 +164,8 @@ def download_all_models(remove_existing=False):
|
|||||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||||
load_bge_small_en_v1_5()
|
load_bge_small_en_v1_5()
|
||||||
print("[LOG] Downloading text classifier...")
|
print("[LOG] Downloading text classifier...")
|
||||||
load_text_multilabel_classifier()
|
_, device = load_text_multilabel_classifier()
|
||||||
|
print(f"[LOG] Text classifier loaded on {device}")
|
||||||
print("[LOG] Downloading custom NLTK Punkt model...")
|
print("[LOG] Downloading custom NLTK Punkt model...")
|
||||||
load_nltk_punkt()
|
load_nltk_punkt()
|
||||||
print("[LOG] ✅ All models downloaded successfully.")
|
print("[LOG] ✅ All models downloaded successfully.")
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ class WebCrawler:
|
|||||||
# db_path: str = None,
|
# db_path: str = None,
|
||||||
crawler_strategy: CrawlerStrategy = None,
|
crawler_strategy: CrawlerStrategy = None,
|
||||||
always_by_pass_cache: bool = False,
|
always_by_pass_cache: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
# self.db_path = db_path
|
# self.db_path = db_path
|
||||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy()
|
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
|
||||||
self.always_by_pass_cache = always_by_pass_cache
|
self.always_by_pass_cache = always_by_pass_cache
|
||||||
|
|
||||||
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ console = Console()
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def create_crawler():
|
def create_crawler():
|
||||||
crawler = WebCrawler()
|
crawler = WebCrawler(verbose=True)
|
||||||
crawler.warmup()
|
crawler.warmup()
|
||||||
return crawler
|
return crawler
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ def add_extraction_strategy(crawler):
|
|||||||
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3)
|
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, verbose=True)
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
@@ -171,7 +171,7 @@ def main():
|
|||||||
|
|
||||||
crawler = create_crawler()
|
crawler = create_crawler()
|
||||||
|
|
||||||
basic_usage(crawler)
|
# basic_usage(crawler)
|
||||||
understanding_parameters(crawler)
|
understanding_parameters(crawler)
|
||||||
|
|
||||||
crawler.always_by_pass_cache = True
|
crawler.always_by_pass_cache = True
|
||||||
|
|||||||
Reference in New Issue
Block a user