- Test all methods

- Update index.hml
- Update Readme
- Resolve some bugs
This commit is contained in:
unclecode
2024-05-14 21:27:41 +08:00
parent 5fea6c064b
commit f6e59157bf
17 changed files with 1004 additions and 402 deletions

View File

@@ -7,6 +7,8 @@ from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
from .utils import *
from functools import partial
from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model
class ExtractionStrategy(ABC):
"""
@@ -15,6 +17,7 @@ class ExtractionStrategy(ABC):
def __init__(self):
self.DEL = "<|DEL|>"
self.name = self.__class__.__name__
@abstractmethod
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
@@ -67,7 +70,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
def extract(self, url: str, html: str) -> List[Dict[str, Any]]:
print("Extracting blocks ...")
print("[LOG] Extracting blocks from URL:", url)
variable_values = {
"URL": url,
"HTML": escape_json_string(sanitize_html(html)),
@@ -98,7 +101,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"content": unparsed
})
print("Extracted", len(blocks), "blocks.")
print("[LOG] Extracted", len(blocks), "blocks from URL:", url)
return blocks
def _merge(self, documents):
@@ -125,6 +128,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
"""
merged_sections = self._merge(sections)
parsed_json = []
if self.provider.startswith("groq/"):
@@ -144,7 +148,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
return parsed_json
class CosinegStrategy(ExtractionStrategy):
class CosineStrategy(ExtractionStrategy):
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
"""
Initialize the strategy with clustering parameters.
@@ -164,20 +168,13 @@ class CosinegStrategy(ExtractionStrategy):
self.linkage_method = linkage_method
self.top_k = top_k
self.timer = time.time()
if model_name == "bert-base-uncased":
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
elif model_name == "sshleifer/distilbart-cnn-12-6":
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
pass
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model.eval()
self.nlp = spacy.load("models/reuters")
if model_name == "bert-base-uncased":
self.tokenizer, self.model = load_bert_base_uncased()
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5()
self.nlp = load_spacy_model()
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
def get_embeddings(self, sentences: List[str]):