Remove dependency on Spacy model.
This commit is contained in:
@@ -28,68 +28,66 @@ def load_bge_small_en_v1_5():
|
||||
return tokenizer, model
|
||||
|
||||
@lru_cache()
|
||||
def load_spacy_en_core_web_sm():
|
||||
import spacy
|
||||
try:
|
||||
print("[LOG] Loading spaCy model")
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
except IOError:
|
||||
print("[LOG] ⏬ Downloading spaCy model for the first time")
|
||||
spacy.cli.download("en_core_web_sm")
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
print("[LOG] ✅ spaCy model loaded successfully")
|
||||
return nlp
|
||||
def load_text_classifier():
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from transformers import pipeline
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
||||
|
||||
return pipe
|
||||
|
||||
@lru_cache()
|
||||
def load_spacy_model():
|
||||
import spacy
|
||||
name = "models/reuters"
|
||||
home_folder = get_home_folder()
|
||||
model_folder = os.path.join(home_folder, name)
|
||||
|
||||
# Check if the model directory already exists
|
||||
if not (Path(model_folder).exists() and any(Path(model_folder).iterdir())):
|
||||
repo_url = "https://github.com/unclecode/crawl4ai.git"
|
||||
# branch = "main"
|
||||
branch = MODEL_REPO_BRANCH
|
||||
repo_folder = os.path.join(home_folder, "crawl4ai")
|
||||
model_folder = os.path.join(home_folder, name)
|
||||
def load_text_multilabel_classifier():
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
import torch
|
||||
|
||||
print("[LOG] ⏬ Downloading model for the first time...")
|
||||
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
||||
class_mapping = model.config.id2label
|
||||
|
||||
# Remove existing repo folder if it exists
|
||||
if Path(repo_folder).exists():
|
||||
shutil.rmtree(repo_folder)
|
||||
shutil.rmtree(model_folder)
|
||||
# Check for available device: CUDA, MPS (for Apple Silicon), or CPU
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
try:
|
||||
# Clone the repository
|
||||
subprocess.run(
|
||||
["git", "clone", "-b", branch, repo_url, repo_folder],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
check=True
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
# Create the models directory if it doesn't exist
|
||||
models_folder = os.path.join(home_folder, "models")
|
||||
os.makedirs(models_folder, exist_ok=True)
|
||||
def _classifier(texts, threshold=0.5, max_length=64):
|
||||
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
||||
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
|
||||
|
||||
# Copy the reuters model folder to the models directory
|
||||
source_folder = os.path.join(repo_folder, "models/reuters")
|
||||
shutil.copytree(source_folder, model_folder)
|
||||
with torch.no_grad():
|
||||
output = model(**tokens)
|
||||
|
||||
# Remove the cloned repository
|
||||
shutil.rmtree(repo_folder)
|
||||
scores = output.logits.detach().cpu().numpy()
|
||||
scores = expit(scores)
|
||||
predictions = (scores >= threshold) * 1
|
||||
|
||||
# Print completion message
|
||||
print("[LOG] ✅ Model downloaded successfully")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"An error occurred while cloning the repository: {e}")
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
batch_labels = []
|
||||
for prediction in predictions:
|
||||
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1]
|
||||
batch_labels.append(labels)
|
||||
|
||||
return spacy.load(model_folder)
|
||||
return batch_labels
|
||||
|
||||
return _classifier
|
||||
|
||||
@lru_cache()
|
||||
def load_nltk_punkt():
|
||||
import nltk
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
nltk.download('punkt')
|
||||
return nltk.data.find('tokenizers/punkt')
|
||||
|
||||
def download_all_models(remove_existing=False):
|
||||
"""Download all models required for Crawl4AI."""
|
||||
@@ -110,10 +108,10 @@ def download_all_models(remove_existing=False):
|
||||
load_bert_base_uncased()
|
||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||
load_bge_small_en_v1_5()
|
||||
print("[LOG] Downloading spaCy EN Core Web SM...")
|
||||
load_spacy_en_core_web_sm()
|
||||
print("[LOG] Downloading custom spaCy model...")
|
||||
load_spacy_model()
|
||||
print("[LOG] Downloading text classifier...")
|
||||
load_text_multilabel_classifier
|
||||
print("[LOG] Downloading custom NLTK Punkt model...")
|
||||
load_nltk_punkt()
|
||||
print("[LOG] ✅ All models downloaded successfully.")
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user