diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py index 3a2b8695..459dec1e 100644 --- a/crawl4ai/model_loader.py +++ b/crawl4ai/model_loader.py @@ -45,18 +45,21 @@ def load_text_multilabel_classifier(): from scipy.special import expit import torch - MODEL = "cardiffnlp/tweet-topic-21-multi" - tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) - model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) - class_mapping = model.config.id2label - # Check for available device: CUDA, MPS (for Apple Silicon), or CPU if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: - device = torch.device("cpu") + return load_spacy_model() + # device = torch.device("cpu") + + + MODEL = "cardiffnlp/tweet-topic-21-multi" + tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) + model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) + class_mapping = model.config.id2label + model.to(device) @@ -78,7 +81,7 @@ def load_text_multilabel_classifier(): return batch_labels - return _classifier + return _classifier, "gpu" @lru_cache() def load_nltk_punkt(): @@ -89,6 +92,58 @@ def load_nltk_punkt(): nltk.download('punkt') return nltk.data.find('tokenizers/punkt') + +@lru_cache() +def load_spacy_model(): + import spacy + name = "models/reuters" + home_folder = get_home_folder() + model_folder = os.path.join(home_folder, name) + + # Check if the model directory already exists + if not (Path(model_folder).exists() and any(Path(model_folder).iterdir())): + repo_url = "https://github.com/unclecode/crawl4ai.git" + # branch = "main" + branch = MODEL_REPO_BRANCH + repo_folder = os.path.join(home_folder, "crawl4ai") + model_folder = os.path.join(home_folder, name) + + print("[LOG] ⏬ Downloading model for the first time...") + + # Remove existing repo folder if it exists + if Path(repo_folder).exists(): + shutil.rmtree(repo_folder) + shutil.rmtree(model_folder) + + try: + # Clone the repository + subprocess.run( + ["git", "clone", "-b", branch, repo_url, repo_folder], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=True + ) + + # Create the models directory if it doesn't exist + models_folder = os.path.join(home_folder, "models") + os.makedirs(models_folder, exist_ok=True) + + # Copy the reuters model folder to the models directory + source_folder = os.path.join(repo_folder, "models/reuters") + shutil.copytree(source_folder, model_folder) + + # Remove the cloned repository + shutil.rmtree(repo_folder) + + # Print completion message + print("[LOG] ✅ Model downloaded successfully") + except subprocess.CalledProcessError as e: + print(f"An error occurred while cloning the repository: {e}") + except Exception as e: + print(f"An error occurred: {e}") + + return spacy.load(model_folder), "cpu" + def download_all_models(remove_existing=False): """Download all models required for Crawl4AI.""" if remove_existing: @@ -109,7 +164,7 @@ def download_all_models(remove_existing=False): print("[LOG] Downloading BGE Small EN v1.5...") load_bge_small_en_v1_5() print("[LOG] Downloading text classifier...") - load_text_multilabel_classifier + load_text_multilabel_classifier() print("[LOG] Downloading custom NLTK Punkt model...") load_nltk_punkt() print("[LOG] ✅ All models downloaded successfully.") @@ -124,4 +179,4 @@ def main(): download_all_models(remove_existing=args.remove_existing) if __name__ == "__main__": - main() \ No newline at end of file + main()