diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index d1f68a37..cebfeeea 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -3,12 +3,14 @@ from typing import Any, List, Dict, Optional, Union from scipy.cluster.hierarchy import linkage, fcluster from scipy.spatial.distance import pdist from transformers import BertTokenizer, BertModel, pipeline +from transformers import AutoTokenizer, AutoModel from concurrent.futures import ThreadPoolExecutor, as_completed import nltk from nltk.tokenize import TextTilingTokenizer import json, time import torch import spacy +# from optimum.intel import IPEXModel from .prompts import PROMPT_EXTRACT_BLOCKS from .config import * @@ -130,11 +132,17 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): self.max_dist = max_dist self.linkage_method = linkage_method self.top_k = top_k - self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) - self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + # self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) + # self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + self.nlp = spacy.load("models/reuters") + # self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") + # self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") + self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) + self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) + self.model.eval() def get_embeddings(self, sentences: List[str]): """ @@ -146,10 +154,13 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): # Tokenize sentences and convert to tensor encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings + t = time.time() with torch.no_grad(): model_output = self.model(**encoded_input) + # Get embeddings from the last hidden state (mean pooling) embeddings = model_output.last_hidden_state.mean(1) + print(f"Embeddings computed in {time.time() - t:.2f} seconds") return embeddings.numpy() def hierarchical_clustering(self, sentences: List[str]): @@ -224,7 +235,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] cluster['tags'] = [cat for cat, _ in top_categories] - print(f"Processing done in {time.time() - t:.2f} seconds") + print(f"Categorization done in {time.time() - t:.2f} seconds") return cluster_list