Update model loader to support GPU, MPS, and CPU
This commit is contained in:
@@ -45,18 +45,21 @@ def load_text_multilabel_classifier():
|
|||||||
from scipy.special import expit
|
from scipy.special import expit
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
|
||||||
class_mapping = model.config.id2label
|
|
||||||
|
|
||||||
# Check for available device: CUDA, MPS (for Apple Silicon), or CPU
|
# Check for available device: CUDA, MPS (for Apple Silicon), or CPU
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
return load_spacy_model()
|
||||||
|
# device = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
||||||
|
class_mapping = model.config.id2label
|
||||||
|
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
@@ -78,7 +81,7 @@ def load_text_multilabel_classifier():
|
|||||||
|
|
||||||
return batch_labels
|
return batch_labels
|
||||||
|
|
||||||
return _classifier
|
return _classifier, "gpu"
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_nltk_punkt():
|
def load_nltk_punkt():
|
||||||
@@ -89,6 +92,58 @@ def load_nltk_punkt():
|
|||||||
nltk.download('punkt')
|
nltk.download('punkt')
|
||||||
return nltk.data.find('tokenizers/punkt')
|
return nltk.data.find('tokenizers/punkt')
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def load_spacy_model():
|
||||||
|
import spacy
|
||||||
|
name = "models/reuters"
|
||||||
|
home_folder = get_home_folder()
|
||||||
|
model_folder = os.path.join(home_folder, name)
|
||||||
|
|
||||||
|
# Check if the model directory already exists
|
||||||
|
if not (Path(model_folder).exists() and any(Path(model_folder).iterdir())):
|
||||||
|
repo_url = "https://github.com/unclecode/crawl4ai.git"
|
||||||
|
# branch = "main"
|
||||||
|
branch = MODEL_REPO_BRANCH
|
||||||
|
repo_folder = os.path.join(home_folder, "crawl4ai")
|
||||||
|
model_folder = os.path.join(home_folder, name)
|
||||||
|
|
||||||
|
print("[LOG] ⏬ Downloading model for the first time...")
|
||||||
|
|
||||||
|
# Remove existing repo folder if it exists
|
||||||
|
if Path(repo_folder).exists():
|
||||||
|
shutil.rmtree(repo_folder)
|
||||||
|
shutil.rmtree(model_folder)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clone the repository
|
||||||
|
subprocess.run(
|
||||||
|
["git", "clone", "-b", branch, repo_url, repo_folder],
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the models directory if it doesn't exist
|
||||||
|
models_folder = os.path.join(home_folder, "models")
|
||||||
|
os.makedirs(models_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy the reuters model folder to the models directory
|
||||||
|
source_folder = os.path.join(repo_folder, "models/reuters")
|
||||||
|
shutil.copytree(source_folder, model_folder)
|
||||||
|
|
||||||
|
# Remove the cloned repository
|
||||||
|
shutil.rmtree(repo_folder)
|
||||||
|
|
||||||
|
# Print completion message
|
||||||
|
print("[LOG] ✅ Model downloaded successfully")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"An error occurred while cloning the repository: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
return spacy.load(model_folder), "cpu"
|
||||||
|
|
||||||
def download_all_models(remove_existing=False):
|
def download_all_models(remove_existing=False):
|
||||||
"""Download all models required for Crawl4AI."""
|
"""Download all models required for Crawl4AI."""
|
||||||
if remove_existing:
|
if remove_existing:
|
||||||
@@ -109,7 +164,7 @@ def download_all_models(remove_existing=False):
|
|||||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||||
load_bge_small_en_v1_5()
|
load_bge_small_en_v1_5()
|
||||||
print("[LOG] Downloading text classifier...")
|
print("[LOG] Downloading text classifier...")
|
||||||
load_text_multilabel_classifier
|
load_text_multilabel_classifier()
|
||||||
print("[LOG] Downloading custom NLTK Punkt model...")
|
print("[LOG] Downloading custom NLTK Punkt model...")
|
||||||
load_nltk_punkt()
|
load_nltk_punkt()
|
||||||
print("[LOG] ✅ All models downloaded successfully.")
|
print("[LOG] ✅ All models downloaded successfully.")
|
||||||
@@ -124,4 +179,4 @@ def main():
|
|||||||
download_all_models(remove_existing=args.remove_existing)
|
download_all_models(remove_existing=args.remove_existing)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user