Update extraction_strategy.py Support GPU, MPS, and CPU
This commit is contained in:
@@ -183,7 +183,7 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||||
|
|
||||||
self.nlp = load_text_multilabel_classifier()
|
self.nlp, self.device = load_text_multilabel_classifier()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||||
@@ -311,20 +311,21 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
# Convert filtered clusters to a sorted list of dictionaries
|
# Convert filtered clusters to a sorted list of dictionaries
|
||||||
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
|
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
|
||||||
|
|
||||||
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
if self.device == "gpu":
|
||||||
|
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
||||||
|
|
||||||
|
for cluster, label in zip(cluster_list, labels):
|
||||||
|
cluster['tags'] = label
|
||||||
|
elif self.device == "cpu":
|
||||||
|
# Process the text with the loaded model
|
||||||
|
for cluster in cluster_list:
|
||||||
|
doc = self.nlp(cluster['content'])
|
||||||
|
tok_k = self.top_k
|
||||||
|
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||||
|
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||||
|
|
||||||
for cluster, label in zip(cluster_list, labels):
|
if self.verbose:
|
||||||
cluster['tags'] = label
|
print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
|
||||||
|
|
||||||
# Process the text with the loaded model
|
|
||||||
# for cluster in cluster_list:
|
|
||||||
# cluster['tags'] = self.nlp(cluster['content'])[0]['label']
|
|
||||||
# doc = self.nlp(cluster['content'])
|
|
||||||
# tok_k = self.top_k
|
|
||||||
# top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
|
||||||
# cluster['tags'] = [cat for cat, _ in top_categories]
|
|
||||||
|
|
||||||
# print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
|
|
||||||
|
|
||||||
return cluster_list
|
return cluster_list
|
||||||
|
|
||||||
@@ -463,4 +464,4 @@ class ContentSummarizationStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
# Sort summaries by the original section index to maintain order
|
# Sort summaries by the original section index to maintain order
|
||||||
summaries.sort(key=lambda x: x[0])
|
summaries.sort(key=lambda x: x[0])
|
||||||
return [summary for _, summary in summaries]
|
return [summary for _, summary in summaries]
|
||||||
|
|||||||
Reference in New Issue
Block a user