chore: Update configuration values for chunk token threshold, overlap rate, and minimum word threshold. Create a new example for LLMExtraction Strategy, update Dockerfile, and README

This commit is contained in:
unclecode
2024-06-19 18:32:20 +08:00
parent 3f0e265baf
commit 539263a8ba
11 changed files with 212 additions and 130 deletions

View File

@@ -3,12 +3,12 @@ from typing import Any, List, Dict, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
import json, time
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS, PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
from .prompts import *
from .config import *
from .utils import *
from functools import partial
from .model_loader import *
import math
import numpy as np
class ExtractionStrategy(ABC):
@@ -55,7 +55,9 @@ class NoExtractionStrategy(ExtractionStrategy):
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
class LLMExtractionStrategy(ExtractionStrategy):
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, instruction:str = None, **kwargs):
def __init__(self,
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None,
instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs):
"""
Initialize the strategy with clustering parameters.
@@ -67,6 +69,13 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY")
self.instruction = instruction
self.extract_type = extraction_type
self.schema = schema
self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD)
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
self.verbose = kwargs.get("verbose", False)
if not self.api_token:
@@ -81,10 +90,15 @@ class LLMExtractionStrategy(ExtractionStrategy):
"HTML": escape_json_string(sanitize_html(html)),
}
prompt_with_variables = PROMPT_EXTRACT_BLOCKS
if self.instruction:
variable_values["REQUEST"] = self.instruction
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
if self.extract_type == "schema":
variable_values["SCHEMA"] = json.dumps(self.schema)
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
prompt_with_variables = PROMPT_EXTRACT_BLOCKS if not self.instruction else PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
for variable in variable_values:
prompt_with_variables = prompt_with_variables.replace(
"{" + variable + "}", variable_values[variable]
@@ -112,32 +126,62 @@ class LLMExtractionStrategy(ExtractionStrategy):
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix)
return blocks
def _merge(self, documents):
def _merge(self, documents, chunk_token_threshold, overlap):
chunks = []
sections = []
total_tokens = 0
# Calculate the total tokens across all documents
for document in documents:
total_tokens += len(document.split(' ')) * self.word_token_rate
# Calculate the number of sections needed
num_sections = math.floor(total_tokens / chunk_token_threshold)
if num_sections < 1:
num_sections = 1 # Ensure there is at least one section
adjusted_chunk_threshold = total_tokens / num_sections
total_token_so_far = 0
current_chunk = []
for document in documents:
if total_token_so_far < CHUNK_TOKEN_THRESHOLD:
chunk = document.split(' ')
total_token_so_far += len(chunk) * 1.3
chunks.append(document)
else:
sections.append('\n\n'.join(chunks))
chunks = [document]
total_token_so_far = len(document.split(' ')) * 1.3
if chunks:
sections.append('\n\n'.join(chunks))
tokens = document.split(' ')
token_count = len(tokens) * self.word_token_rate
return sections
if total_token_so_far + token_count <= adjusted_chunk_threshold:
current_chunk.extend(tokens)
total_token_so_far += token_count
else:
# Ensure to handle the last section properly
if len(sections) == num_sections - 1:
current_chunk.extend(tokens)
continue
# Add overlap if specified
if overlap > 0 and current_chunk:
overlap_tokens = current_chunk[-overlap:]
current_chunk.extend(overlap_tokens)
sections.append(' '.join(current_chunk))
current_chunk = tokens
total_token_so_far = token_count
# Add the last chunk
if current_chunk:
sections.append(' '.join(current_chunk))
return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
"""
merged_sections = self._merge(sections)
merged_sections = self._merge(
sections, self.chunk_token_threshold,
overlap= int(self.chunk_token_threshold * self.overlap_rate)
)
extracted_content = []
if self.provider.startswith("groq/"):
# Sequential processing with a delay