refactor: Update extraction strategy to handle schema extraction with non-empty schema

This code change updates the `LLMExtractionStrategy` class to handle schema extraction when the schema is non-empty. Previously, the schema extraction was only triggered when the `extract_type` was set to "schema", regardless of whether a schema was provided. With this update, the schema extraction will only be performed if the `extract_type` is "schema" and a non-empty schema is provided. This ensures that the extraction strategy behaves correctly and avoids unnecessary schema extraction when not needed. Also "numpy" is removed from default installation mode.
This commit is contained in:
unclecode
2024-08-19 15:37:07 +08:00
parent e5e6a34e80
commit dec3d44224
3 changed files with 2 additions and 5 deletions

View File

@@ -101,7 +101,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
variable_values["REQUEST"] = self.instruction
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
if self.extract_type == "schema":
if self.extract_type == "schema" and self.schema:
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2)
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION

View File

@@ -834,7 +834,6 @@ def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_toke
return sum(all_blocks, [])
def merge_chunks_based_on_token_threshold(chunks, token_threshold):
"""
Merges small chunks into larger ones based on the total token threshold.
@@ -880,7 +879,6 @@ def process_sections(url: str, sections: list, provider: str, api_token: str) ->
return extracted_content
def wrap_text(draw, text, font, max_width):
# Wrap the text to fit within the specified width
lines = []
@@ -892,7 +890,6 @@ def wrap_text(draw, text, font, max_width):
lines.append(line)
return '\n'.join(lines)
def format_html(html_string):
soup = BeautifulSoup(html_string, 'html.parser')
return soup.prettify()

View File

@@ -19,7 +19,7 @@ with open("requirements.txt") as f:
requirements = f.read().splitlines()
# Define the requirements for different environments
default_requirements = [req for req in requirements if not req.startswith(("torch", "transformers", "onnxruntime", "nltk", "spacy", "tokenizers", "scikit-learn", "numpy"))]
default_requirements = [req for req in requirements if not req.startswith(("torch", "transformers", "onnxruntime", "nltk", "spacy", "tokenizers", "scikit-learn"))]
torch_requirements = [req for req in requirements if req.startswith(("torch", "nltk", "spacy", "scikit-learn", "numpy"))]
transformer_requirements = [req for req in requirements if req.startswith(("transformers", "tokenizers", "onnxruntime"))]