chore: Update extraction strategy to support GPU, MPS, and CPU, add batch procesing for CPU devices

This commit is contained in:
unclecode
2024-05-18 15:42:19 +08:00
parent eb6423875f
commit 3846648c12
2 changed files with 78 additions and 19 deletions

View File

@@ -5,6 +5,48 @@ import shutil
from crawl4ai.config import MODEL_REPO_BRANCH
import argparse
@lru_cache()
def get_available_memory(device):
import torch
if device.type == 'cuda':
return torch.cuda.get_device_properties(device).total_memory
elif device.type == 'mps':
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
else:
return 0
@lru_cache()
def calculate_batch_size(device):
available_memory = get_available_memory(device)
if device.type == 'cpu':
return 16
elif device.type in ['cuda', 'mps']:
# Adjust these thresholds based on your model size and available memory
if available_memory > 32 * 1024 ** 3: # > 16GB
return 256
elif available_memory > 16 * 1024 ** 3: # > 16GB
return 128
elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB
return 64
else:
return 32
else:
return 16 # Default batch size
def set_model_device(model):
import torch
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
model.to(device)
return model, device
@lru_cache()
def get_home_folder():
home_folder = os.path.join(Path.home(), ".crawl4ai")
os.makedirs(home_folder, exist_ok=True)
@@ -17,6 +59,8 @@ def load_bert_base_uncased():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
@lru_cache()
@@ -25,17 +69,20 @@ def load_bge_small_en_v1_5():
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
@lru_cache()
def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import torch
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
model.eval()
model, device = set_model_device(model)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe
@lru_cache()
@@ -51,18 +98,16 @@ def load_text_multilabel_classifier():
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
return load_spacy_model()
# device = torch.device("cpu")
return load_spacy_model(), torch.device("cpu")
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
model.eval()
model, device = set_model_device(model)
class_mapping = model.config.id2label
model.to(device)
def _classifier(texts, threshold=0.5, max_length=64):
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
@@ -81,7 +126,7 @@ def load_text_multilabel_classifier():
return batch_labels
return _classifier, "gpu"
return _classifier, device
@lru_cache()
def load_nltk_punkt():
@@ -142,7 +187,7 @@ def load_spacy_model():
except Exception as e:
print(f"An error occurred: {e}")
return spacy.load(model_folder), "cpu"
return spacy.load(model_folder)
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""