Update model loader to support GPU, MPS, and CPU

This commit is contained in:
UncleCode
2024-05-17 21:39:22 +08:00
committed by GitHub
parent ce052a4eb5
commit 33fddc27ad

View File

@@ -45,18 +45,21 @@ def load_text_multilabel_classifier():
from scipy.special import expit
import torch
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
class_mapping = model.config.id2label
# Check for available device: CUDA, MPS (for Apple Silicon), or CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
return load_spacy_model()
# device = torch.device("cpu")
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
class_mapping = model.config.id2label
model.to(device)
@@ -78,7 +81,7 @@ def load_text_multilabel_classifier():
return batch_labels
return _classifier
return _classifier, "gpu"
@lru_cache()
def load_nltk_punkt():
@@ -89,6 +92,58 @@ def load_nltk_punkt():
nltk.download('punkt')
return nltk.data.find('tokenizers/punkt')
@lru_cache()
def load_spacy_model():
import spacy
name = "models/reuters"
home_folder = get_home_folder()
model_folder = os.path.join(home_folder, name)
# Check if the model directory already exists
if not (Path(model_folder).exists() and any(Path(model_folder).iterdir())):
repo_url = "https://github.com/unclecode/crawl4ai.git"
# branch = "main"
branch = MODEL_REPO_BRANCH
repo_folder = os.path.join(home_folder, "crawl4ai")
model_folder = os.path.join(home_folder, name)
print("[LOG] ⏬ Downloading model for the first time...")
# Remove existing repo folder if it exists
if Path(repo_folder).exists():
shutil.rmtree(repo_folder)
shutil.rmtree(model_folder)
try:
# Clone the repository
subprocess.run(
["git", "clone", "-b", branch, repo_url, repo_folder],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True
)
# Create the models directory if it doesn't exist
models_folder = os.path.join(home_folder, "models")
os.makedirs(models_folder, exist_ok=True)
# Copy the reuters model folder to the models directory
source_folder = os.path.join(repo_folder, "models/reuters")
shutil.copytree(source_folder, model_folder)
# Remove the cloned repository
shutil.rmtree(repo_folder)
# Print completion message
print("[LOG] ✅ Model downloaded successfully")
except subprocess.CalledProcessError as e:
print(f"An error occurred while cloning the repository: {e}")
except Exception as e:
print(f"An error occurred: {e}")
return spacy.load(model_folder), "cpu"
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""
if remove_existing:
@@ -109,7 +164,7 @@ def download_all_models(remove_existing=False):
print("[LOG] Downloading BGE Small EN v1.5...")
load_bge_small_en_v1_5()
print("[LOG] Downloading text classifier...")
load_text_multilabel_classifier
load_text_multilabel_classifier()
print("[LOG] Downloading custom NLTK Punkt model...")
load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.")
@@ -124,4 +179,4 @@ def main():
download_all_models(remove_existing=args.remove_existing)
if __name__ == "__main__":
main()
main()