chore: Update extraction strategy to support GPU, MPS, and CPU, add batch procesing for CPU devices

This commit is contained in:
unclecode
2024-05-18 15:42:19 +08:00
parent eb6423875f
commit 3846648c12
2 changed files with 78 additions and 19 deletions

View File

@@ -187,8 +187,12 @@ class CosineStrategy(ExtractionStrategy):
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5()
self.model.eval() # Ensure the model is in evaluation mode
self.buffer_embeddings = None
self.nlp, self.device = load_text_multilabel_classifier()
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
self.default_batch_size = calculate_batch_size(self.device)
if self.verbose:
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
@@ -219,7 +223,7 @@ class CosineStrategy(ExtractionStrategy):
return filtered_docs
def get_embeddings(self, sentences: List[str], bypass_buffer=True):
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=True):
"""
Get BERT embeddings for a list of sentences.
@@ -231,15 +235,25 @@ class CosineStrategy(ExtractionStrategy):
import torch
# Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
if batch_size is None:
batch_size = self.default_batch_size
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i + batch_size]
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
# Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(1)
self.buffer_embeddings = embeddings.numpy()
return embeddings.numpy()
# Ensure no gradients are calculated
with torch.no_grad():
model_output = self.model(**encoded_input)
# Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy()
all_embeddings.append(embeddings)
self.buffer_embeddings = np.vstack(all_embeddings)
return self.buffer_embeddings
def hierarchical_clustering(self, sentences: List[str]):
"""
@@ -319,7 +333,7 @@ class CosineStrategy(ExtractionStrategy):
if self.verbose:
print(f"[LOG] 🚀 Assign tags using {self.device}")
if self.device == "gpu":
if self.device.type in ["gpu", "cuda", "mps"]:
labels = self.nlp([cluster['content'] for cluster in cluster_list])
for cluster, label in zip(cluster_list, labels):

View File

@@ -5,6 +5,48 @@ import shutil
from crawl4ai.config import MODEL_REPO_BRANCH
import argparse
@lru_cache()
def get_available_memory(device):
import torch
if device.type == 'cuda':
return torch.cuda.get_device_properties(device).total_memory
elif device.type == 'mps':
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
else:
return 0
@lru_cache()
def calculate_batch_size(device):
available_memory = get_available_memory(device)
if device.type == 'cpu':
return 16
elif device.type in ['cuda', 'mps']:
# Adjust these thresholds based on your model size and available memory
if available_memory > 32 * 1024 ** 3: # > 16GB
return 256
elif available_memory > 16 * 1024 ** 3: # > 16GB
return 128
elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB
return 64
else:
return 32
else:
return 16 # Default batch size
def set_model_device(model):
import torch
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
model.to(device)
return model, device
@lru_cache()
def get_home_folder():
home_folder = os.path.join(Path.home(), ".crawl4ai")
os.makedirs(home_folder, exist_ok=True)
@@ -17,6 +59,8 @@ def load_bert_base_uncased():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
@lru_cache()
@@ -25,17 +69,20 @@ def load_bge_small_en_v1_5():
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
@lru_cache()
def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import torch
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
model.eval()
model, device = set_model_device(model)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe
@lru_cache()
@@ -51,18 +98,16 @@ def load_text_multilabel_classifier():
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
return load_spacy_model()
# device = torch.device("cpu")
return load_spacy_model(), torch.device("cpu")
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
model.eval()
model, device = set_model_device(model)
class_mapping = model.config.id2label
model.to(device)
def _classifier(texts, threshold=0.5, max_length=64):
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
@@ -81,7 +126,7 @@ def load_text_multilabel_classifier():
return batch_labels
return _classifier, "gpu"
return _classifier, device
@lru_cache()
def load_nltk_punkt():
@@ -142,7 +187,7 @@ def load_spacy_model():
except Exception as e:
print(f"An error occurred: {e}")
return spacy.load(model_folder), "cpu"
return spacy.load(model_folder)
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""