Feat/llm config (#724)

* feature: Add LlmConfig to easily configure and pass LLM configs to different strategies

* pulled in next branch and resolved conflicts

* feat: Add gemini and deepseek providers. Make ignore_cache in llm content filter to true by default to avoid confusions

* Refactor: Update LlmConfig in LLMExtractionStrategy class and deprecate old params

* updated tests, docs and readme
This commit is contained in:
Aravind
2025-02-21 13:11:37 +05:30
committed by GitHub
parent 3cb28875c3
commit 2af958e12c
25 changed files with 420 additions and 240 deletions

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import inspect
from typing import Any, List, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
@@ -496,20 +497,26 @@ class LLMExtractionStrategy(ExtractionStrategy):
usages: List of individual token usages.
total_usage: Accumulated token usage.
"""
_UNWANTED_PROPS = {
'provider' : 'Instead, use llmConfig=LlmConfig(provider="...")',
'api_token' : 'Instead, use llmConfig=LlMConfig(api_token="...")',
'base_url' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
'api_base' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
}
def __init__(
self,
llmConfig: 'LLMConfig' = None,
instruction: str = None,
provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
instruction: str = None,
base_url: str = None,
api_base: str = None,
schema: Dict = None,
extraction_type="block",
chunk_token_threshold=CHUNK_TOKEN_THRESHOLD,
overlap_rate=OVERLAP_RATE,
word_token_rate=WORD_TOKEN_RATE,
apply_chunking=True,
api_base: str =None,
base_url: str =None,
input_format: str = "markdown",
verbose=False,
**kwargs,
@@ -518,6 +525,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
Initialize the strategy with clustering parameters.
Args:
llmConfig: The LLM configuration object.
provider: The provider to use for extraction. It follows the format <provider_name>/<model_name>, e.g., "ollama/llama3.3".
api_token: The API token for the provider.
instruction: The instruction to use for the LLM model.
@@ -536,41 +544,39 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
super().__init__( input_format=input_format, **kwargs)
self.llmConfig = llmConfig
self.provider = provider
if api_token and not api_token.startswith("env:"):
self.api_token = api_token
elif api_token and api_token.startswith("env:"):
self.api_token = os.getenv(api_token[4:])
else:
self.api_token = (
PROVIDER_MODELS.get(provider, "no-token")
or os.getenv("OPENAI_API_KEY")
)
self.api_token = api_token
self.base_url = base_url
self.api_base = api_base
self.instruction = instruction
self.extract_type = extraction_type
self.schema = schema
if schema:
self.extract_type = "schema"
self.chunk_token_threshold = chunk_token_threshold or CHUNK_TOKEN_THRESHOLD
self.overlap_rate = overlap_rate
self.word_token_rate = word_token_rate
self.apply_chunking = apply_chunking
self.base_url = base_url
self.api_base = api_base or base_url
self.extra_args = kwargs.get("extra_args", {})
if not self.apply_chunking:
self.chunk_token_threshold = 1e9
self.verbose = verbose
self.usages = [] # Store individual usages
self.total_usage = TokenUsage() # Accumulated usage
if not self.api_token:
raise ValueError(
"API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable."
)
def __setattr__(self, name, value):
"""Handle attribute setting."""
# TODO: Planning to set properties dynamically based on the __init__ signature
sig = inspect.signature(self.__init__)
all_params = sig.parameters # Dictionary of parameter names and their details
if name in self._UNWANTED_PROPS and value is not all_params[name].default:
raise AttributeError(f"Setting '{name}' is deprecated. {self._UNWANTED_PROPS[name]}")
super().__setattr__(name, value)
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML using an LLM.
@@ -603,7 +609,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
if self.extract_type == "schema" and self.schema:
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2)
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2) # if type of self.schema is dict else self.schema
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
for variable in variable_values:
@@ -612,10 +618,10 @@ class LLMExtractionStrategy(ExtractionStrategy):
)
response = perform_completion_with_backoff(
self.provider,
self.llmConfig.provider,
prompt_with_variables,
self.api_token,
base_url=self.api_base or self.base_url,
self.llmConfig.api_token,
base_url=self.llmConfig.base_url,
extra_args=self.extra_args,
) # , json_response=self.extract_type == "schema")
# Track usage
@@ -695,7 +701,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
overlap=int(self.chunk_token_threshold * self.overlap_rate),
)
extracted_content = []
if self.provider.startswith("groq/"):
if self.llmConfig.provider.startswith("groq/"):
# Sequential processing with a delay
for ix, section in enumerate(merged_sections):
extract_func = partial(self.extract, url)
@@ -1036,14 +1042,20 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""Get attribute value from element"""
pass
_GENERATE_SCHEMA_UNWANTED_PROPS = {
'provider': 'Instead, use llmConfig=LlmConfig(provider="...")',
'api_token': 'Instead, use llmConfig=LlMConfig(api_token="...")',
}
@staticmethod
def generate_schema(
html: str,
schema_type: str = "CSS", # or XPATH
query: str = None,
target_json_example: str = None,
provider: str = "gpt-4o",
api_token: str = os.getenv("OPENAI_API_KEY"),
llmConfig: 'LLMConfig' = None,
provider: str = None,
api_token: str = None,
**kwargs
) -> dict:
"""
@@ -1052,8 +1064,9 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
Args:
html (str): The HTML content to analyze
query (str, optional): Natural language description of what data to extract
provider (str): LLM provider to use
api_token (str): API token for LLM provider
provider (str): Legacy Parameter. LLM provider to use
api_token (str): Legacy Parameter. API token for LLM provider
llmConfig (LlmConfig): LLM configuration object
prompt (str, optional): Custom prompt template to use
**kwargs: Additional args passed to perform_completion_with_backoff
@@ -1062,6 +1075,9 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""
from .prompts import JSON_SCHEMA_BUILDER
from .utils import perform_completion_with_backoff
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
if locals()[name] is not None:
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
# Use default or custom prompt
prompt_template = JSON_SCHEMA_BUILDER if schema_type == "CSS" else JSON_SCHEMA_BUILDER_XPATH
@@ -1114,10 +1130,10 @@ In this scenario, use your best judgment to generate the schema. Try to maximize
try:
# Call LLM with backoff handling
response = perform_completion_with_backoff(
provider=provider,
provider=llmConfig.provider,
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
json_response = True,
api_token=api_token,
api_token=llmConfig.api_token,
**kwargs
)