Update extraction_strategy.py Support GPU, MPS, and CPU

This commit is contained in:
UncleCode
2024-05-17 21:40:48 +08:00
committed by GitHub
parent 33fddc27ad
commit 454135856e

View File

@@ -183,7 +183,7 @@ class CosineStrategy(ExtractionStrategy):
elif model_name == "BAAI/bge-small-en-v1.5": elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5() self.tokenizer, self.model = load_bge_small_en_v1_5()
self.nlp = load_text_multilabel_classifier() self.nlp, self.device = load_text_multilabel_classifier()
if self.verbose: if self.verbose:
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
@@ -311,20 +311,21 @@ class CosineStrategy(ExtractionStrategy):
# Convert filtered clusters to a sorted list of dictionaries # Convert filtered clusters to a sorted list of dictionaries
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)] cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
labels = self.nlp([cluster['content'] for cluster in cluster_list]) if self.device == "gpu":
labels = self.nlp([cluster['content'] for cluster in cluster_list])
for cluster, label in zip(cluster_list, labels):
cluster['tags'] = label
elif self.device == "cpu":
# Process the text with the loaded model
for cluster in cluster_list:
doc = self.nlp(cluster['content'])
tok_k = self.top_k
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
cluster['tags'] = [cat for cat, _ in top_categories]
for cluster, label in zip(cluster_list, labels): if self.verbose:
cluster['tags'] = label print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
# Process the text with the loaded model
# for cluster in cluster_list:
# cluster['tags'] = self.nlp(cluster['content'])[0]['label']
# doc = self.nlp(cluster['content'])
# tok_k = self.top_k
# top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
# cluster['tags'] = [cat for cat, _ in top_categories]
# print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
return cluster_list return cluster_list
@@ -463,4 +464,4 @@ class ContentSummarizationStrategy(ExtractionStrategy):
# Sort summaries by the original section index to maintain order # Sort summaries by the original section index to maintain order
summaries.sort(key=lambda x: x[0]) summaries.sort(key=lambda x: x[0])
return [summary for _, summary in summaries] return [summary for _, summary in summaries]