Update extraction_strategy.py Support GPU, MPS, and CPU
This commit is contained in:
@@ -183,7 +183,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||
|
||||
self.nlp = load_text_multilabel_classifier()
|
||||
self.nlp, self.device = load_text_multilabel_classifier()
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||
@@ -311,20 +311,21 @@ class CosineStrategy(ExtractionStrategy):
|
||||
# Convert filtered clusters to a sorted list of dictionaries
|
||||
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
|
||||
|
||||
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
||||
if self.device == "gpu":
|
||||
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
||||
|
||||
for cluster, label in zip(cluster_list, labels):
|
||||
cluster['tags'] = label
|
||||
elif self.device == "cpu":
|
||||
# Process the text with the loaded model
|
||||
for cluster in cluster_list:
|
||||
doc = self.nlp(cluster['content'])
|
||||
tok_k = self.top_k
|
||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||
|
||||
for cluster, label in zip(cluster_list, labels):
|
||||
cluster['tags'] = label
|
||||
|
||||
# Process the text with the loaded model
|
||||
# for cluster in cluster_list:
|
||||
# cluster['tags'] = self.nlp(cluster['content'])[0]['label']
|
||||
# doc = self.nlp(cluster['content'])
|
||||
# tok_k = self.top_k
|
||||
# top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||
# cluster['tags'] = [cat for cat, _ in top_categories]
|
||||
|
||||
# print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
|
||||
if self.verbose:
|
||||
print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
|
||||
|
||||
return cluster_list
|
||||
|
||||
@@ -463,4 +464,4 @@ class ContentSummarizationStrategy(ExtractionStrategy):
|
||||
|
||||
# Sort summaries by the original section index to maintain order
|
||||
summaries.sort(key=lambda x: x[0])
|
||||
return [summary for _, summary in summaries]
|
||||
return [summary for _, summary in summaries]
|
||||
|
||||
Reference in New Issue
Block a user