Replace embedding model with smaller one

This commit is contained in:
unclecode
2024-05-12 23:55:57 +08:00
parent 5693e324a4
commit cf087cfa58

View File

@@ -3,12 +3,14 @@ from typing import Any, List, Dict, Optional, Union
from scipy.cluster.hierarchy import linkage, fcluster from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist from scipy.spatial.distance import pdist
from transformers import BertTokenizer, BertModel, pipeline from transformers import BertTokenizer, BertModel, pipeline
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import nltk import nltk
from nltk.tokenize import TextTilingTokenizer from nltk.tokenize import TextTilingTokenizer
import json, time import json, time
import torch import torch
import spacy import spacy
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import * from .config import *
@@ -130,11 +132,17 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
self.max_dist = max_dist self.max_dist = max_dist
self.linkage_method = linkage_method self.linkage_method = linkage_method
self.top_k = top_k self.top_k = top_k
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) # self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) # self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
self.nlp = spacy.load("models/reuters") self.nlp = spacy.load("models/reuters")
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model.eval()
def get_embeddings(self, sentences: List[str]): def get_embeddings(self, sentences: List[str]):
""" """
@@ -146,10 +154,13 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
# Tokenize sentences and convert to tensor # Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings # Compute token embeddings
t = time.time()
with torch.no_grad(): with torch.no_grad():
model_output = self.model(**encoded_input) model_output = self.model(**encoded_input)
# Get embeddings from the last hidden state (mean pooling) # Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(1) embeddings = model_output.last_hidden_state.mean(1)
print(f"Embeddings computed in {time.time() - t:.2f} seconds")
return embeddings.numpy() return embeddings.numpy()
def hierarchical_clustering(self, sentences: List[str]): def hierarchical_clustering(self, sentences: List[str]):
@@ -224,7 +235,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
cluster['tags'] = [cat for cat, _ in top_categories] cluster['tags'] = [cat for cat, _ in top_categories]
print(f"Processing done in {time.time() - t:.2f} seconds") print(f"Categorization done in {time.time() - t:.2f} seconds")
return cluster_list return cluster_list