Replace embedding model with smaller one
This commit is contained in:
@@ -3,12 +3,14 @@ from typing import Any, List, Dict, Optional, Union
|
||||
from scipy.cluster.hierarchy import linkage, fcluster
|
||||
from scipy.spatial.distance import pdist
|
||||
from transformers import BertTokenizer, BertModel, pipeline
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import nltk
|
||||
from nltk.tokenize import TextTilingTokenizer
|
||||
import json, time
|
||||
import torch
|
||||
import spacy
|
||||
# from optimum.intel import IPEXModel
|
||||
|
||||
from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||
from .config import *
|
||||
@@ -130,11 +132,17 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
self.max_dist = max_dist
|
||||
self.linkage_method = linkage_method
|
||||
self.top_k = top_k
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
# self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
# self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
|
||||
self.nlp = spacy.load("models/reuters")
|
||||
|
||||
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model.eval()
|
||||
|
||||
def get_embeddings(self, sentences: List[str]):
|
||||
"""
|
||||
@@ -146,10 +154,13 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
# Tokenize sentences and convert to tensor
|
||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
# Compute token embeddings
|
||||
t = time.time()
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Get embeddings from the last hidden state (mean pooling)
|
||||
embeddings = model_output.last_hidden_state.mean(1)
|
||||
print(f"Embeddings computed in {time.time() - t:.2f} seconds")
|
||||
return embeddings.numpy()
|
||||
|
||||
def hierarchical_clustering(self, sentences: List[str]):
|
||||
@@ -224,7 +235,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||
|
||||
print(f"Processing done in {time.time() - t:.2f} seconds")
|
||||
print(f"Categorization done in {time.time() - t:.2f} seconds")
|
||||
|
||||
return cluster_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user