Add model parameter for clustring.
This commit is contained in:
@@ -118,7 +118,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
return parsed_json
|
||||
|
||||
class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3):
|
||||
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
|
||||
"""
|
||||
Initialize the strategy with clustering parameters.
|
||||
|
||||
@@ -132,17 +132,22 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
self.max_dist = max_dist
|
||||
self.linkage_method = linkage_method
|
||||
self.top_k = top_k
|
||||
# self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
# self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
self.timer = time.time()
|
||||
|
||||
if model_name == "bert-base-uncased":
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
elif model_name == "sshleifer/distilbart-cnn-12-6":
|
||||
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
pass
|
||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model.eval()
|
||||
|
||||
self.nlp = spacy.load("models/reuters")
|
||||
|
||||
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model.eval()
|
||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||
|
||||
def get_embeddings(self, sentences: List[str]):
|
||||
"""
|
||||
@@ -154,13 +159,11 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
# Tokenize sentences and convert to tensor
|
||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
# Compute token embeddings
|
||||
t = time.time()
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
|
||||
# Get embeddings from the last hidden state (mean pooling)
|
||||
embeddings = model_output.last_hidden_state.mean(1)
|
||||
print(f"Embeddings computed in {time.time() - t:.2f} seconds")
|
||||
return embeddings.numpy()
|
||||
|
||||
def hierarchical_clustering(self, sentences: List[str]):
|
||||
@@ -171,7 +174,9 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
:return: NumPy array of cluster labels.
|
||||
"""
|
||||
# Get embeddings
|
||||
self.timer = time.time()
|
||||
embeddings = self.get_embeddings(sentences)
|
||||
print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
|
||||
# Compute pairwise cosine distances
|
||||
distance_matrix = pdist(embeddings, 'cosine')
|
||||
# Perform agglomerative clustering respecting order
|
||||
@@ -214,7 +219,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
|
||||
# Perform clustering
|
||||
labels = self.hierarchical_clustering(text_chunks)
|
||||
print(f"Clustering done in {time.time() - t:.2f} seconds")
|
||||
print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds")
|
||||
|
||||
# Organize texts by their cluster labels, retaining order
|
||||
t = time.time()
|
||||
@@ -235,7 +240,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||
|
||||
print(f"Categorization done in {time.time() - t:.2f} seconds")
|
||||
print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
|
||||
|
||||
return cluster_list
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class WebCrawler:
|
||||
|
||||
parsed_json = []
|
||||
if extract_blocks_flag:
|
||||
print(f"[LOG] 🚀 Extracting semantic blocks for {url_model.url}")
|
||||
print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}")
|
||||
t = time.time()
|
||||
# Split markdown into sections
|
||||
sections = chunking_strategy.chunk(markdown)
|
||||
|
||||
Reference in New Issue
Block a user