diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 8567ea6b..49cce284 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -183,7 +183,7 @@ class CosineStrategy(ExtractionStrategy): elif model_name == "BAAI/bge-small-en-v1.5": self.tokenizer, self.model = load_bge_small_en_v1_5() - self.nlp = load_text_multilabel_classifier() + self.nlp, self.device = load_text_multilabel_classifier() if self.verbose: print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") @@ -311,20 +311,21 @@ class CosineStrategy(ExtractionStrategy): # Convert filtered clusters to a sorted list of dictionaries cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)] - labels = self.nlp([cluster['content'] for cluster in cluster_list]) + if self.device == "gpu": + labels = self.nlp([cluster['content'] for cluster in cluster_list]) + + for cluster, label in zip(cluster_list, labels): + cluster['tags'] = label + elif self.device == "cpu": + # Process the text with the loaded model + for cluster in cluster_list: + doc = self.nlp(cluster['content']) + tok_k = self.top_k + top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] + cluster['tags'] = [cat for cat, _ in top_categories] - for cluster, label in zip(cluster_list, labels): - cluster['tags'] = label - - # Process the text with the loaded model - # for cluster in cluster_list: - # cluster['tags'] = self.nlp(cluster['content'])[0]['label'] - # doc = self.nlp(cluster['content']) - # tok_k = self.top_k - # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] - # cluster['tags'] = [cat for cat, _ in top_categories] - - # print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") + if self.verbose: + print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") return cluster_list @@ -463,4 +464,4 @@ class ContentSummarizationStrategy(ExtractionStrategy): # Sort summaries by the original section index to maintain order summaries.sort(key=lambda x: x[0]) - return [summary for _, summary in summaries] \ No newline at end of file + return [summary for _, summary in summaries]