- Test all methods
- Update index.hml - Update Readme - Resolve some bugs
This commit is contained in:
@@ -7,6 +7,8 @@ from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||
from .config import *
|
||||
from .utils import *
|
||||
from functools import partial
|
||||
from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model
|
||||
|
||||
|
||||
class ExtractionStrategy(ABC):
|
||||
"""
|
||||
@@ -15,6 +17,7 @@ class ExtractionStrategy(ABC):
|
||||
|
||||
def __init__(self):
|
||||
self.DEL = "<|DEL|>"
|
||||
self.name = self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
@@ -67,7 +70,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
|
||||
|
||||
def extract(self, url: str, html: str) -> List[Dict[str, Any]]:
|
||||
print("Extracting blocks ...")
|
||||
print("[LOG] Extracting blocks from URL:", url)
|
||||
variable_values = {
|
||||
"URL": url,
|
||||
"HTML": escape_json_string(sanitize_html(html)),
|
||||
@@ -98,7 +101,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
"content": unparsed
|
||||
})
|
||||
|
||||
print("Extracted", len(blocks), "blocks.")
|
||||
print("[LOG] Extracted", len(blocks), "blocks from URL:", url)
|
||||
return blocks
|
||||
|
||||
def _merge(self, documents):
|
||||
@@ -125,6 +128,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
"""
|
||||
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
||||
"""
|
||||
|
||||
merged_sections = self._merge(sections)
|
||||
parsed_json = []
|
||||
if self.provider.startswith("groq/"):
|
||||
@@ -144,7 +148,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
|
||||
return parsed_json
|
||||
|
||||
class CosinegStrategy(ExtractionStrategy):
|
||||
class CosineStrategy(ExtractionStrategy):
|
||||
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
|
||||
"""
|
||||
Initialize the strategy with clustering parameters.
|
||||
@@ -164,20 +168,13 @@ class CosinegStrategy(ExtractionStrategy):
|
||||
self.linkage_method = linkage_method
|
||||
self.top_k = top_k
|
||||
self.timer = time.time()
|
||||
|
||||
if model_name == "bert-base-uncased":
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
elif model_name == "sshleifer/distilbart-cnn-12-6":
|
||||
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
|
||||
pass
|
||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||
self.model.eval()
|
||||
|
||||
self.nlp = spacy.load("models/reuters")
|
||||
if model_name == "bert-base-uncased":
|
||||
self.tokenizer, self.model = load_bert_base_uncased()
|
||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||
|
||||
self.nlp = load_spacy_model()
|
||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||
|
||||
def get_embeddings(self, sentences: List[str]):
|
||||
|
||||
Reference in New Issue
Block a user