Merge pull request #1471 from unclecode/fix/adaptive-crawler-llm-config
Fix: allow custom LLM providers for adaptive crawler embedding config…
This commit is contained in:
@@ -19,7 +19,7 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig
|
||||
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig, LLMConfig
|
||||
from crawl4ai.models import Link, CrawlResult
|
||||
import numpy as np
|
||||
|
||||
@@ -178,7 +178,7 @@ class AdaptiveConfig:
|
||||
|
||||
# Embedding strategy parameters
|
||||
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_llm_config: Optional[Dict] = None # Separate config for embeddings
|
||||
embedding_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings
|
||||
n_query_variations: int = 10
|
||||
coverage_threshold: float = 0.85
|
||||
alpha_shape_alpha: float = 0.5
|
||||
@@ -250,6 +250,30 @@ class AdaptiveConfig:
|
||||
assert 0 <= self.embedding_quality_max_confidence <= 1, "embedding_quality_max_confidence must be between 0 and 1"
|
||||
assert self.embedding_quality_scale_factor > 0, "embedding_quality_scale_factor must be positive"
|
||||
assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1"
|
||||
|
||||
@property
|
||||
def _embedding_llm_config_dict(self) -> Optional[Dict]:
|
||||
"""Convert LLMConfig to dict format for backward compatibility."""
|
||||
if self.embedding_llm_config is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.embedding_llm_config, dict):
|
||||
# Already a dict - return as-is for backward compatibility
|
||||
return self.embedding_llm_config
|
||||
|
||||
# Convert LLMConfig object to dict format
|
||||
return {
|
||||
'provider': self.embedding_llm_config.provider,
|
||||
'api_token': self.embedding_llm_config.api_token,
|
||||
'base_url': getattr(self.embedding_llm_config, 'base_url', None),
|
||||
'temperature': getattr(self.embedding_llm_config, 'temperature', None),
|
||||
'max_tokens': getattr(self.embedding_llm_config, 'max_tokens', None),
|
||||
'top_p': getattr(self.embedding_llm_config, 'top_p', None),
|
||||
'frequency_penalty': getattr(self.embedding_llm_config, 'frequency_penalty', None),
|
||||
'presence_penalty': getattr(self.embedding_llm_config, 'presence_penalty', None),
|
||||
'stop': getattr(self.embedding_llm_config, 'stop', None),
|
||||
'n': getattr(self.embedding_llm_config, 'n', None),
|
||||
}
|
||||
|
||||
|
||||
class CrawlStrategy(ABC):
|
||||
@@ -593,7 +617,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
class EmbeddingStrategy(CrawlStrategy):
|
||||
"""Embedding-based adaptive crawling using semantic space coverage"""
|
||||
|
||||
def __init__(self, embedding_model: str = None, llm_config: Dict = None):
|
||||
def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dict] = None):
|
||||
self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
self.llm_config = llm_config
|
||||
self._embedding_cache = {}
|
||||
@@ -605,14 +629,24 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
self._kb_embeddings_hash = None # Track KB changes
|
||||
self._validation_embeddings_cache = None # Cache validation query embeddings
|
||||
self._kb_similarity_threshold = 0.95 # Threshold for deduplication
|
||||
|
||||
def _get_embedding_llm_config_dict(self) -> Dict:
|
||||
"""Get embedding LLM config as dict with fallback to default."""
|
||||
if hasattr(self, 'config') and self.config:
|
||||
config_dict = self.config._embedding_llm_config_dict
|
||||
if config_dict:
|
||||
return config_dict
|
||||
|
||||
# Fallback to default if no config provided
|
||||
return {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
|
||||
async def _get_embeddings(self, texts: List[str]) -> Any:
|
||||
"""Get embeddings using configured method"""
|
||||
from .utils import get_text_embeddings
|
||||
embedding_llm_config = {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
||||
return await get_text_embeddings(
|
||||
texts,
|
||||
embedding_llm_config,
|
||||
@@ -679,8 +713,20 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
Return as a JSON array of strings."""
|
||||
|
||||
# Use the LLM for query generation
|
||||
provider = self.llm_config.get('provider', 'openai/gpt-4o-mini') if self.llm_config else 'openai/gpt-4o-mini'
|
||||
api_token = self.llm_config.get('api_token') if self.llm_config else None
|
||||
# Convert LLMConfig to dict if needed
|
||||
llm_config_dict = None
|
||||
if self.llm_config:
|
||||
if isinstance(self.llm_config, dict):
|
||||
llm_config_dict = self.llm_config
|
||||
else:
|
||||
# Convert LLMConfig object to dict
|
||||
llm_config_dict = {
|
||||
'provider': self.llm_config.provider,
|
||||
'api_token': self.llm_config.api_token
|
||||
}
|
||||
|
||||
provider = llm_config_dict.get('provider', 'openai/gpt-4o-mini') if llm_config_dict else 'openai/gpt-4o-mini'
|
||||
api_token = llm_config_dict.get('api_token') if llm_config_dict else None
|
||||
|
||||
# response = perform_completion_with_backoff(
|
||||
# provider=provider,
|
||||
@@ -843,10 +889,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
# Batch embed only uncached links
|
||||
if texts_to_embed:
|
||||
embedding_llm_config = {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
||||
new_embeddings = await get_text_embeddings(texts_to_embed, embedding_llm_config, self.embedding_model)
|
||||
|
||||
# Cache the new embeddings
|
||||
@@ -1184,10 +1227,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
return
|
||||
|
||||
# Get embeddings for new texts
|
||||
embedding_llm_config = {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
||||
new_embeddings = await get_text_embeddings(new_texts, embedding_llm_config, self.embedding_model)
|
||||
|
||||
# Deduplicate embeddings before adding to KB
|
||||
@@ -1256,10 +1296,12 @@ class AdaptiveCrawler:
|
||||
if strategy_name == "statistical":
|
||||
return StatisticalStrategy()
|
||||
elif strategy_name == "embedding":
|
||||
return EmbeddingStrategy(
|
||||
strategy = EmbeddingStrategy(
|
||||
embedding_model=self.config.embedding_model,
|
||||
llm_config=self.config.embedding_llm_config
|
||||
)
|
||||
strategy.config = self.config # Pass config to strategy
|
||||
return strategy
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy_name}")
|
||||
|
||||
|
||||
154
docs/examples/adaptive_crawling/llm_config_example.py
Normal file
154
docs/examples/adaptive_crawling/llm_config_example.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import os
|
||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
|
||||
|
||||
|
||||
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
|
||||
"""Test a specific configuration"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Configuration: {name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
async with AsyncWebCrawler(verbose=False) as crawler:
|
||||
adaptive = AdaptiveCrawler(crawler, config)
|
||||
result = await adaptive.digest(start_url=url, query=query)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Pages crawled: {len(result.crawled_urls)}")
|
||||
print(f"Final confidence: {adaptive.confidence:.1%}")
|
||||
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
|
||||
|
||||
if result.metrics.get('is_irrelevant', False):
|
||||
print("⚠️ Query detected as irrelevant!")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def llm_embedding():
|
||||
"""Demonstrate various embedding configurations"""
|
||||
|
||||
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
|
||||
print("=" * 60)
|
||||
|
||||
# Base URL and query for testing
|
||||
test_url = "https://docs.python.org/3/library/asyncio.html"
|
||||
|
||||
openai_llm_config = LLMConfig(
|
||||
provider='openai/text-embedding-3-small',
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
config_openai = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
max_pages=10,
|
||||
|
||||
# Use OpenAI embeddings
|
||||
embedding_llm_config=openai_llm_config,
|
||||
# embedding_llm_config={
|
||||
# 'provider': 'openai/text-embedding-3-small',
|
||||
# 'api_token': os.getenv('OPENAI_API_KEY')
|
||||
# },
|
||||
|
||||
# OpenAI embeddings are high quality, can be stricter
|
||||
embedding_k_exp=4.0,
|
||||
n_query_variations=12
|
||||
)
|
||||
|
||||
await test_configuration(
|
||||
"OpenAI Embeddings",
|
||||
config_openai,
|
||||
test_url,
|
||||
# "event-driven architecture patterns"
|
||||
"async await context managers coroutines"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def basic_adaptive_crawling():
|
||||
"""Basic adaptive crawling example"""
|
||||
|
||||
# Initialize the crawler
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# Create an adaptive crawler with default settings (statistical strategy)
|
||||
adaptive = AdaptiveCrawler(crawler)
|
||||
|
||||
# Note: You can also use embedding strategy for semantic understanding:
|
||||
# from crawl4ai import AdaptiveConfig
|
||||
# config = AdaptiveConfig(strategy="embedding")
|
||||
# adaptive = AdaptiveCrawler(crawler, config)
|
||||
|
||||
# Start adaptive crawling
|
||||
print("Starting adaptive crawl for Python async programming information...")
|
||||
result = await adaptive.digest(
|
||||
start_url="https://docs.python.org/3/library/asyncio.html",
|
||||
query="async await context managers coroutines"
|
||||
)
|
||||
|
||||
# Display crawl statistics
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
# Show final confidence
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Final Confidence: {adaptive.confidence:.2%}")
|
||||
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
|
||||
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
|
||||
|
||||
|
||||
if adaptive.confidence >= 0.8:
|
||||
print("✓ High confidence - can answer detailed questions about async Python")
|
||||
elif adaptive.confidence >= 0.6:
|
||||
print("~ Moderate confidence - can answer basic questions")
|
||||
else:
|
||||
print("✗ Low confidence - need more information")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(llm_embedding())
|
||||
# asyncio.run(basic_adaptive_crawling())
|
||||
@@ -108,7 +108,19 @@ config = AdaptiveConfig(
|
||||
embedding_min_confidence_threshold=0.1 # Stop if completely irrelevant
|
||||
)
|
||||
|
||||
# With custom embedding provider (e.g., OpenAI)
|
||||
# With custom LLM provider for query expansion (recommended)
|
||||
from crawl4ai import LLMConfig
|
||||
|
||||
config = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
embedding_llm_config=LLMConfig(
|
||||
provider='openai/text-embedding-3-small',
|
||||
api_token='your-api-key',
|
||||
temperature=0.7
|
||||
)
|
||||
)
|
||||
|
||||
# Alternative: Dictionary format (backward compatible)
|
||||
config = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
embedding_llm_config={
|
||||
|
||||
154
tests/adaptive/test_llm_embedding.py
Normal file
154
tests/adaptive/test_llm_embedding.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import os
|
||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
|
||||
|
||||
|
||||
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
|
||||
"""Test a specific configuration"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Configuration: {name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
async with AsyncWebCrawler(verbose=False) as crawler:
|
||||
adaptive = AdaptiveCrawler(crawler, config)
|
||||
result = await adaptive.digest(start_url=url, query=query)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Pages crawled: {len(result.crawled_urls)}")
|
||||
print(f"Final confidence: {adaptive.confidence:.1%}")
|
||||
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
|
||||
|
||||
if result.metrics.get('is_irrelevant', False):
|
||||
print("⚠️ Query detected as irrelevant!")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def llm_embedding():
|
||||
"""Demonstrate various embedding configurations"""
|
||||
|
||||
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
|
||||
print("=" * 60)
|
||||
|
||||
# Base URL and query for testing
|
||||
test_url = "https://docs.python.org/3/library/asyncio.html"
|
||||
|
||||
openai_llm_config = LLMConfig(
|
||||
provider='openai/text-embedding-3-small',
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
config_openai = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
max_pages=10,
|
||||
|
||||
# Use OpenAI embeddings
|
||||
embedding_llm_config=openai_llm_config,
|
||||
# embedding_llm_config={
|
||||
# 'provider': 'openai/text-embedding-3-small',
|
||||
# 'api_token': os.getenv('OPENAI_API_KEY')
|
||||
# },
|
||||
|
||||
# OpenAI embeddings are high quality, can be stricter
|
||||
embedding_k_exp=4.0,
|
||||
n_query_variations=12
|
||||
)
|
||||
|
||||
await test_configuration(
|
||||
"OpenAI Embeddings",
|
||||
config_openai,
|
||||
test_url,
|
||||
# "event-driven architecture patterns"
|
||||
"async await context managers coroutines"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def basic_adaptive_crawling():
|
||||
"""Basic adaptive crawling example"""
|
||||
|
||||
# Initialize the crawler
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# Create an adaptive crawler with default settings (statistical strategy)
|
||||
adaptive = AdaptiveCrawler(crawler)
|
||||
|
||||
# Note: You can also use embedding strategy for semantic understanding:
|
||||
# from crawl4ai import AdaptiveConfig
|
||||
# config = AdaptiveConfig(strategy="embedding")
|
||||
# adaptive = AdaptiveCrawler(crawler, config)
|
||||
|
||||
# Start adaptive crawling
|
||||
print("Starting adaptive crawl for Python async programming information...")
|
||||
result = await adaptive.digest(
|
||||
start_url="https://docs.python.org/3/library/asyncio.html",
|
||||
query="async await context managers coroutines"
|
||||
)
|
||||
|
||||
# Display crawl statistics
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
# Show final confidence
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Final Confidence: {adaptive.confidence:.2%}")
|
||||
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
|
||||
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
|
||||
|
||||
|
||||
if adaptive.confidence >= 0.8:
|
||||
print("✓ High confidence - can answer detailed questions about async Python")
|
||||
elif adaptive.confidence >= 0.6:
|
||||
print("~ Moderate confidence - can answer basic questions")
|
||||
else:
|
||||
print("✗ Low confidence - need more information")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(llm_embedding())
|
||||
# asyncio.run(basic_adaptive_crawling())
|
||||
Reference in New Issue
Block a user