chore: Update configuration values for chunk token threshold, overlap rate, and minimum word threshold. Create a new example for LLMExtraction Strategy, update Dockerfile, and README
This commit is contained in:
@@ -3,12 +3,12 @@ from typing import Any, List, Dict, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import json, time
|
||||
# from optimum.intel import IPEXModel
|
||||
from .prompts import PROMPT_EXTRACT_BLOCKS, PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
|
||||
from .prompts import *
|
||||
from .config import *
|
||||
from .utils import *
|
||||
from functools import partial
|
||||
from .model_loader import *
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
class ExtractionStrategy(ABC):
|
||||
@@ -55,7 +55,9 @@ class NoExtractionStrategy(ExtractionStrategy):
|
||||
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
|
||||
|
||||
class LLMExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, instruction:str = None, **kwargs):
|
||||
def __init__(self,
|
||||
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None,
|
||||
instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs):
|
||||
"""
|
||||
Initialize the strategy with clustering parameters.
|
||||
|
||||
@@ -67,6 +69,13 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
self.provider = provider
|
||||
self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY")
|
||||
self.instruction = instruction
|
||||
self.extract_type = extraction_type
|
||||
self.schema = schema
|
||||
|
||||
self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD)
|
||||
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
|
||||
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
|
||||
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
if not self.api_token:
|
||||
@@ -81,10 +90,15 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
"HTML": escape_json_string(sanitize_html(html)),
|
||||
}
|
||||
|
||||
prompt_with_variables = PROMPT_EXTRACT_BLOCKS
|
||||
if self.instruction:
|
||||
variable_values["REQUEST"] = self.instruction
|
||||
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
|
||||
|
||||
if self.extract_type == "schema":
|
||||
variable_values["SCHEMA"] = json.dumps(self.schema)
|
||||
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
|
||||
|
||||
prompt_with_variables = PROMPT_EXTRACT_BLOCKS if not self.instruction else PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
|
||||
for variable in variable_values:
|
||||
prompt_with_variables = prompt_with_variables.replace(
|
||||
"{" + variable + "}", variable_values[variable]
|
||||
@@ -112,32 +126,62 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix)
|
||||
return blocks
|
||||
|
||||
def _merge(self, documents):
|
||||
def _merge(self, documents, chunk_token_threshold, overlap):
|
||||
chunks = []
|
||||
sections = []
|
||||
total_tokens = 0
|
||||
|
||||
# Calculate the total tokens across all documents
|
||||
for document in documents:
|
||||
total_tokens += len(document.split(' ')) * self.word_token_rate
|
||||
|
||||
# Calculate the number of sections needed
|
||||
num_sections = math.floor(total_tokens / chunk_token_threshold)
|
||||
if num_sections < 1:
|
||||
num_sections = 1 # Ensure there is at least one section
|
||||
adjusted_chunk_threshold = total_tokens / num_sections
|
||||
|
||||
total_token_so_far = 0
|
||||
current_chunk = []
|
||||
|
||||
for document in documents:
|
||||
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
|
||||
chunk = document.split(' ')
|
||||
total_token_so_far += len(chunk) * 1.3
|
||||
chunks.append(document)
|
||||
else:
|
||||
sections.append('\n\n'.join(chunks))
|
||||
chunks = [document]
|
||||
total_token_so_far = len(document.split(' ')) * 1.3
|
||||
|
||||
if chunks:
|
||||
sections.append('\n\n'.join(chunks))
|
||||
tokens = document.split(' ')
|
||||
token_count = len(tokens) * self.word_token_rate
|
||||
|
||||
return sections
|
||||
if total_token_so_far + token_count <= adjusted_chunk_threshold:
|
||||
current_chunk.extend(tokens)
|
||||
total_token_so_far += token_count
|
||||
else:
|
||||
# Ensure to handle the last section properly
|
||||
if len(sections) == num_sections - 1:
|
||||
current_chunk.extend(tokens)
|
||||
continue
|
||||
|
||||
# Add overlap if specified
|
||||
if overlap > 0 and current_chunk:
|
||||
overlap_tokens = current_chunk[-overlap:]
|
||||
current_chunk.extend(overlap_tokens)
|
||||
|
||||
sections.append(' '.join(current_chunk))
|
||||
current_chunk = tokens
|
||||
total_token_so_far = token_count
|
||||
|
||||
# Add the last chunk
|
||||
if current_chunk:
|
||||
sections.append(' '.join(current_chunk))
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
||||
"""
|
||||
|
||||
merged_sections = self._merge(sections)
|
||||
merged_sections = self._merge(
|
||||
sections, self.chunk_token_threshold,
|
||||
overlap= int(self.chunk_token_threshold * self.overlap_rate)
|
||||
)
|
||||
extracted_content = []
|
||||
if self.provider.startswith("groq/"):
|
||||
# Sequential processing with a delay
|
||||
|
||||
Reference in New Issue
Block a user