Improve libraries import

This commit is contained in:
unclecode
2024-05-13 02:46:35 +08:00
parent 11393183f7
commit 5fea6c064b
5 changed files with 231 additions and 125 deletions

View File

@@ -1,20 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any, List, Dict, Optional, Union
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
from transformers import BertTokenizer, BertModel, pipeline
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor, as_completed
import nltk
from nltk.tokenize import TextTilingTokenizer
import json, time
import torch
import spacy
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
from .utils import *
from functools import partial
class ExtractionStrategy(ABC):
"""
@@ -50,23 +42,32 @@ class ExtractionStrategy(ABC):
parsed_json.extend(future.result())
return parsed_json
class NoExtractionStrategy(ExtractionStrategy):
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": 0, "content": html}]
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
class LLMExtractionStrategy(ExtractionStrategy):
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None):
"""
Initialize the strategy with clustering parameters.
"""
Initialize the strategy with clustering parameters.
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
"""
super().__init__()
self.provider = provider
self.api_token = api_token
def extract(self, url: str, html: str, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
"""
super().__init__()
self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY")
if not self.api_token:
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.")
def extract(self, url: str, html: str) -> List[Dict[str, Any]]:
print("Extracting blocks ...")
variable_values = {
"URL": url,
"HTML": escape_json_string(sanitize_html(html)),
@@ -78,7 +79,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"{" + variable + "}", variable_values[variable]
)
response = perform_completion_with_backoff(provider, prompt_with_variables, api_token)
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token)
try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
@@ -96,28 +97,54 @@ class LLMExtractionStrategy(ExtractionStrategy):
"tags": ["error"],
"content": unparsed
})
print("Extracted", len(blocks), "blocks.")
return blocks
def run(self, url: str, sections: List[str], provider: str, api_token: Optional[str]) -> List[Dict[str, Any]]:
def _merge(self, documents):
chunks = []
sections = []
total_token_so_far = 0
for document in documents:
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
chunk = document.split(' ')
total_token_so_far += len(chunk) * 1.3
chunks.append(document)
else:
sections.append('\n\n'.join(chunks))
chunks = [document]
total_token_so_far = len(document.split(' ')) * 1.3
if chunks:
sections.append('\n\n'.join(chunks))
return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
"""
merged_sections = self._merge(sections)
parsed_json = []
if provider.startswith("groq/"):
if self.provider.startswith("groq/"):
# Sequential processing with a delay
for section in sections:
parsed_json.extend(self.extract(url, section, provider, api_token))
for section in merged_sections:
parsed_json.extend(self.extract(url, section))
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.extract, url, section, provider, api_token) for section in sections]
with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url)
futures = [executor.submit(extract_func, section) for section in merged_sections]
for future in as_completed(futures):
parsed_json.extend(future.result())
return parsed_json
class HierarchicalClusteringStrategy(ExtractionStrategy):
class CosinegStrategy(ExtractionStrategy):
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
"""
Initialize the strategy with clustering parameters.
@@ -128,6 +155,10 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
:param top_k: Number of top categories to extract.
"""
super().__init__()
from transformers import BertTokenizer, BertModel, pipeline
from transformers import AutoTokenizer, AutoModel
import spacy
self.word_count_threshold = word_count_threshold
self.max_dist = max_dist
self.linkage_method = linkage_method
@@ -156,6 +187,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
:param sentences: List of text chunks (sentences).
:return: NumPy array of embeddings.
"""
import torch
# Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
@@ -174,9 +206,11 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
:return: NumPy array of cluster labels.
"""
# Get embeddings
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
self.timer = time.time()
embeddings = self.get_embeddings(sentences)
print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# Compute pairwise cosine distances
distance_matrix = pdist(embeddings, 'cosine')
# Perform agglomerative clustering respecting order
@@ -219,7 +253,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
# Perform clustering
labels = self.hierarchical_clustering(text_chunks)
print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds")
# print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds")
# Organize texts by their cluster labels, retaining order
t = time.time()
@@ -240,7 +274,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy):
top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k]
cluster['tags'] = [cat for cat, _ in top_categories]
print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
# print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds")
return cluster_list
@@ -265,9 +299,10 @@ class TopicExtractionStrategy(ExtractionStrategy):
:param num_keywords: Number of keywords to represent each topic segment.
"""
import nltk
super().__init__()
self.num_keywords = num_keywords
self.tokenizer = TextTilingTokenizer()
self.tokenizer = nltk.TextTilingTokenizer()
def extract_keywords(self, text: str) -> List[str]:
"""
@@ -276,6 +311,7 @@ class TopicExtractionStrategy(ExtractionStrategy):
:param text: The text segment from which to extract keywords.
:return: A list of keyword strings.
"""
import nltk
# Tokenize the text and compute word frequency
words = nltk.word_tokenize(text)
freq_dist = nltk.FreqDist(words)