chore: Update extraction strategy to support GPU, MPS, and CPU, add batch procesing for CPU devices
This commit is contained in:
@@ -5,6 +5,48 @@ import shutil
|
||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||
import argparse
|
||||
|
||||
@lru_cache()
|
||||
def get_available_memory(device):
|
||||
import torch
|
||||
if device.type == 'cuda':
|
||||
return torch.cuda.get_device_properties(device).total_memory
|
||||
elif device.type == 'mps':
|
||||
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
|
||||
else:
|
||||
return 0
|
||||
|
||||
@lru_cache()
|
||||
def calculate_batch_size(device):
|
||||
available_memory = get_available_memory(device)
|
||||
|
||||
if device.type == 'cpu':
|
||||
return 16
|
||||
elif device.type in ['cuda', 'mps']:
|
||||
# Adjust these thresholds based on your model size and available memory
|
||||
if available_memory > 32 * 1024 ** 3: # > 16GB
|
||||
return 256
|
||||
elif available_memory > 16 * 1024 ** 3: # > 16GB
|
||||
return 128
|
||||
elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB
|
||||
return 64
|
||||
else:
|
||||
return 32
|
||||
else:
|
||||
return 16 # Default batch size
|
||||
|
||||
def set_model_device(model):
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
model.to(device)
|
||||
return model, device
|
||||
|
||||
@lru_cache()
|
||||
def get_home_folder():
|
||||
home_folder = os.path.join(Path.home(), ".crawl4ai")
|
||||
os.makedirs(home_folder, exist_ok=True)
|
||||
@@ -17,6 +59,8 @@ def load_bert_base_uncased():
|
||||
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
return tokenizer, model
|
||||
|
||||
@lru_cache()
|
||||
@@ -25,17 +69,20 @@ def load_bge_small_en_v1_5():
|
||||
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
return tokenizer, model
|
||||
|
||||
@lru_cache()
|
||||
def load_text_classifier():
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
||||
|
||||
return pipe
|
||||
|
||||
@lru_cache()
|
||||
@@ -51,18 +98,16 @@ def load_text_multilabel_classifier():
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
return load_spacy_model()
|
||||
# device = torch.device("cpu")
|
||||
return load_spacy_model(), torch.device("cpu")
|
||||
|
||||
|
||||
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
class_mapping = model.config.id2label
|
||||
|
||||
|
||||
model.to(device)
|
||||
|
||||
def _classifier(texts, threshold=0.5, max_length=64):
|
||||
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
||||
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
|
||||
@@ -81,7 +126,7 @@ def load_text_multilabel_classifier():
|
||||
|
||||
return batch_labels
|
||||
|
||||
return _classifier, "gpu"
|
||||
return _classifier, device
|
||||
|
||||
@lru_cache()
|
||||
def load_nltk_punkt():
|
||||
@@ -142,7 +187,7 @@ def load_spacy_model():
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
return spacy.load(model_folder), "cpu"
|
||||
return spacy.load(model_folder)
|
||||
|
||||
def download_all_models(remove_existing=False):
|
||||
"""Download all models required for Crawl4AI."""
|
||||
|
||||
Reference in New Issue
Block a user