Replace embedding model with smaller one

This commit is contained in:
unclecode
2024-05-12 23:55:57 +08:00
parent 5693e324a4
commit cf087cfa58

View File

@@ -3,12 +3,14 @@ from typing import Any, List, Dict, Optional, Union
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
from transformers import BertTokenizer, BertModel, pipeline
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor, as_completed
import nltk
from nltk.tokenize import TextTilingTokenizer
import json, time
import torch
import spacy
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
@@ -130,11 +132,17 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
self.max_dist = max_dist
self.linkage_method = linkage_method
self.top_k = top_k
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
# self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
# self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
self.nlp = spacy.load("models/reuters")
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model.eval()
def get_embeddings(self, sentences: List[str]):
"""
@@ -146,10 +154,13 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
# Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
t = time.time()
with torch.no_grad():
model_output = self.model(**encoded_input)
# Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(1)
print(f"Embeddings computed in {time.time() - t:.2f} seconds")
return embeddings.numpy()
def hierarchical_clustering(self, sentences: List[str]):
@@ -224,7 +235,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
cluster['tags'] = [cat for cat, _ in top_categories]
print(f"Processing done in {time.time() - t:.2f} seconds")
print(f"Categorization done in {time.time() - t:.2f} seconds")
return cluster_list