From 7679064521a248d3f68014d2f29743b08b19bdf2 Mon Sep 17 00:00:00 2001 From: unclecode Date: Mon, 13 May 2024 00:06:16 +0800 Subject: [PATCH] Add model parameter for clustring. --- crawl4ai/extraction_strategy.py | 33 +++++++++++++++++++-------------- crawl4ai/web_crawler.py | 2 +- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index cebfeeea..6c265b47 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -118,7 +118,7 @@ class LLMExtractionStrategy(ExtractionStrategy): return parsed_json class HierarchicalClusteringStrategy(ExtractionStrategy): - def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3): + def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'): """ Initialize the strategy with clustering parameters. @@ -132,17 +132,22 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): self.max_dist = max_dist self.linkage_method = linkage_method self.top_k = top_k - # self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) - # self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + self.timer = time.time() + + if model_name == "bert-base-uncased": + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) + self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + elif model_name == "sshleifer/distilbart-cnn-12-6": + # self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") + # self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") + pass + elif model_name == "BAAI/bge-small-en-v1.5": + self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) + self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) + self.model.eval() self.nlp = spacy.load("models/reuters") - - # self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") - # self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") - - self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) - self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) - self.model.eval() + print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") def get_embeddings(self, sentences: List[str]): """ @@ -154,13 +159,11 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): # Tokenize sentences and convert to tensor encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings - t = time.time() with torch.no_grad(): model_output = self.model(**encoded_input) # Get embeddings from the last hidden state (mean pooling) embeddings = model_output.last_hidden_state.mean(1) - print(f"Embeddings computed in {time.time() - t:.2f} seconds") return embeddings.numpy() def hierarchical_clustering(self, sentences: List[str]): @@ -171,7 +174,9 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): :return: NumPy array of cluster labels. """ # Get embeddings + self.timer = time.time() embeddings = self.get_embeddings(sentences) + print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds") # Compute pairwise cosine distances distance_matrix = pdist(embeddings, 'cosine') # Perform agglomerative clustering respecting order @@ -214,7 +219,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): # Perform clustering labels = self.hierarchical_clustering(text_chunks) - print(f"Clustering done in {time.time() - t:.2f} seconds") + print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds") # Organize texts by their cluster labels, retaining order t = time.time() @@ -235,7 +240,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] cluster['tags'] = [cat for cat, _ in top_categories] - print(f"Categorization done in {time.time() - t:.2f} seconds") + print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") return cluster_list diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 19ad8af5..34c3afc6 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -75,7 +75,7 @@ class WebCrawler: parsed_json = [] if extract_blocks_flag: - print(f"[LOG] 🚀 Extracting semantic blocks for {url_model.url}") + print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}") t = time.time() # Split markdown into sections sections = chunking_strategy.chunk(markdown)