Feat/llm config (#724)
* feature: Add LlmConfig to easily configure and pass LLM configs to different strategies * pulled in next branch and resolved conflicts * feat: Add gemini and deepseek providers. Make ignore_cache in llm content filter to true by default to avoid confusions * Refactor: Update LlmConfig in LLMExtractionStrategy class and deprecate old params * updated tests, docs and readme
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import os
|
||||
from .config import (
|
||||
DEFAULT_PROVIDER,
|
||||
MIN_WORD_THRESHOLD,
|
||||
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||
PROVIDER_MODELS,
|
||||
SCREENSHOT_HEIGHT_TRESHOLD,
|
||||
PAGE_TIMEOUT,
|
||||
IMAGE_SCORE_THRESHOLD,
|
||||
SOCIAL_MEDIA_DOMAINS,
|
||||
)
|
||||
|
||||
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
|
||||
from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator
|
||||
from .extraction_strategy import ExtractionStrategy
|
||||
from .chunking_strategy import ChunkingStrategy, RegexChunking
|
||||
from .markdown_generation_strategy import MarkdownGenerationStrategy
|
||||
@@ -19,7 +22,8 @@ from .proxy_strategy import ProxyRotationStrategy
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
from enum import Enum
|
||||
|
||||
|
||||
def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
|
||||
"""
|
||||
@@ -28,26 +32,23 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
|
||||
# Handle basic types
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
|
||||
|
||||
# Handle Enum
|
||||
if isinstance(obj, Enum):
|
||||
return {
|
||||
"type": obj.__class__.__name__,
|
||||
"params": obj.value
|
||||
}
|
||||
|
||||
return {"type": obj.__class__.__name__, "params": obj.value}
|
||||
|
||||
# Handle datetime objects
|
||||
if hasattr(obj, 'isoformat'):
|
||||
if hasattr(obj, "isoformat"):
|
||||
return obj.isoformat()
|
||||
|
||||
|
||||
# Handle lists, tuples, and sets, and basically any iterable
|
||||
if isinstance(obj, (list, tuple, set)) or hasattr(obj, '__iter__') and not isinstance(obj, dict):
|
||||
return [to_serializable_dict(item) for item in obj]
|
||||
|
||||
|
||||
# Handle frozensets, which are not iterable
|
||||
if isinstance(obj, frozenset):
|
||||
return [to_serializable_dict(item) for item in list(obj)]
|
||||
@@ -56,25 +57,25 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
"type": "dict", # Mark as plain dictionary
|
||||
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
|
||||
"value": {str(k): to_serializable_dict(v) for k, v in obj.items()},
|
||||
}
|
||||
|
||||
_type = obj.__class__.__name__
|
||||
|
||||
# Handle class instances
|
||||
if hasattr(obj, '__class__'):
|
||||
if hasattr(obj, "__class__"):
|
||||
# Get constructor signature
|
||||
sig = inspect.signature(obj.__class__.__init__)
|
||||
params = sig.parameters
|
||||
|
||||
|
||||
# Get current values
|
||||
current_values = {}
|
||||
for name, param in params.items():
|
||||
if name == 'self':
|
||||
if name == "self":
|
||||
continue
|
||||
|
||||
|
||||
value = getattr(obj, name, param.default)
|
||||
|
||||
|
||||
# Only include if different from default, considering empty values
|
||||
if not (is_empty_value(value) and is_empty_value(param.default)):
|
||||
if value != param.default and not ignore_default_value:
|
||||
@@ -97,47 +98,50 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict:
|
||||
|
||||
return str(obj)
|
||||
|
||||
|
||||
def from_serializable_dict(data: Any) -> Any:
|
||||
"""
|
||||
Recursively convert a serializable dictionary back to an object instance.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
|
||||
# Handle basic types
|
||||
if isinstance(data, (str, int, float, bool)):
|
||||
return data
|
||||
|
||||
|
||||
# Handle typed data
|
||||
if isinstance(data, dict) and "type" in data:
|
||||
# Handle plain dictionaries
|
||||
if data["type"] == "dict":
|
||||
return {k: from_serializable_dict(v) for k, v in data["value"].items()}
|
||||
|
||||
|
||||
# Import from crawl4ai for class instances
|
||||
import crawl4ai
|
||||
|
||||
cls = getattr(crawl4ai, data["type"])
|
||||
|
||||
|
||||
# Handle Enum
|
||||
if issubclass(cls, Enum):
|
||||
return cls(data["params"])
|
||||
|
||||
|
||||
# Handle class instances
|
||||
constructor_args = {
|
||||
k: from_serializable_dict(v) for k, v in data["params"].items()
|
||||
}
|
||||
return cls(**constructor_args)
|
||||
|
||||
|
||||
# Handle lists
|
||||
if isinstance(data, list):
|
||||
return [from_serializable_dict(item) for item in data]
|
||||
|
||||
|
||||
# Handle raw dictionaries (legacy support)
|
||||
if isinstance(data, dict):
|
||||
return {k: from_serializable_dict(v) for k, v in data.items()}
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
def is_empty_value(value: Any) -> bool:
|
||||
"""Check if a value is effectively empty/null."""
|
||||
if value is None:
|
||||
@@ -146,7 +150,8 @@ def is_empty_value(value: Any) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
class BrowserConfig():
|
||||
|
||||
class BrowserConfig:
|
||||
"""
|
||||
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
|
||||
|
||||
@@ -224,7 +229,7 @@ class BrowserConfig():
|
||||
viewport: dict = None,
|
||||
accept_downloads: bool = False,
|
||||
downloads_path: str = None,
|
||||
storage_state : Union[str, dict, None]=None,
|
||||
storage_state: Union[str, dict, None] = None,
|
||||
ignore_https_errors: bool = True,
|
||||
java_script_enabled: bool = True,
|
||||
sleep_on_close: bool = False,
|
||||
@@ -288,7 +293,7 @@ class BrowserConfig():
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
self.browser_hint = UAGen.generate_client_hints(self.user_agent)
|
||||
self.headers.setdefault("sec-ch-ua", self.browser_hint)
|
||||
|
||||
@@ -364,10 +369,10 @@ class BrowserConfig():
|
||||
|
||||
def clone(self, **kwargs):
|
||||
"""Create a copy of this configuration with updated values.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Key-value pairs of configuration options to update
|
||||
|
||||
|
||||
Returns:
|
||||
BrowserConfig: A new instance with the specified updates
|
||||
"""
|
||||
@@ -381,24 +386,33 @@ class BrowserConfig():
|
||||
return to_serializable_dict(self)
|
||||
|
||||
@staticmethod
|
||||
def load( data: dict) -> "BrowserConfig":
|
||||
def load(data: dict) -> "BrowserConfig":
|
||||
# Deserialize the object from a dictionary
|
||||
config = from_serializable_dict(data)
|
||||
config = from_serializable_dict(data)
|
||||
if isinstance(config, BrowserConfig):
|
||||
return config
|
||||
return BrowserConfig.from_kwargs(config)
|
||||
|
||||
|
||||
class HTTPCrawlerConfig():
|
||||
class HTTPCrawlerConfig:
|
||||
"""HTTP-specific crawler configuration"""
|
||||
|
||||
method: str = "GET"
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
json: Optional[Dict[str, Any]] = None
|
||||
json: Optional[Dict[str, Any]] = None
|
||||
follow_redirects: bool = True
|
||||
verify_ssl: bool = True
|
||||
|
||||
def __init__(self, method: str = "GET", headers: Optional[Dict[str, str]] = None, data: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, follow_redirects: bool = True, verify_ssl: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
method: str = "GET",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
follow_redirects: bool = True,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
self.method = method
|
||||
self.headers = headers
|
||||
self.data = data
|
||||
@@ -426,23 +440,23 @@ class HTTPCrawlerConfig():
|
||||
"follow_redirects": self.follow_redirects,
|
||||
"verify_ssl": self.verify_ssl,
|
||||
}
|
||||
|
||||
|
||||
def clone(self, **kwargs):
|
||||
"""Create a copy of this configuration with updated values.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Key-value pairs of configuration options to update
|
||||
|
||||
|
||||
Returns:
|
||||
HTTPCrawlerConfig: A new instance with the specified updates
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
config_dict.update(kwargs)
|
||||
return HTTPCrawlerConfig.from_kwargs(config_dict)
|
||||
|
||||
|
||||
def dump(self) -> dict:
|
||||
return to_serializable_dict(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load(data: dict) -> "HTTPCrawlerConfig":
|
||||
config = from_serializable_dict(data)
|
||||
@@ -469,7 +483,7 @@ class CrawlerRunConfig():
|
||||
Attributes:
|
||||
# Deep Crawl Parameters
|
||||
deep_crawl_strategy (DeepCrawlStrategy or None): Strategy to use for deep crawling.
|
||||
|
||||
|
||||
# Content Processing Parameters
|
||||
word_count_threshold (int): Minimum word count threshold before processing content.
|
||||
Default: MIN_WORD_THRESHOLD (typically 200).
|
||||
@@ -606,20 +620,20 @@ class CrawlerRunConfig():
|
||||
data (dict): Data to send in the request body, when using AsyncHTTPCrwalerStrategy.
|
||||
Default: None.
|
||||
json (dict): JSON data to send in the request body, when using AsyncHTTPCrwalerStrategy.
|
||||
|
||||
|
||||
# Connection Parameters
|
||||
stream (bool): If True, enables streaming of crawled URLs as they are processed when used with arun_many.
|
||||
Default: False.
|
||||
|
||||
|
||||
check_robots_txt (bool): Whether to check robots.txt rules before crawling. Default: False
|
||||
Default: False.
|
||||
user_agent (str): Custom User-Agent string to use.
|
||||
Default: False.
|
||||
user_agent (str): Custom User-Agent string to use.
|
||||
Default: None.
|
||||
user_agent_mode (str or None): Mode for generating the user agent (e.g., "random"). If None, use the provided user_agent as-is.
|
||||
user_agent_mode (str or None): Mode for generating the user agent (e.g., "random"). If None, use the provided user_agent as-is.
|
||||
Default: None.
|
||||
user_agent_generator_config (dict or None): Configuration for user agent generation if user_agent_mode is set.
|
||||
Default: None.
|
||||
|
||||
|
||||
url: str = None # This is not a compulsory parameter
|
||||
"""
|
||||
|
||||
@@ -700,7 +714,6 @@ class CrawlerRunConfig():
|
||||
user_agent_generator_config: dict = {},
|
||||
# Deep Crawl Parameters
|
||||
deep_crawl_strategy: Optional[DeepCrawlStrategy] = None,
|
||||
|
||||
):
|
||||
# TODO: Planning to set properties dynamically based on the __init__ signature
|
||||
self.url = url
|
||||
@@ -810,7 +823,6 @@ class CrawlerRunConfig():
|
||||
if self.chunking_strategy is None:
|
||||
self.chunking_strategy = RegexChunking()
|
||||
|
||||
|
||||
# Deep Crawl Parameters
|
||||
self.deep_crawl_strategy = deep_crawl_strategy
|
||||
|
||||
@@ -918,7 +930,6 @@ class CrawlerRunConfig():
|
||||
user_agent_generator_config=kwargs.get("user_agent_generator_config", {}),
|
||||
# Deep Crawl Parameters
|
||||
deep_crawl_strategy=kwargs.get("deep_crawl_strategy"),
|
||||
|
||||
url=kwargs.get("url"),
|
||||
)
|
||||
|
||||
@@ -930,7 +941,7 @@ class CrawlerRunConfig():
|
||||
@staticmethod
|
||||
def load(data: dict) -> "CrawlerRunConfig":
|
||||
# Deserialize the object from a dictionary
|
||||
config = from_serializable_dict(data)
|
||||
config = from_serializable_dict(data)
|
||||
if isinstance(config, CrawlerRunConfig):
|
||||
return config
|
||||
return CrawlerRunConfig.from_kwargs(config)
|
||||
@@ -1006,18 +1017,18 @@ class CrawlerRunConfig():
|
||||
|
||||
def clone(self, **kwargs):
|
||||
"""Create a copy of this configuration with updated values.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Key-value pairs of configuration options to update
|
||||
|
||||
|
||||
Returns:
|
||||
CrawlerRunConfig: A new instance with the specified updates
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Create a new config with streaming enabled
|
||||
stream_config = config.clone(stream=True)
|
||||
|
||||
|
||||
# Create a new config with multiple updates
|
||||
new_config = config.clone(
|
||||
stream=True,
|
||||
@@ -1031,3 +1042,50 @@ class CrawlerRunConfig():
|
||||
return CrawlerRunConfig.from_kwargs(config_dict)
|
||||
|
||||
|
||||
class LlmConfig:
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = DEFAULT_PROVIDER,
|
||||
api_token: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
"""Configuaration class for LLM provider and API token."""
|
||||
self.provider = provider
|
||||
if api_token and not api_token.startswith("env:"):
|
||||
self.api_token = api_token
|
||||
elif api_token and api_token.startswith("env:"):
|
||||
self.api_token = os.getenv(api_token[4:])
|
||||
else:
|
||||
self.api_token = PROVIDER_MODELS.get(provider, "no-token") or os.getenv(
|
||||
"OPENAI_API_KEY"
|
||||
)
|
||||
self.base_url = base_url
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_kwargs(kwargs: dict) -> "LlmConfig":
|
||||
return LlmConfig(
|
||||
provider=kwargs.get("provider", DEFAULT_PROVIDER),
|
||||
api_token=kwargs.get("api_token"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"provider": self.provider,
|
||||
"api_token": self.api_token,
|
||||
"base_url": self.base_url
|
||||
}
|
||||
|
||||
def clone(self, **kwargs):
|
||||
"""Create a copy of this configuration with updated values.
|
||||
|
||||
Args:
|
||||
**kwargs: Key-value pairs of configuration options to update
|
||||
|
||||
Returns:
|
||||
LLMConfig: A new instance with the specified updates
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
config_dict.update(kwargs)
|
||||
return LlmConfig.from_kwargs(config_dict)
|
||||
|
||||
@@ -21,6 +21,12 @@ PROVIDER_MODELS = {
|
||||
"anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic/claude-3-5-sonnet-20240620": os.getenv("ANTHROPIC_API_KEY"),
|
||||
"gemini/gemini-pro": os.getenv("GEMINI_API_KEY"),
|
||||
'gemini/gemini-1.5-pro': os.getenv("GEMINI_API_KEY"),
|
||||
'gemini/gemini-2.0-flash': os.getenv("GEMINI_API_KEY"),
|
||||
'gemini/gemini-2.0-flash-exp': os.getenv("GEMINI_API_KEY"),
|
||||
'gemini/gemini-2.0-flash-lite-preview-02-05': os.getenv("GEMINI_API_KEY"),
|
||||
"deepseek/deepseek-chat": os.getenv("DEEPSEEK_API_KEY"),
|
||||
}
|
||||
|
||||
# Chunk token threshold
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import re
|
||||
import time
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
@@ -5,7 +6,16 @@ from typing import List, Tuple, Dict, Optional
|
||||
from rank_bm25 import BM25Okapi
|
||||
from collections import deque
|
||||
from bs4 import NavigableString, Comment
|
||||
from .utils import clean_tokens, perform_completion_with_backoff, escape_json_string, sanitize_html, get_home_folder, extract_xml_data, merge_chunks
|
||||
|
||||
from .utils import (
|
||||
clean_tokens,
|
||||
perform_completion_with_backoff,
|
||||
escape_json_string,
|
||||
sanitize_html,
|
||||
get_home_folder,
|
||||
extract_xml_data,
|
||||
merge_chunks,
|
||||
)
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
from snowballstemmer import stemmer
|
||||
@@ -20,10 +30,16 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from .async_logger import AsyncLogger, LogLevel
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
||||
class RelevantContentFilter(ABC):
|
||||
"""Abstract base class for content filtering strategies"""
|
||||
|
||||
def __init__(self, user_query: str = None, verbose: bool = False, logger: Optional[AsyncLogger] = None):
|
||||
def __init__(
|
||||
self,
|
||||
user_query: str = None,
|
||||
verbose: bool = False,
|
||||
logger: Optional[AsyncLogger] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the RelevantContentFilter class with optional user query.
|
||||
|
||||
@@ -362,6 +378,7 @@ class RelevantContentFilter(ABC):
|
||||
except Exception:
|
||||
return str(tag) # Fallback to original if anything fails
|
||||
|
||||
|
||||
class BM25ContentFilter(RelevantContentFilter):
|
||||
"""
|
||||
Content filtering using BM25 algorithm with priority tag handling.
|
||||
@@ -504,6 +521,7 @@ class BM25ContentFilter(RelevantContentFilter):
|
||||
|
||||
return [self.clean_element(tag) for _, _, tag in selected_candidates]
|
||||
|
||||
|
||||
class PruningContentFilter(RelevantContentFilter):
|
||||
"""
|
||||
Content filtering using pruning algorithm with dynamic threshold.
|
||||
@@ -750,13 +768,21 @@ class PruningContentFilter(RelevantContentFilter):
|
||||
class_id_score -= 0.5
|
||||
return class_id_score
|
||||
|
||||
|
||||
class LLMContentFilter(RelevantContentFilter):
|
||||
"""Content filtering using LLMs to generate relevant markdown."""
|
||||
_UNWANTED_PROPS = {
|
||||
'provider' : 'Instead, use llmConfig=LlmConfig(provider="...")',
|
||||
'api_token' : 'Instead, use llmConfig=LlMConfig(api_token="...")',
|
||||
'base_url' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
|
||||
'api_base' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = DEFAULT_PROVIDER,
|
||||
api_token: Optional[str] = None,
|
||||
llmConfig: "LlmConfig" = None,
|
||||
instruction: str = None,
|
||||
chunk_token_threshold: int = int(1e9),
|
||||
overlap_rate: float = OVERLAP_RATE,
|
||||
@@ -768,15 +794,13 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
# chunk_mode: str = "char",
|
||||
verbose: bool = False,
|
||||
logger: Optional[AsyncLogger] = None,
|
||||
ignore_cache: bool = False,
|
||||
ignore_cache: bool = True,
|
||||
):
|
||||
super().__init__(None)
|
||||
self.provider = provider
|
||||
self.api_token = (
|
||||
api_token
|
||||
or PROVIDER_MODELS.get(provider, "no-token")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
self.api_token = api_token
|
||||
self.base_url = base_url or api_base
|
||||
self.llmConfig = llmConfig
|
||||
self.instruction = instruction
|
||||
self.chunk_token_threshold = chunk_token_threshold
|
||||
self.overlap_rate = overlap_rate
|
||||
@@ -785,12 +809,10 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
# self.char_token_rate = char_token_rate or word_token_rate / 5
|
||||
# self.token_rate = word_token_rate if chunk_mode == "word" else self.char_token_rate
|
||||
self.token_rate = word_token_rate or WORD_TOKEN_RATE
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base or base_url
|
||||
self.extra_args = extra_args or {}
|
||||
self.ignore_cache = ignore_cache
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
# Setup logger with custom styling for LLM operations
|
||||
if logger:
|
||||
self.logger = logger
|
||||
@@ -801,19 +823,31 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
**AsyncLogger.DEFAULT_ICONS,
|
||||
"LLM": "★", # Star for LLM operations
|
||||
"CHUNK": "◈", # Diamond for chunks
|
||||
"CACHE": "⚡", # Lightning for cache operations
|
||||
"CACHE": "⚡", # Lightning for cache operations
|
||||
},
|
||||
colors={
|
||||
**AsyncLogger.DEFAULT_COLORS,
|
||||
LogLevel.INFO: Fore.MAGENTA + Style.DIM, # Dimmed purple for LLM ops
|
||||
}
|
||||
LogLevel.INFO: Fore.MAGENTA
|
||||
+ Style.DIM, # Dimmed purple for LLM ops
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.logger = None
|
||||
|
||||
|
||||
self.usages = []
|
||||
self.total_usage = TokenUsage()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Handle attribute setting."""
|
||||
# TODO: Planning to set properties dynamically based on the __init__ signature
|
||||
sig = inspect.signature(self.__init__)
|
||||
all_params = sig.parameters # Dictionary of parameter names and their details
|
||||
|
||||
if name in self._UNWANTED_PROPS and value is not all_params[name].default:
|
||||
raise AttributeError(f"Setting '{name}' is deprecated. {self._UNWANTED_PROPS[name]}")
|
||||
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def _get_cache_key(self, html: str, instruction: str) -> str:
|
||||
"""Generate a unique cache key based on HTML and instruction"""
|
||||
content = f"{html}{instruction}"
|
||||
@@ -823,14 +857,12 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
"""Split text into chunks with overlap using char or word mode."""
|
||||
ov = int(self.chunk_token_threshold * self.overlap_rate)
|
||||
sections = merge_chunks(
|
||||
docs = [text],
|
||||
target_size= self.chunk_token_threshold,
|
||||
docs=[text],
|
||||
target_size=self.chunk_token_threshold,
|
||||
overlap=ov,
|
||||
word_token_ratio=self.word_token_rate
|
||||
word_token_ratio=self.word_token_rate,
|
||||
)
|
||||
return sections
|
||||
|
||||
|
||||
|
||||
def filter_content(self, html: str, ignore_cache: bool = True) -> List[str]:
|
||||
if not html or not isinstance(html, str):
|
||||
@@ -838,10 +870,10 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
"Starting LLM markdown content filtering process",
|
||||
"Starting LLM markdown content filtering process",
|
||||
tag="LLM",
|
||||
params={"provider": self.provider},
|
||||
colors={"provider": Fore.CYAN}
|
||||
params={"provider": self.llmConfig.provider},
|
||||
colors={"provider": Fore.CYAN},
|
||||
)
|
||||
|
||||
# Cache handling
|
||||
@@ -857,47 +889,47 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
if self.logger:
|
||||
self.logger.info("Found cached markdown result", tag="CACHE")
|
||||
try:
|
||||
with cache_file.open('r') as f:
|
||||
with cache_file.open("r") as f:
|
||||
cached_data = json.load(f)
|
||||
usage = TokenUsage(**cached_data['usage'])
|
||||
usage = TokenUsage(**cached_data["usage"])
|
||||
self.usages.append(usage)
|
||||
self.total_usage.completion_tokens += usage.completion_tokens
|
||||
self.total_usage.prompt_tokens += usage.prompt_tokens
|
||||
self.total_usage.total_tokens += usage.total_tokens
|
||||
return cached_data['blocks']
|
||||
return cached_data["blocks"]
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"LLM markdown: Cache read error: {str(e)}", tag="CACHE")
|
||||
self.logger.error(
|
||||
f"LLM markdown: Cache read error: {str(e)}", tag="CACHE"
|
||||
)
|
||||
|
||||
# Split into chunks
|
||||
html_chunks = self._merge_chunks(html)
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
"LLM markdown: Split content into {chunk_count} chunks",
|
||||
"LLM markdown: Split content into {chunk_count} chunks",
|
||||
tag="CHUNK",
|
||||
params={"chunk_count": len(html_chunks)},
|
||||
colors={"chunk_count": Fore.YELLOW}
|
||||
colors={"chunk_count": Fore.YELLOW},
|
||||
)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Process chunks in parallel
|
||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = []
|
||||
for i, chunk in enumerate(html_chunks):
|
||||
if self.logger:
|
||||
self.logger.debug(
|
||||
"LLM markdown: Processing chunk {chunk_num}/{total_chunks}",
|
||||
"LLM markdown: Processing chunk {chunk_num}/{total_chunks}",
|
||||
tag="CHUNK",
|
||||
params={
|
||||
"chunk_num": i + 1,
|
||||
"total_chunks": len(html_chunks)
|
||||
}
|
||||
params={"chunk_num": i + 1, "total_chunks": len(html_chunks)},
|
||||
)
|
||||
|
||||
prompt_variables = {
|
||||
"HTML": escape_json_string(sanitize_html(chunk)),
|
||||
"REQUEST": self.instruction or "Convert this HTML into clean, relevant markdown, removing any noise or irrelevant content."
|
||||
"REQUEST": self.instruction
|
||||
or "Convert this HTML into clean, relevant markdown, removing any noise or irrelevant content.",
|
||||
}
|
||||
|
||||
prompt = PROMPT_FILTER_CONTENT
|
||||
@@ -905,95 +937,96 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
prompt = prompt.replace("{" + var + "}", value)
|
||||
|
||||
def _proceed_with_chunk(
|
||||
provider: str,
|
||||
prompt: str,
|
||||
api_token: str,
|
||||
base_url: Optional[str] = None,
|
||||
extra_args: Dict = {}
|
||||
) -> List[str]:
|
||||
provider: str,
|
||||
prompt: str,
|
||||
api_token: str,
|
||||
base_url: Optional[str] = None,
|
||||
extra_args: Dict = {},
|
||||
) -> List[str]:
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
"LLM Markdown: Processing chunk {chunk_num}",
|
||||
"LLM Markdown: Processing chunk {chunk_num}",
|
||||
tag="CHUNK",
|
||||
params={"chunk_num": i + 1}
|
||||
params={"chunk_num": i + 1},
|
||||
)
|
||||
return perform_completion_with_backoff(
|
||||
provider,
|
||||
prompt,
|
||||
api_token,
|
||||
base_url=base_url,
|
||||
extra_args=extra_args
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
future = executor.submit(
|
||||
_proceed_with_chunk,
|
||||
self.provider,
|
||||
self.llmConfig.provider,
|
||||
prompt,
|
||||
self.api_token,
|
||||
self.api_base,
|
||||
self.extra_args
|
||||
self.llmConfig.api_token,
|
||||
self.llmConfig.base_url,
|
||||
self.extra_args,
|
||||
)
|
||||
futures.append((i, future))
|
||||
|
||||
|
||||
# Collect results in order
|
||||
ordered_results = []
|
||||
for i, future in sorted(futures):
|
||||
try:
|
||||
response = future.result()
|
||||
|
||||
|
||||
# Track usage
|
||||
usage = TokenUsage(
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
completion_tokens_details=response.usage.completion_tokens_details.__dict__
|
||||
if response.usage.completion_tokens_details else {},
|
||||
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__
|
||||
if response.usage.prompt_tokens_details else {},
|
||||
completion_tokens_details=(
|
||||
response.usage.completion_tokens_details.__dict__
|
||||
if response.usage.completion_tokens_details
|
||||
else {}
|
||||
),
|
||||
prompt_tokens_details=(
|
||||
response.usage.prompt_tokens_details.__dict__
|
||||
if response.usage.prompt_tokens_details
|
||||
else {}
|
||||
),
|
||||
)
|
||||
self.usages.append(usage)
|
||||
self.total_usage.completion_tokens += usage.completion_tokens
|
||||
self.total_usage.prompt_tokens += usage.prompt_tokens
|
||||
self.total_usage.total_tokens += usage.total_tokens
|
||||
|
||||
blocks = extract_xml_data(["content"], response.choices[0].message.content)["content"]
|
||||
blocks = extract_xml_data(
|
||||
["content"], response.choices[0].message.content
|
||||
)["content"]
|
||||
if blocks:
|
||||
ordered_results.append(blocks)
|
||||
if self.logger:
|
||||
self.logger.success(
|
||||
"LLM markdown: Successfully processed chunk {chunk_num}",
|
||||
"LLM markdown: Successfully processed chunk {chunk_num}",
|
||||
tag="CHUNK",
|
||||
params={"chunk_num": i + 1}
|
||||
params={"chunk_num": i + 1},
|
||||
)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
"LLM markdown: Error processing chunk {chunk_num}: {error}",
|
||||
"LLM markdown: Error processing chunk {chunk_num}: {error}",
|
||||
tag="CHUNK",
|
||||
params={
|
||||
"chunk_num": i + 1,
|
||||
"error": str(e)
|
||||
}
|
||||
params={"chunk_num": i + 1, "error": str(e)},
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
if self.logger:
|
||||
self.logger.success(
|
||||
"LLM markdown: Completed processing in {time:.2f}s",
|
||||
"LLM markdown: Completed processing in {time:.2f}s",
|
||||
tag="LLM",
|
||||
params={"time": end_time - start_time},
|
||||
colors={"time": Fore.YELLOW}
|
||||
colors={"time": Fore.YELLOW},
|
||||
)
|
||||
|
||||
result = ordered_results if ordered_results else []
|
||||
|
||||
# Cache the final result
|
||||
cache_data = {
|
||||
'blocks': result,
|
||||
'usage': self.total_usage.__dict__
|
||||
}
|
||||
with cache_file.open('w') as f:
|
||||
cache_data = {"blocks": result, "usage": self.total_usage.__dict__}
|
||||
with cache_file.open("w") as f:
|
||||
json.dump(cache_data, f)
|
||||
if self.logger:
|
||||
self.logger.info("Cached results for future use", tag="CACHE")
|
||||
@@ -1017,4 +1050,4 @@ class LLMContentFilter(RelevantContentFilter):
|
||||
print(
|
||||
f"{i:<10} {usage.completion_tokens:>12,} "
|
||||
f"{usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import inspect
|
||||
from typing import Any, List, Dict, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import json
|
||||
@@ -496,20 +497,26 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
usages: List of individual token usages.
|
||||
total_usage: Accumulated token usage.
|
||||
"""
|
||||
|
||||
_UNWANTED_PROPS = {
|
||||
'provider' : 'Instead, use llmConfig=LlmConfig(provider="...")',
|
||||
'api_token' : 'Instead, use llmConfig=LlMConfig(api_token="...")',
|
||||
'base_url' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
|
||||
'api_base' : 'Instead, use llmConfig=LlmConfig(base_url="...")',
|
||||
}
|
||||
def __init__(
|
||||
self,
|
||||
llmConfig: 'LLMConfig' = None,
|
||||
instruction: str = None,
|
||||
provider: str = DEFAULT_PROVIDER,
|
||||
api_token: Optional[str] = None,
|
||||
instruction: str = None,
|
||||
base_url: str = None,
|
||||
api_base: str = None,
|
||||
schema: Dict = None,
|
||||
extraction_type="block",
|
||||
chunk_token_threshold=CHUNK_TOKEN_THRESHOLD,
|
||||
overlap_rate=OVERLAP_RATE,
|
||||
word_token_rate=WORD_TOKEN_RATE,
|
||||
apply_chunking=True,
|
||||
api_base: str =None,
|
||||
base_url: str =None,
|
||||
input_format: str = "markdown",
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
@@ -518,6 +525,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
Initialize the strategy with clustering parameters.
|
||||
|
||||
Args:
|
||||
llmConfig: The LLM configuration object.
|
||||
provider: The provider to use for extraction. It follows the format <provider_name>/<model_name>, e.g., "ollama/llama3.3".
|
||||
api_token: The API token for the provider.
|
||||
instruction: The instruction to use for the LLM model.
|
||||
@@ -536,41 +544,39 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
|
||||
"""
|
||||
super().__init__( input_format=input_format, **kwargs)
|
||||
self.llmConfig = llmConfig
|
||||
self.provider = provider
|
||||
if api_token and not api_token.startswith("env:"):
|
||||
self.api_token = api_token
|
||||
elif api_token and api_token.startswith("env:"):
|
||||
self.api_token = os.getenv(api_token[4:])
|
||||
else:
|
||||
self.api_token = (
|
||||
PROVIDER_MODELS.get(provider, "no-token")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
self.api_token = api_token
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.instruction = instruction
|
||||
self.extract_type = extraction_type
|
||||
self.schema = schema
|
||||
if schema:
|
||||
self.extract_type = "schema"
|
||||
|
||||
self.chunk_token_threshold = chunk_token_threshold or CHUNK_TOKEN_THRESHOLD
|
||||
self.overlap_rate = overlap_rate
|
||||
self.word_token_rate = word_token_rate
|
||||
self.apply_chunking = apply_chunking
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base or base_url
|
||||
self.extra_args = kwargs.get("extra_args", {})
|
||||
if not self.apply_chunking:
|
||||
self.chunk_token_threshold = 1e9
|
||||
|
||||
self.verbose = verbose
|
||||
self.usages = [] # Store individual usages
|
||||
self.total_usage = TokenUsage() # Accumulated usage
|
||||
|
||||
if not self.api_token:
|
||||
raise ValueError(
|
||||
"API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable."
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Handle attribute setting."""
|
||||
# TODO: Planning to set properties dynamically based on the __init__ signature
|
||||
sig = inspect.signature(self.__init__)
|
||||
all_params = sig.parameters # Dictionary of parameter names and their details
|
||||
|
||||
if name in self._UNWANTED_PROPS and value is not all_params[name].default:
|
||||
raise AttributeError(f"Setting '{name}' is deprecated. {self._UNWANTED_PROPS[name]}")
|
||||
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract meaningful blocks or chunks from the given HTML using an LLM.
|
||||
@@ -603,7 +609,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
|
||||
|
||||
if self.extract_type == "schema" and self.schema:
|
||||
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2)
|
||||
variable_values["SCHEMA"] = json.dumps(self.schema, indent=2) # if type of self.schema is dict else self.schema
|
||||
prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
|
||||
|
||||
for variable in variable_values:
|
||||
@@ -612,10 +618,10 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
)
|
||||
|
||||
response = perform_completion_with_backoff(
|
||||
self.provider,
|
||||
self.llmConfig.provider,
|
||||
prompt_with_variables,
|
||||
self.api_token,
|
||||
base_url=self.api_base or self.base_url,
|
||||
self.llmConfig.api_token,
|
||||
base_url=self.llmConfig.base_url,
|
||||
extra_args=self.extra_args,
|
||||
) # , json_response=self.extract_type == "schema")
|
||||
# Track usage
|
||||
@@ -695,7 +701,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
||||
overlap=int(self.chunk_token_threshold * self.overlap_rate),
|
||||
)
|
||||
extracted_content = []
|
||||
if self.provider.startswith("groq/"):
|
||||
if self.llmConfig.provider.startswith("groq/"):
|
||||
# Sequential processing with a delay
|
||||
for ix, section in enumerate(merged_sections):
|
||||
extract_func = partial(self.extract, url)
|
||||
@@ -1036,14 +1042,20 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
"""Get attribute value from element"""
|
||||
pass
|
||||
|
||||
_GENERATE_SCHEMA_UNWANTED_PROPS = {
|
||||
'provider': 'Instead, use llmConfig=LlmConfig(provider="...")',
|
||||
'api_token': 'Instead, use llmConfig=LlMConfig(api_token="...")',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_schema(
|
||||
html: str,
|
||||
schema_type: str = "CSS", # or XPATH
|
||||
query: str = None,
|
||||
target_json_example: str = None,
|
||||
provider: str = "gpt-4o",
|
||||
api_token: str = os.getenv("OPENAI_API_KEY"),
|
||||
llmConfig: 'LLMConfig' = None,
|
||||
provider: str = None,
|
||||
api_token: str = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -1052,8 +1064,9 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
Args:
|
||||
html (str): The HTML content to analyze
|
||||
query (str, optional): Natural language description of what data to extract
|
||||
provider (str): LLM provider to use
|
||||
api_token (str): API token for LLM provider
|
||||
provider (str): Legacy Parameter. LLM provider to use
|
||||
api_token (str): Legacy Parameter. API token for LLM provider
|
||||
llmConfig (LlmConfig): LLM configuration object
|
||||
prompt (str, optional): Custom prompt template to use
|
||||
**kwargs: Additional args passed to perform_completion_with_backoff
|
||||
|
||||
@@ -1062,6 +1075,9 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||
"""
|
||||
from .prompts import JSON_SCHEMA_BUILDER
|
||||
from .utils import perform_completion_with_backoff
|
||||
for name, message in JsonElementExtractionStrategy._GENERATE_SCHEMA_UNWANTED_PROPS.items():
|
||||
if locals()[name] is not None:
|
||||
raise AttributeError(f"Setting '{name}' is deprecated. {message}")
|
||||
|
||||
# Use default or custom prompt
|
||||
prompt_template = JSON_SCHEMA_BUILDER if schema_type == "CSS" else JSON_SCHEMA_BUILDER_XPATH
|
||||
@@ -1114,10 +1130,10 @@ In this scenario, use your best judgment to generate the schema. Try to maximize
|
||||
try:
|
||||
# Call LLM with backoff handling
|
||||
response = perform_completion_with_backoff(
|
||||
provider=provider,
|
||||
provider=llmConfig.provider,
|
||||
prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]),
|
||||
json_response = True,
|
||||
api_token=api_token,
|
||||
api_token=llmConfig.api_token,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user