chore: Add support for GPU, MPS, and CPU
This commit is contained in:
@@ -36,7 +36,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
return html
|
||||
|
||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self, use_cached_html=False, js_code=None):
|
||||
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||
super().__init__()
|
||||
print("[LOG] 🚀 Initializing LocalSeleniumCrawlerStrategy")
|
||||
self.options = Options()
|
||||
@@ -48,6 +48,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.options.add_argument("--headless")
|
||||
self.use_cached_html = use_cached_html
|
||||
self.js_code = js_code
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
# chromedriver_autoinstaller.install()
|
||||
import chromedriver_autoinstaller
|
||||
@@ -62,6 +63,8 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
return f.read()
|
||||
|
||||
try:
|
||||
if self.verbose:
|
||||
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
||||
self.driver.get(url)
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||
@@ -81,6 +84,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
|
||||
with open(cache_file_path, "w") as f:
|
||||
f.write(html)
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] ✅ Crawled {url} successfully!")
|
||||
|
||||
return html
|
||||
except InvalidArgumentException:
|
||||
|
||||
@@ -178,11 +178,15 @@ class CosineStrategy(ExtractionStrategy):
|
||||
|
||||
self.buffer_embeddings = np.array([])
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] Loading Extraction Model {model_name}")
|
||||
|
||||
if model_name == "bert-base-uncased":
|
||||
self.tokenizer, self.model = load_bert_base_uncased()
|
||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||
|
||||
|
||||
self.nlp, self.device = load_text_multilabel_classifier()
|
||||
|
||||
if self.verbose:
|
||||
|
||||
@@ -108,7 +108,7 @@ def load_spacy_model():
|
||||
repo_folder = os.path.join(home_folder, "crawl4ai")
|
||||
model_folder = os.path.join(home_folder, name)
|
||||
|
||||
print("[LOG] ⏬ Downloading model for the first time...")
|
||||
# print("[LOG] ⏬ Downloading Spacy model for the first time...")
|
||||
|
||||
# Remove existing repo folder if it exists
|
||||
if Path(repo_folder).exists():
|
||||
@@ -136,7 +136,7 @@ def load_spacy_model():
|
||||
shutil.rmtree(repo_folder)
|
||||
|
||||
# Print completion message
|
||||
print("[LOG] ✅ Model downloaded successfully")
|
||||
# print("[LOG] ✅ Spacy Model downloaded successfully")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"An error occurred while cloning the repository: {e}")
|
||||
except Exception as e:
|
||||
@@ -164,7 +164,8 @@ def download_all_models(remove_existing=False):
|
||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||
load_bge_small_en_v1_5()
|
||||
print("[LOG] Downloading text classifier...")
|
||||
load_text_multilabel_classifier()
|
||||
_, device = load_text_multilabel_classifier()
|
||||
print(f"[LOG] Text classifier loaded on {device}")
|
||||
print("[LOG] Downloading custom NLTK Punkt model...")
|
||||
load_nltk_punkt()
|
||||
print("[LOG] ✅ All models downloaded successfully.")
|
||||
|
||||
@@ -19,9 +19,10 @@ class WebCrawler:
|
||||
# db_path: str = None,
|
||||
crawler_strategy: CrawlerStrategy = None,
|
||||
always_by_pass_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
# self.db_path = db_path
|
||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy()
|
||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
|
||||
self.always_by_pass_cache = always_by_pass_cache
|
||||
|
||||
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
|
||||
|
||||
Reference in New Issue
Block a user