Add model parameter for clustring.
This commit is contained in:
@@ -118,7 +118,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
return parsed_json
|
return parsed_json
|
||||||
|
|
||||||
class HierarchicalClusteringStrategy(ExtractionStrategy):
|
class HierarchicalClusteringStrategy(ExtractionStrategy):
|
||||||
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3):
|
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
|
||||||
"""
|
"""
|
||||||
Initialize the strategy with clustering parameters.
|
Initialize the strategy with clustering parameters.
|
||||||
|
|
||||||
@@ -132,17 +132,22 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
self.max_dist = max_dist
|
self.max_dist = max_dist
|
||||||
self.linkage_method = linkage_method
|
self.linkage_method = linkage_method
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
# self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
self.timer = time.time()
|
||||||
# self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
|
||||||
|
if model_name == "bert-base-uncased":
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
|
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
|
elif model_name == "sshleifer/distilbart-cnn-12-6":
|
||||||
|
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||||
|
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||||
|
pass
|
||||||
|
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||||
|
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
self.nlp = spacy.load("models/reuters")
|
self.nlp = spacy.load("models/reuters")
|
||||||
|
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||||
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
|
||||||
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
|
||||||
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def get_embeddings(self, sentences: List[str]):
|
def get_embeddings(self, sentences: List[str]):
|
||||||
"""
|
"""
|
||||||
@@ -154,13 +159,11 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
# Tokenize sentences and convert to tensor
|
# Tokenize sentences and convert to tensor
|
||||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||||
# Compute token embeddings
|
# Compute token embeddings
|
||||||
t = time.time()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model_output = self.model(**encoded_input)
|
model_output = self.model(**encoded_input)
|
||||||
|
|
||||||
# Get embeddings from the last hidden state (mean pooling)
|
# Get embeddings from the last hidden state (mean pooling)
|
||||||
embeddings = model_output.last_hidden_state.mean(1)
|
embeddings = model_output.last_hidden_state.mean(1)
|
||||||
print(f"Embeddings computed in {time.time() - t:.2f} seconds")
|
|
||||||
return embeddings.numpy()
|
return embeddings.numpy()
|
||||||
|
|
||||||
def hierarchical_clustering(self, sentences: List[str]):
|
def hierarchical_clustering(self, sentences: List[str]):
|
||||||
@@ -171,7 +174,9 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
:return: NumPy array of cluster labels.
|
:return: NumPy array of cluster labels.
|
||||||
"""
|
"""
|
||||||
# Get embeddings
|
# Get embeddings
|
||||||
|
self.timer = time.time()
|
||||||
embeddings = self.get_embeddings(sentences)
|
embeddings = self.get_embeddings(sentences)
|
||||||
|
print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
|
||||||
# Compute pairwise cosine distances
|
# Compute pairwise cosine distances
|
||||||
distance_matrix = pdist(embeddings, 'cosine')
|
distance_matrix = pdist(embeddings, 'cosine')
|
||||||
# Perform agglomerative clustering respecting order
|
# Perform agglomerative clustering respecting order
|
||||||
@@ -214,7 +219,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
# Perform clustering
|
# Perform clustering
|
||||||
labels = self.hierarchical_clustering(text_chunks)
|
labels = self.hierarchical_clustering(text_chunks)
|
||||||
print(f"Clustering done in {time.time() - t:.2f} seconds")
|
print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds")
|
||||||
|
|
||||||
# Organize texts by their cluster labels, retaining order
|
# Organize texts by their cluster labels, retaining order
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@@ -235,7 +240,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
|
|||||||
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
|
||||||
cluster['tags'] = [cat for cat, _ in top_categories]
|
cluster['tags'] = [cat for cat, _ in top_categories]
|
||||||
|
|
||||||
print(f"Categorization done in {time.time() - t:.2f} seconds")
|
print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
|
||||||
|
|
||||||
return cluster_list
|
return cluster_list
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class WebCrawler:
|
|||||||
|
|
||||||
parsed_json = []
|
parsed_json = []
|
||||||
if extract_blocks_flag:
|
if extract_blocks_flag:
|
||||||
print(f"[LOG] 🚀 Extracting semantic blocks for {url_model.url}")
|
print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}")
|
||||||
t = time.time()
|
t = time.time()
|
||||||
# Split markdown into sections
|
# Split markdown into sections
|
||||||
sections = chunking_strategy.chunk(markdown)
|
sections = chunking_strategy.chunk(markdown)
|
||||||
|
|||||||
Reference in New Issue
Block a user