chore: Update extraction strategy to support GPU, MPS, and CPU, add batch procesing for CPU devices

This commit is contained in:
unclecode
2024-05-18 15:42:19 +08:00
parent eb6423875f
commit 3846648c12
2 changed files with 78 additions and 19 deletions

View File

@@ -187,8 +187,12 @@ class CosineStrategy(ExtractionStrategy):
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5()
self.model.eval() # Ensure the model is in evaluation mode
self.buffer_embeddings = None
self.nlp, self.device = load_text_multilabel_classifier()
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
self.default_batch_size = calculate_batch_size(self.device)
if self.verbose:
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
@@ -219,7 +223,7 @@ class CosineStrategy(ExtractionStrategy):
return filtered_docs
def get_embeddings(self, sentences: List[str], bypass_buffer=True):
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=True):
"""
Get BERT embeddings for a list of sentences.
@@ -231,15 +235,25 @@ class CosineStrategy(ExtractionStrategy):
import torch
# Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
if batch_size is None:
batch_size = self.default_batch_size
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i + batch_size]
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
# Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(1)
self.buffer_embeddings = embeddings.numpy()
return embeddings.numpy()
# Ensure no gradients are calculated
with torch.no_grad():
model_output = self.model(**encoded_input)
# Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy()
all_embeddings.append(embeddings)
self.buffer_embeddings = np.vstack(all_embeddings)
return self.buffer_embeddings
def hierarchical_clustering(self, sentences: List[str]):
"""
@@ -319,7 +333,7 @@ class CosineStrategy(ExtractionStrategy):
if self.verbose:
print(f"[LOG] 🚀 Assign tags using {self.device}")
if self.device == "gpu":
if self.device.type in ["gpu", "cuda", "mps"]:
labels = self.nlp([cluster['content'] for cluster in cluster_list])
for cluster, label in zip(cluster_list, labels):