Remove dependency on Spacy model.

This commit is contained in:
unclecode
2024-05-17 15:08:03 +08:00
parent f85df91ca6
commit a5f9d07dbf
18 changed files with 123 additions and 83955 deletions

View File

@@ -28,68 +28,66 @@ def load_bge_small_en_v1_5():
return tokenizer, model
@lru_cache()
def load_spacy_en_core_web_sm():
import spacy
try:
print("[LOG] Loading spaCy model")
nlp = spacy.load("en_core_web_sm")
except IOError:
print("[LOG] ⏬ Downloading spaCy model for the first time")
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
print("[LOG] ✅ spaCy model loaded successfully")
return nlp
def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe
@lru_cache()
def load_spacy_model():
import spacy
name = "models/reuters"
home_folder = get_home_folder()
model_folder = os.path.join(home_folder, name)
# Check if the model directory already exists
if not (Path(model_folder).exists() and any(Path(model_folder).iterdir())):
repo_url = "https://github.com/unclecode/crawl4ai.git"
# branch = "main"
branch = MODEL_REPO_BRANCH
repo_folder = os.path.join(home_folder, "crawl4ai")
model_folder = os.path.join(home_folder, name)
def load_text_multilabel_classifier():
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
from scipy.special import expit
import torch
print("[LOG] ⏬ Downloading model for the first time...")
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
class_mapping = model.config.id2label
# Remove existing repo folder if it exists
if Path(repo_folder).exists():
shutil.rmtree(repo_folder)
shutil.rmtree(model_folder)
# Check for available device: CUDA, MPS (for Apple Silicon), or CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
try:
# Clone the repository
subprocess.run(
["git", "clone", "-b", branch, repo_url, repo_folder],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True
)
model.to(device)
# Create the models directory if it doesn't exist
models_folder = os.path.join(home_folder, "models")
os.makedirs(models_folder, exist_ok=True)
def _classifier(texts, threshold=0.5, max_length=64):
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
# Copy the reuters model folder to the models directory
source_folder = os.path.join(repo_folder, "models/reuters")
shutil.copytree(source_folder, model_folder)
with torch.no_grad():
output = model(**tokens)
# Remove the cloned repository
shutil.rmtree(repo_folder)
scores = output.logits.detach().cpu().numpy()
scores = expit(scores)
predictions = (scores >= threshold) * 1
# Print completion message
print("[LOG] ✅ Model downloaded successfully")
except subprocess.CalledProcessError as e:
print(f"An error occurred while cloning the repository: {e}")
except Exception as e:
print(f"An error occurred: {e}")
batch_labels = []
for prediction in predictions:
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1]
batch_labels.append(labels)
return spacy.load(model_folder)
return batch_labels
return _classifier
@lru_cache()
def load_nltk_punkt():
import nltk
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
return nltk.data.find('tokenizers/punkt')
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""
@@ -110,10 +108,10 @@ def download_all_models(remove_existing=False):
load_bert_base_uncased()
print("[LOG] Downloading BGE Small EN v1.5...")
load_bge_small_en_v1_5()
print("[LOG] Downloading spaCy EN Core Web SM...")
load_spacy_en_core_web_sm()
print("[LOG] Downloading custom spaCy model...")
load_spacy_model()
print("[LOG] Downloading text classifier...")
load_text_multilabel_classifier
print("[LOG] Downloading custom NLTK Punkt model...")
load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.")
def main():