Feat/llm config (#724)

* feature: Add LlmConfig to easily configure and pass LLM configs to different strategies

* pulled in next branch and resolved conflicts

* feat: Add gemini and deepseek providers. Make ignore_cache in llm content filter to true by default to avoid confusions

* Refactor: Update LlmConfig in LLMExtractionStrategy class and deprecate old params

* updated tests, docs and readme
This commit is contained in:
Aravind
2025-02-21 13:11:37 +05:30
committed by GitHub
parent 3cb28875c3
commit 2af958e12c
25 changed files with 420 additions and 240 deletions

View File

@@ -1,13 +1,16 @@
import os
from .config import (
DEFAULT_PROVIDER,
MIN_WORD_THRESHOLD,
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
PROVIDER_MODELS,
SCREENSHOT_HEIGHT_TRESHOLD,
PAGE_TIMEOUT,
IMAGE_SCORE_THRESHOLD,
SOCIAL_MEDIA_DOMAINS,
)
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
from .extraction_strategy import ExtractionStrategy
from .chunking_strategy import ChunkingStrategy, RegexChunking
from .markdown_generation_strategy import MarkdownGenerationStrategy
@@ -19,7 +22,8 @@ from .proxy_strategy import ProxyRotationStrategy
import inspect
from typing import Any, Dict, Optional
from enum import Enum
from enum import Enum
def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
"""
@@ -28,26 +32,23 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
"""
if obj is None:
return None
# Handle basic types
if isinstance(obj, (str, int, float, bool)):
return obj
# Handle Enum
if isinstance(obj, Enum):
return {
"type": obj.__class__.__name__,
"params": obj.value
}
return {"type": obj.__class__.__name__, "params": obj.value}
# Handle datetime objects
if hasattr(obj, 'isoformat'):
if hasattr(obj, "isoformat"):
return obj.isoformat()
# Handle lists, tuples, and sets, and basically any iterable
if isinstance(obj, (list, tuple, set)) or hasattr(obj, '__iter__') and not isinstance(obj, dict):
return [to_serializable_dict(item) for item in obj]
# Handle frozensets, which are not iterable
if isinstance(obj, frozenset):
return [to_serializable_dict(item) for item in list(obj)]
@@ -56,25 +57,25 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
if isinstance(obj, dict):
return {
"type": "dict", # Mark as plain dictionary
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()},
}
_type = obj.__class__.__name__
# Handle class instances
if hasattr(obj, '__class__'):
if hasattr(obj, "__class__"):
# Get constructor signature
sig = inspect.signature(obj.__class__.__init__)
params = sig.parameters
# Get current values
current_values = {}
for name, param in params.items():
if name == 'self':
if name == "self":
continue
value = getattr(obj, name, param.default)
# Only include if different from default, considering empty values
if not (is_empty_value(value) and is_empty_value(param.default)):
if value != param.default and not ignore_default_value:
@@ -97,47 +98,50 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
return str(obj)
def from_serializable_dict(data: Any) -> Any:
"""
Recursively convert a serializable dictionary back to an object instance.
"""
if data is None:
return None
# Handle basic types
if isinstance(data, (str, int, float, bool)):
return data
# Handle typed data
if isinstance(data, dict) and "type" in data:
# Handle plain dictionaries
if data["type"] == "dict":
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
# Import from crawl4ai for class instances
import crawl4ai
cls = getattr(crawl4ai, data["type"])
# Handle Enum
if issubclass(cls, Enum):
return cls(data["params"])
# Handle class instances
constructor_args = {
k: from_serializable_dict(v) for k, v in data["params"].items()
}
return cls(**constructor_args)
# Handle lists
if isinstance(data, list):
return [from_serializable_dict(item) for item in data]
# Handle raw dictionaries (legacy support)
if isinstance(data, dict):
return {k: from_serializable_dict(v) for k, v in data.items()}
return data
def is_empty_value(value: Any) -> bool:
"""Check if a value is effectively empty/null."""
if value is None:
@@ -146,7 +150,8 @@ def is_empty_value(value: Any) -> bool:
return True
return False
class BrowserConfig():
class BrowserConfig:
"""
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
@@ -224,7 +229,7 @@ class BrowserConfig():
viewport: dict = None,
accept_downloads: bool = False,
downloads_path: str = None,
storage_state : Union[str, dict, None]=None,
storage_state: Union[str, dict, None] = None,
ignore_https_errors: bool = True,
java_script_enabled: bool = True,
sleep_on_close: bool = False,
@@ -288,7 +293,7 @@ class BrowserConfig():
)
else:
pass
self.browser_hint = UAGen.generate_client_hints(self.user_agent)
self.headers.setdefault("sec-ch-ua", self.browser_hint)
@@ -364,10 +369,10 @@ class BrowserConfig():
def clone(self, **kwargs):
"""Create a copy of this configuration with updated values.
Args:
**kwargs: Key-value pairs of configuration options to update
Returns:
BrowserConfig: A new instance with the specified updates
"""
@@ -381,24 +386,33 @@ class BrowserConfig():
return to_serializable_dict(self)
@staticmethod
def load( data: dict) -> "BrowserConfig":
def load(data: dict) -> "BrowserConfig":
# Deserialize the object from a dictionary
config = from_serializable_dict(data)
config = from_serializable_dict(data)
if isinstance(config, BrowserConfig):
return config
return BrowserConfig.from_kwargs(config)
class HTTPCrawlerConfig():
class HTTPCrawlerConfig:
"""HTTP-specific crawler configuration"""
method: str = "GET"
headers: Optional[Dict[str, str]] = None
data: Optional[Dict[str, Any]] = None
json: Optional[Dict[str, Any]] = None
json: Optional[Dict[str, Any]] = None
follow_redirects: bool = True
verify_ssl: bool = True
def __init__(self, method: str = "GET", headers: Optional[Dict[str, str]] = None, data: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, follow_redirects: bool = True, verify_ssl: bool = True):
def __init__(
self,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
follow_redirects: bool = True,
verify_ssl: bool = True,
):
self.method = method
self.headers = headers
self.data = data
@@ -426,23 +440,23 @@ class HTTPCrawlerConfig():
"follow_redirects": self.follow_redirects,
"verify_ssl": self.verify_ssl,
}
def clone(self, **kwargs):
"""Create a copy of this configuration with updated values.
Args:
**kwargs: Key-value pairs of configuration options to update
Returns:
HTTPCrawlerConfig: A new instance with the specified updates
"""
config_dict = self.to_dict()
config_dict.update(kwargs)
return HTTPCrawlerConfig.from_kwargs(config_dict)
def dump(self) -> dict:
return to_serializable_dict(self)
@staticmethod
def load(data: dict) -> "HTTPCrawlerConfig":
config = from_serializable_dict(data)
@@ -469,7 +483,7 @@ class CrawlerRunConfig():
Attributes:
# Deep Crawl Parameters
deep_crawl_strategy (DeepCrawlStrategy or None): Strategy to use for deep crawling.
# Content Processing Parameters
word_count_threshold (int): Minimum word count threshold before processing content.
Default: MIN_WORD_THRESHOLD (typically 200).
@@ -606,20 +620,20 @@ class CrawlerRunConfig():
data (dict): Data to send in the request body, when using AsyncHTTPCrwalerStrategy.
Default: None.
json (dict): JSON data to send in the request body, when using AsyncHTTPCrwalerStrategy.
# Connection Parameters
stream (bool): If True, enables streaming of crawled URLs as they are processed when used with arun_many.
Default: False.
check_robots_txt (bool): Whether to check robots.txt rules before crawling. Default: False
Default: False.
user_agent (str): Custom User-Agent string to use.
Default: False.
user_agent (str): Custom User-Agent string to use.
Default: None.
user_agent_mode (str or None): Mode for generating the user agent (e.g., "random"). If None, use the provided user_agent as-is.
user_agent_mode (str or None): Mode for generating the user agent (e.g., "random"). If None, use the provided user_agent as-is.
Default: None.
user_agent_generator_config (dict or None): Configuration for user agent generation if user_agent_mode is set.
Default: None.
url: str = None # This is not a compulsory parameter
"""
@@ -700,7 +714,6 @@ class CrawlerRunConfig():
user_agent_generator_config: dict = {},
# Deep Crawl Parameters
deep_crawl_strategy: Optional[DeepCrawlStrategy] = None,
):
# TODO: Planning to set properties dynamically based on the __init__ signature
self.url = url
@@ -810,7 +823,6 @@ class CrawlerRunConfig():
if self.chunking_strategy is None:
self.chunking_strategy = RegexChunking()
# Deep Crawl Parameters
self.deep_crawl_strategy = deep_crawl_strategy
@@ -918,7 +930,6 @@ class CrawlerRunConfig():
user_agent_generator_config=kwargs.get("user_agent_generator_config", {}),
# Deep Crawl Parameters
deep_crawl_strategy=kwargs.get("deep_crawl_strategy"),
url=kwargs.get("url"),
)
@@ -930,7 +941,7 @@ class CrawlerRunConfig():
@staticmethod
def load(data: dict) -> "CrawlerRunConfig":
# Deserialize the object from a dictionary
config = from_serializable_dict(data)
config = from_serializable_dict(data)
if isinstance(config, CrawlerRunConfig):
return config
return CrawlerRunConfig.from_kwargs(config)
@@ -1006,18 +1017,18 @@ class CrawlerRunConfig():
def clone(self, **kwargs):
"""Create a copy of this configuration with updated values.
Args:
**kwargs: Key-value pairs of configuration options to update
Returns:
CrawlerRunConfig: A new instance with the specified updates
Example:
```python
# Create a new config with streaming enabled
stream_config = config.clone(stream=True)
# Create a new config with multiple updates
new_config = config.clone(
stream=True,
@@ -1031,3 +1042,50 @@ class CrawlerRunConfig():
return CrawlerRunConfig.from_kwargs(config_dict)
class LlmConfig:
def __init__(
self,
provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
base_url: Optional[str] = None,
):
"""Configuaration class for LLM provider and API token."""
self.provider = provider
if api_token and not api_token.startswith("env:"):
self.api_token = api_token
elif api_token and api_token.startswith("env:"):
self.api_token = os.getenv(api_token[4:])
else:
self.api_token = PROVIDER_MODELS.get(provider, "no-token") or os.getenv(
"OPENAI_API_KEY"
)
self.base_url = base_url
@staticmethod
def from_kwargs(kwargs: dict) -> "LlmConfig":
return LlmConfig(
provider=kwargs.get("provider", DEFAULT_PROVIDER),
api_token=kwargs.get("api_token"),
base_url=kwargs.get("base_url"),
)
def to_dict(self):
return {
"provider": self.provider,
"api_token": self.api_token,
"base_url": self.base_url
}
def clone(self, **kwargs):
"""Create a copy of this configuration with updated values.
Args:
**kwargs: Key-value pairs of configuration options to update
Returns:
LLMConfig: A new instance with the specified updates
"""
config_dict = self.to_dict()
config_dict.update(kwargs)
return LlmConfig.from_kwargs(config_dict)

View File

@@ -21,6 +21,12 @@ PROVIDER_MODELS = {
"anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"),
"anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"),
"anthropic/claude-3-5-sonnet-20240620": os.getenv("ANTHROPIC_API_KEY"),
"gemini/gemini-pro": os.getenv("GEMINI_API_KEY"),
'gemini/gemini-1.5-pro': os.getenv("GEMINI_API_KEY"),
'gemini/gemini-2.0-flash': os.getenv("GEMINI_API_KEY"),
'gemini/gemini-2.0-flash-exp': os.getenv("GEMINI_API_KEY"),
'gemini/gemini-2.0-flash-lite-preview-02-05': os.getenv("GEMINI_API_KEY"),
"deepseek/deepseek-chat": os.getenv("DEEPSEEK_API_KEY"),
}
# Chunk token threshold

View File

@@ -1,3 +1,4 @@
import inspect
import re
import time
from bs4 import BeautifulSoup, Tag
@@ -5,7 +6,16 @@ from typing import List, Tuple, Dict, Optional
from rank_bm25 import BM25Okapi
from collections import deque
from bs4 import NavigableString, Comment
from .utils import clean_tokens, perform_completion_with_backoff, escape_json_string, sanitize_html, get_home_folder, extract_xml_data, merge_chunks
from .utils import (
clean_tokens,
perform_completion_with_backoff,
escape_json_string,
sanitize_html,
get_home_folder,
extract_xml_data,
merge_chunks,
)
from abc import ABC, abstractmethod
import math
from snowballstemmer import stemmer
@@ -20,10 +30,16 @@ from concurrent.futures import ThreadPoolExecutor
from .async_logger import AsyncLogger, LogLevel
from colorama import Fore, Style
class RelevantContentFilter(ABC):
"""Abstract base class for content filtering strategies"""
def __init__(self, user_query: str = None, verbose: bool = False, logger: Optional[AsyncLogger] = None):
def __init__(
self,
user_query: str = None,
verbose: bool = False,
logger: Optional[AsyncLogger] = None,
):
"""
Initializes the RelevantContentFilter class with optional user query.
@@ -362,6 +378,7 @@ class RelevantContentFilter(ABC):
except Exception:
return str(tag) # Fallback to original if anything fails
class BM25ContentFilter(RelevantContentFilter):
"""
Content filtering using BM25 algorithm with priority tag handling.
@@ -504,6 +521,7 @@ class BM25ContentFilter(RelevantContentFilter):
return [self.clean_element(tag) for _, _, tag in selected_candidates]
class PruningContentFilter(RelevantContentFilter):
"""
Content filtering using pruning algorithm with dynamic threshold.
@@ -750,13 +768,21 @@ class PruningContentFilter(RelevantContentFilter):
class_id_score -= 0.5
return class_id_score
class LLMContentFilter(RelevantContentFilter):
"""Content filtering using LLMs to generate relevant markdown."""
_UNWANTED_PROPS = {
'provider' : 'Instead, use llmConfig=LlmConfig(provider="...")',
'api_token' : 'Instead, use llmConfig=LlMConfig(api_token="...")',
'base_url' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
'api_base' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
}
def __init__(
self,
provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
llmConfig: "LlmConfig" = None,
instruction: str = None,
chunk_token_threshold: int = int(1e9),
overlap_rate: float = OVERLAP_RATE,
@@ -768,15 +794,13 @@ class LLMContentFilter(RelevantContentFilter):
# chunk_mode: str = "char",
verbose: bool = False,
logger: Optional[AsyncLogger] = None,
ignore_cache: bool = False,
ignore_cache: bool = True,
):
super().__init__(None)
self.provider = provider
self.api_token = (
api_token
or PROVIDER_MODELS.get(provider, "no-token")
or os.getenv("OPENAI_API_KEY")
)
self.api_token = api_token
self.base_url = base_url or api_base
self.llmConfig = llmConfig
self.instruction = instruction
self.chunk_token_threshold = chunk_token_threshold
self.overlap_rate = overlap_rate
@@ -785,12 +809,10 @@ class LLMContentFilter(RelevantContentFilter):
# self.char_token_rate = char_token_rate or word_token_rate / 5
# self.token_rate = word_token_rate if chunk_mode == "word" else self.char_token_rate
self.token_rate = word_token_rate or WORD_TOKEN_RATE
self.base_url = base_url
self.api_base = api_base or base_url
self.extra_args = extra_args or {}
self.ignore_cache = ignore_cache
self.verbose = verbose
# Setup logger with custom styling for LLM operations
if logger:
self.logger = logger
@@ -801,19 +823,31 @@ class LLMContentFilter(RelevantContentFilter):
**AsyncLogger.DEFAULT_ICONS,
"LLM": "", # Star for LLM operations
"CHUNK": "", # Diamond for chunks
"CACHE": "", # Lightning for cache operations
"CACHE": "", # Lightning for cache operations
},
colors={
**AsyncLogger.DEFAULT_COLORS,
LogLevel.INFO: Fore.MAGENTA + Style.DIM, # Dimmed purple for LLM ops
}
LogLevel.INFO: Fore.MAGENTA
+ Style.DIM, # Dimmed purple for LLM ops
},
)
else:
self.logger = None
self.usages = []
self.total_usage = TokenUsage()
def __setattr__(self, name, value):
"""Handle attribute setting."""
# TODO: Planning to set properties dynamically based on the __init__ signature
sig = inspect.signature(self.__init__)
all_params = sig.parameters # Dictionary of parameter names and their details
if name in self._UNWANTED_PROPS and value is not all_params[name].default:
raise AttributeError(f"Setting '{name}' is deprecated. {self._UNWANTED_PROPS[name]}")
super().__setattr__(name, value)
def _get_cache_key(self, html: str, instruction: str) -> str:
"""Generate a unique cache key based on HTML and instruction"""
content = f"{html}{instruction}"
@@ -823,14 +857,12 @@ class LLMContentFilter(RelevantContentFilter):
"""Split text into chunks with overlap using char or word mode."""
ov = int(self.chunk_token_threshold * self.overlap_rate)
sections = merge_chunks(
docs = [text],
target_size= self.chunk_token_threshold,
docs=[text],
target_size=self.chunk_token_threshold,
overlap=ov,
word_token_ratio=self.word_token_rate
word_token_ratio=self.word_token_rate,
)
return sections
def filter_content(self, html: str, ignore_cache: bool = True) -> List[str]:
if not html or not isinstance(html, str):
@@ -838,10 +870,10 @@ class LLMContentFilter(RelevantContentFilter):
if self.logger:
self.logger.info(
"Starting LLM markdown content filtering process",
"Starting LLM markdown content filtering process",
tag="LLM",
params={"provider": self.provider},
colors={"provider": Fore.CYAN}
params={"provider": self.llmConfig.provider},
colors={"provider": Fore.CYAN},
)
# Cache handling
@@ -857,47 +889,47 @@ class LLMContentFilter(RelevantContentFilter):
if self.logger:
self.logger.info("Found cached markdown result", tag="CACHE")
try:
with cache_file.open('r') as f:
with cache_file.open("r") as f:
cached_data = json.load(f)
usage = TokenUsage(**cached_data['usage'])
usage = TokenUsage(**cached_data["usage"])
self.usages.append(usage)
self.total_usage.completion_tokens += usage.completion_tokens
self.total_usage.prompt_tokens += usage.prompt_tokens
self.total_usage.total_tokens += usage.total_tokens
return cached_data['blocks']
return cached_data["blocks"]
except Exception as e:
if self.logger:
self.logger.error(f"LLM markdown: Cache read error: {str(e)}", tag="CACHE")
self.logger.error(
f"LLM markdown: Cache read error: {str(e)}", tag="CACHE"
)
# Split into chunks
html_chunks = self._merge_chunks(html)
if self.logger:
self.logger.info(
"LLM markdown: Split content into {chunk_count} chunks",
"LLM markdown: Split content into {chunk_count} chunks",
tag="CHUNK",
params={"chunk_count": len(html_chunks)},
colors={"chunk_count": Fore.YELLOW}
colors={"chunk_count": Fore.YELLOW},
)
start_time = time.time()
# Process chunks in parallel
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for i, chunk in enumerate(html_chunks):
if self.logger:
self.logger.debug(
"LLM markdown: Processing chunk {chunk_num}/{total_chunks}",
"LLM markdown: Processing chunk {chunk_num}/{total_chunks}",
tag="CHUNK",
params={
"chunk_num": i + 1,
"total_chunks": len(html_chunks)
}
params={"chunk_num": i + 1, "total_chunks": len(html_chunks)},
)
prompt_variables = {
"HTML": escape_json_string(sanitize_html(chunk)),
"REQUEST": self.instruction or "Convert this HTML into clean, relevant markdown, removing any noise or irrelevant content."
"REQUEST": self.instruction
or "Convert this HTML into clean, relevant markdown, removing any noise or irrelevant content.",
}
prompt = PROMPT_FILTER_CONTENT
@@ -905,95 +937,96 @@ class LLMContentFilter(RelevantContentFilter):
prompt = prompt.replace("{" + var + "}", value)
def _proceed_with_chunk(
provider: str,
prompt: str,
api_token: str,
base_url: Optional[str] = None,
extra_args: Dict = {}
) -> List[str]:
provider: str,
prompt: str,
api_token: str,
base_url: Optional[str] = None,
extra_args: Dict = {},
) -> List[str]:
if self.logger:
self.logger.info(
"LLM Markdown: Processing chunk {chunk_num}",
"LLM Markdown: Processing chunk {chunk_num}",
tag="CHUNK",
params={"chunk_num": i + 1}
params={"chunk_num": i + 1},
)
return perform_completion_with_backoff(
provider,
prompt,
api_token,
base_url=base_url,
extra_args=extra_args
extra_args=extra_args,
)
future = executor.submit(
_proceed_with_chunk,
self.provider,
self.llmConfig.provider,
prompt,
self.api_token,
self.api_base,
self.extra_args
self.llmConfig.api_token,
self.llmConfig.base_url,
self.extra_args,
)
futures.append((i, future))
# Collect results in order
ordered_results = []
for i, future in sorted(futures):
try:
response = future.result()
# Track usage
usage = TokenUsage(
completion_tokens=response.usage.completion_tokens,
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
completion_tokens_details=response.usage.completion_tokens_details.__dict__
if response.usage.completion_tokens_details else {},
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__
if response.usage.prompt_tokens_details else {},
completion_tokens_details=(
response.usage.completion_tokens_details.__dict__
if response.usage.completion_tokens_details
else {}
),
prompt_tokens_details=(
response.usage.prompt_tokens_details.__dict__
if response.usage.prompt_tokens_details
else {}
),
)
self.usages.append(usage)
self.total_usage.completion_tokens += usage.completion_tokens
self.total_usage.prompt_tokens += usage.prompt_tokens
self.total_usage.total_tokens += usage.total_tokens
blocks = extract_xml_data(["content"], response.choices[0].message.content)["content"]
blocks = extract_xml_data(
["content"], response.choices[0].message.content
)["content"]
if blocks:
ordered_results.append(blocks)
if self.logger:
self.logger.success(
"LLM markdown: Successfully processed chunk {chunk_num}",
"LLM markdown: Successfully processed chunk {chunk_num}",
tag="CHUNK",
params={"chunk_num": i + 1}
params={"chunk_num": i + 1},
)
except Exception as e:
if self.logger:
self.logger.error(
"LLM markdown: Error processing chunk {chunk_num}: {error}",
"LLM markdown: Error processing chunk {chunk_num}: {error}",
tag="CHUNK",
params={
"chunk_num": i + 1,
"error": str(e)
}
params={"chunk_num": i + 1, "error": str(e)},
)
end_time = time.time()
if self.logger:
self.logger.success(
"LLM markdown: Completed processing in {time:.2f}s",
"LLM markdown: Completed processing in {time:.2f}s",
tag="LLM",
params={"time": end_time - start_time},
colors={"time": Fore.YELLOW}
colors={"time": Fore.YELLOW},
)
result = ordered_results if ordered_results else []
# Cache the final result
cache_data = {
'blocks': result,
'usage': self.total_usage.__dict__
}
with cache_file.open('w') as f:
cache_data = {"blocks": result, "usage": self.total_usage.__dict__}
with cache_file.open("w") as f:
json.dump(cache_data, f)
if self.logger:
self.logger.info("Cached results for future use", tag="CACHE")
@@ -1017,4 +1050,4 @@ class LLMContentFilter(RelevantContentFilter):
print(
f"{i:<10} {usage.completion_tokens:>12,} "
f"{usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
)
)

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import inspect
from typing import Any, List, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
@@ -496,20 +497,26 @@ class LLMExtractionStrategy(ExtractionStrategy):
usages: List of individual token usages.
total_usage: Accumulated token usage.
"""
_UNWANTED_PROPS = {
'provider' : 'Instead, use llmConfig=LlmConfig(provider="...")',
'api_token' : 'Instead, use llmConfig=LlMConfig(api_token="...")',
'base_url' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
'api_base' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
}
def __init__(
self,
llmConfig: 'LLMConfig' = None,
instruction: str = None,
provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
instruction: str = None,
base_url: str = None,
api_base: str = None,
schema: Dict = None,
extraction_type="block",
chunk_token_threshold=CHUNK_TOKEN_THRESHOLD,
overlap_rate=OVERLAP_RATE,
word_token_rate=WORD_TOKEN_RATE,
apply_chunking=True,
api_base: str =None,
base_url: str =None,
input_format: str = "markdown",
verbose=False,
**kwargs,
@@ -518,6 +525,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
Initialize the strategy with clustering parameters.
Args:
llmConfig: The LLM configuration object.
provider: The provider to use for extraction. It follows the format <provider_name>/<model_name>, e.g., "ollama/llama3.3".
api_token: The API token for the provider.
instruction: The instruction to use for the LLM model.
@@ -536,41 +544,39 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
super().__init__( input_format=input_format, **kwargs)
self.llmConfig = llmConfig
self.provider = provider
if api_token and not api_token.startswith("env:"):
self.api_token = api_token
elif api_token and api_token.startswith("env:"):
self.api_token = os.getenv(api_token[4:])
else:
self.api_token = (
PROVIDER_MODELS.get(provider, "no-token")
or os.getenv("OPENAI_API_KEY")
)
self.api_token = api_token
self.base_url = base_url
self.api_base = api_base
self.instruction = instruction
self.extract_type = extraction_type
self.schema = schema
if schema:
self.extract_type = "schema"
self.chunk_token_threshold = chunk_token_threshold or CHUNK_TOKEN_THRESHOLD
self.overlap_rate = overlap_rate
self.word_token_rate = word_token_rate
self.apply_chunking = apply_chunking
self.base_url = base_url
self.api_base = api_base or base_url
self.extra_args = kwargs.get("extra_args", {})
if not self.apply_chunking:
self.chunk_token_threshold = 1e9
self.verbose = verbose
self.usages = [] # Store individual usages
self.total_usage = TokenUsage() # Accumulated usage
if not self.api_token:
raise ValueError(
"API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable."
)
def __setattr__(self, name, value):
"""Handle attribute setting."""
# TODO: Planning to set properties dynamically based on the __init__ signature
sig = inspect.signature(self.__init__)
all_params = sig.parameters # Dictionary of parameter names and their details
if name in self._UNWANTED_PROPS and value is not all_params[name].default:
raise AttributeError(f"Setting '{name}' is deprecated. {self._UNWANTED_PROPS[name]}")
super().__setattr__(name, value)
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML using an LLM.
@@ -603,7 +609,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
if self.extract_type == "schema" and self.schema:
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2)
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2) # if type of self.schema is dict else self.schema
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
for variable in variable_values:
@@ -612,10 +618,10 @@ class LLMExtractionStrategy(ExtractionStrategy):
)
response = perform_completion_with_backoff(
self.provider,
self.llmConfig.provider,
prompt_with_variables,
self.api_token,
base_url=self.api_base or self.base_url,
self.llmConfig.api_token,
base_url=self.llmConfig.base_url,
extra_args=self.extra_args,
) # , json_response=self.extract_type == "schema")
# Track usage
@@ -695,7 +701,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
overlap=int(self.chunk_token_threshold * self.overlap_rate),
)
extracted_content = []
if self.provider.startswith("groq/"):
if self.llmConfig.provider.startswith("groq/"):
# Sequential processing with a delay
for ix, section in enumerate(merged_sections):
extract_func = partial(self.extract, url)
@@ -1036,14 +1042,20 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""Get attribute value from element"""
pass
_GENERATE_SCHEMA_UNWANTED_PROPS = {
'provider': 'Instead, use llmConfig=LlmConfig(provider="...")',
'api_token': 'Instead, use llmConfig=LlMConfig(api_token="...")',
}
@staticmethod
def generate_schema(
html: str,
schema_type: str = "CSS", # or XPATH
query: str = None,
target_json_example: str = None,
provider: str = "gpt-4o",
api_token: str = os.getenv("OPENAI_API_KEY"),
llmConfig: 'LLMConfig' = None,
provider: str = None,
api_token: str = None,
**kwargs
) -> dict:
"""
@@ -1052,8 +1064,9 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
Args:
html (str): The HTML content to analyze
query (str, optional): Natural language description of what data to extract
provider (str): LLM provider to use
api_token (str): API token for LLM provider
provider (str): Legacy Parameter. LLM provider to use
api_token (str): Legacy Parameter. API token for LLM provider
llmConfig (LlmConfig): LLM configuration object
prompt (str, optional): Custom prompt template to use
**kwargs: Additional args passed to perform_completion_with_backoff
@@ -1062,6 +1075,9 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""
from .prompts import JSON_SCHEMA_BUILDER
from .utils import perform_completion_with_backoff
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
if locals()[name] is not None:
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
# Use default or custom prompt
prompt_template = JSON_SCHEMA_BUILDER if schema_type == "CSS" else JSON_SCHEMA_BUILDER_XPATH
@@ -1114,10 +1130,10 @@ In this scenario, use your best judgment to generate the schema. Try to maximize
try:
# Call LLM with backoff handling
response = perform_completion_with_backoff(
provider=provider,
provider=llmConfig.provider,
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
json_response = True,
api_token=api_token,
api_token=llmConfig.api_token,
**kwargs
)