Replace embedding model with smaller one
This commit is contained in:
@@ -3,12 +3,14 @@ from typing import Any, List, Dict, Optional, Union
|
|||||||
from scipy.cluster.hierarchy import linkage, fcluster
|
from scipy.cluster.hierarchy import linkage, fcluster
|
||||||
from scipy.spatial.distance import pdist
|
from scipy.spatial.distance import pdist
|
||||||
from transformers import BertTokenizer, BertModel, pipeline
|
from transformers import BertTokenizer, BertModel, pipeline
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
import nltk
|
import nltk
|
||||||
from nltk.tokenize import TextTilingTokenizer
|
from nltk.tokenize import TextTilingTokenizer
|
||||||
import json, time
|
import json, time
|
||||||
import torch
|
import torch
|
||||||
import spacy
|
import spacy
|
||||||
|
# from optimum.intel import IPEXModel
|
||||||
|
|
||||||
from .prompts import PROMPT_EXTRACT_BLOCKS
|
from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||||
from .config import *
|
from .config import *
|
||||||
@@ -130,11 +132,17 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
self.max_dist = max_dist
|
self.max_dist = max_dist
|
||||||
self.linkage_method = linkage_method
|
self.linkage_method = linkage_method
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
# self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
# self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
|
|
||||||
self.nlp = spacy.load("models/reuters")
|
self.nlp = spacy.load("models/reuters")
|
||||||
|
|
||||||
|
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||||
|
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||||
|
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
def get_embeddings(self, sentences: List[str]):
|
def get_embeddings(self, sentences: List[str]):
|
||||||
"""
|
"""
|
||||||
@@ -146,10 +154,13 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
# Tokenize sentences and convert to tensor
|
# Tokenize sentences and convert to tensor
|
||||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||||
# Compute token embeddings
|
# Compute token embeddings
|
||||||
|
t = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model_output = self.model(**encoded_input)
|
model_output = self.model(**encoded_input)
|
||||||
|
|
||||||
# Get embeddings from the last hidden state (mean pooling)
|
# Get embeddings from the last hidden state (mean pooling)
|
||||||
embeddings = model_output.last_hidden_state.mean(1)
|
embeddings = model_output.last_hidden_state.mean(1)
|
||||||
|
print(f"Embeddings computed in {time.time() - t:.2f} seconds")
|
||||||
return embeddings.numpy()
|
return embeddings.numpy()
|
||||||
|
|
||||||
def hierarchical_clustering(self, sentences: List[str]):
|
def hierarchical_clustering(self, sentences: List[str]):
|
||||||
@@ -224,7 +235,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||||
cluster['tags'] = [cat for cat, _ in top_categories]
|
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||||
|
|
||||||
print(f"Processing done in {time.time() - t:.2f} seconds")
|
print(f"Categorization done in {time.time() - t:.2f} seconds")
|
||||||
|
|
||||||
return cluster_list
|
return cluster_list
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user