From 3bc56dd028fa35ba31a62d06fe0c0de09b731bdb Mon Sep 17 00:00:00 2001 From: ntohidi Date: Tue, 9 Sep 2025 12:49:55 +0800 Subject: [PATCH] fix: allow custom LLM providers for adaptive crawler embedding config. ref: #1291 - Change embedding_llm_config from Dict to Union[LLMConfig, Dict] for type safety - Add backward-compatible conversion property _embedding_llm_config_dict - Replace all hardcoded OpenAI embedding configs with configurable options - Fix LLMConfig object attribute access in query expansion logic - Add comprehensive example demonstrating multiple provider configurations - Update documentation with both LLMConfig object and dictionary usage patterns Users can now specify any LLM provider for query expansion in embedding strategy: - New: embedding_llm_config=LLMConfig(provider='anthropic/claude-3', api_token='key') - Old: embedding_llm_config={'provider': 'openai/gpt-4', 'api_token': 'key'} (still works) --- crawl4ai/adaptive_crawler.py | 78 +++++++-- .../adaptive_crawling/llm_config_example.py | 154 ++++++++++++++++++ docs/md_v2/core/adaptive-crawling.md | 14 +- tests/adaptive/test_llm_embedding.py | 154 ++++++++++++++++++ 4 files changed, 381 insertions(+), 19 deletions(-) create mode 100644 docs/examples/adaptive_crawling/llm_config_example.py create mode 100644 tests/adaptive/test_llm_embedding.py diff --git a/crawl4ai/adaptive_crawler.py b/crawl4ai/adaptive_crawler.py index a0b8fa9c..bce1da23 100644 --- a/crawl4ai/adaptive_crawler.py +++ b/crawl4ai/adaptive_crawler.py @@ -19,7 +19,7 @@ import re from pathlib import Path from crawl4ai.async_webcrawler import AsyncWebCrawler -from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig +from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig, LLMConfig from crawl4ai.models import Link, CrawlResult import numpy as np @@ -178,7 +178,7 @@ class AdaptiveConfig: # Embedding strategy parameters embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" - embedding_llm_config: Optional[Dict] = None # Separate config for embeddings + embedding_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings n_query_variations: int = 10 coverage_threshold: float = 0.85 alpha_shape_alpha: float = 0.5 @@ -250,6 +250,30 @@ class AdaptiveConfig: assert 0 <= self.embedding_quality_max_confidence <= 1, "embedding_quality_max_confidence must be between 0 and 1" assert self.embedding_quality_scale_factor > 0, "embedding_quality_scale_factor must be positive" assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1" + + @property + def _embedding_llm_config_dict(self) -> Optional[Dict]: + """Convert LLMConfig to dict format for backward compatibility.""" + if self.embedding_llm_config is None: + return None + + if isinstance(self.embedding_llm_config, dict): + # Already a dict - return as-is for backward compatibility + return self.embedding_llm_config + + # Convert LLMConfig object to dict format + return { + 'provider': self.embedding_llm_config.provider, + 'api_token': self.embedding_llm_config.api_token, + 'base_url': getattr(self.embedding_llm_config, 'base_url', None), + 'temperature': getattr(self.embedding_llm_config, 'temperature', None), + 'max_tokens': getattr(self.embedding_llm_config, 'max_tokens', None), + 'top_p': getattr(self.embedding_llm_config, 'top_p', None), + 'frequency_penalty': getattr(self.embedding_llm_config, 'frequency_penalty', None), + 'presence_penalty': getattr(self.embedding_llm_config, 'presence_penalty', None), + 'stop': getattr(self.embedding_llm_config, 'stop', None), + 'n': getattr(self.embedding_llm_config, 'n', None), + } class CrawlStrategy(ABC): @@ -593,7 +617,7 @@ class StatisticalStrategy(CrawlStrategy): class EmbeddingStrategy(CrawlStrategy): """Embedding-based adaptive crawling using semantic space coverage""" - def __init__(self, embedding_model: str = None, llm_config: Dict = None): + def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dict] = None): self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2" self.llm_config = llm_config self._embedding_cache = {} @@ -605,14 +629,24 @@ class EmbeddingStrategy(CrawlStrategy): self._kb_embeddings_hash = None # Track KB changes self._validation_embeddings_cache = None # Cache validation query embeddings self._kb_similarity_threshold = 0.95 # Threshold for deduplication + + def _get_embedding_llm_config_dict(self) -> Dict: + """Get embedding LLM config as dict with fallback to default.""" + if hasattr(self, 'config') and self.config: + config_dict = self.config._embedding_llm_config_dict + if config_dict: + return config_dict + + # Fallback to default if no config provided + return { + 'provider': 'openai/text-embedding-3-small', + 'api_token': os.getenv('OPENAI_API_KEY') + } async def _get_embeddings(self, texts: List[str]) -> Any: """Get embeddings using configured method""" from .utils import get_text_embeddings - embedding_llm_config = { - 'provider': 'openai/text-embedding-3-small', - 'api_token': os.getenv('OPENAI_API_KEY') - } + embedding_llm_config = self._get_embedding_llm_config_dict() return await get_text_embeddings( texts, embedding_llm_config, @@ -679,8 +713,20 @@ class EmbeddingStrategy(CrawlStrategy): Return as a JSON array of strings.""" # Use the LLM for query generation - provider = self.llm_config.get('provider', 'openai/gpt-4o-mini') if self.llm_config else 'openai/gpt-4o-mini' - api_token = self.llm_config.get('api_token') if self.llm_config else None + # Convert LLMConfig to dict if needed + llm_config_dict = None + if self.llm_config: + if isinstance(self.llm_config, dict): + llm_config_dict = self.llm_config + else: + # Convert LLMConfig object to dict + llm_config_dict = { + 'provider': self.llm_config.provider, + 'api_token': self.llm_config.api_token + } + + provider = llm_config_dict.get('provider', 'openai/gpt-4o-mini') if llm_config_dict else 'openai/gpt-4o-mini' + api_token = llm_config_dict.get('api_token') if llm_config_dict else None # response = perform_completion_with_backoff( # provider=provider, @@ -843,10 +889,7 @@ class EmbeddingStrategy(CrawlStrategy): # Batch embed only uncached links if texts_to_embed: - embedding_llm_config = { - 'provider': 'openai/text-embedding-3-small', - 'api_token': os.getenv('OPENAI_API_KEY') - } + embedding_llm_config = self._get_embedding_llm_config_dict() new_embeddings = await get_text_embeddings(texts_to_embed, embedding_llm_config, self.embedding_model) # Cache the new embeddings @@ -1184,10 +1227,7 @@ class EmbeddingStrategy(CrawlStrategy): return # Get embeddings for new texts - embedding_llm_config = { - 'provider': 'openai/text-embedding-3-small', - 'api_token': os.getenv('OPENAI_API_KEY') - } + embedding_llm_config = self._get_embedding_llm_config_dict() new_embeddings = await get_text_embeddings(new_texts, embedding_llm_config, self.embedding_model) # Deduplicate embeddings before adding to KB @@ -1256,10 +1296,12 @@ class AdaptiveCrawler: if strategy_name == "statistical": return StatisticalStrategy() elif strategy_name == "embedding": - return EmbeddingStrategy( + strategy = EmbeddingStrategy( embedding_model=self.config.embedding_model, llm_config=self.config.embedding_llm_config ) + strategy.config = self.config # Pass config to strategy + return strategy else: raise ValueError(f"Unknown strategy: {strategy_name}") diff --git a/docs/examples/adaptive_crawling/llm_config_example.py b/docs/examples/adaptive_crawling/llm_config_example.py new file mode 100644 index 00000000..52794744 --- /dev/null +++ b/docs/examples/adaptive_crawling/llm_config_example.py @@ -0,0 +1,154 @@ +import asyncio +import os +from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig + + +async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str): + """Test a specific configuration""" + print(f"\n{'='*60}") + print(f"Configuration: {name}") + print(f"{'='*60}") + + async with AsyncWebCrawler(verbose=False) as crawler: + adaptive = AdaptiveCrawler(crawler, config) + result = await adaptive.digest(start_url=url, query=query) + + print("\n" + "="*50) + print("CRAWL STATISTICS") + print("="*50) + adaptive.print_stats(detailed=False) + + # Get the most relevant content found + print("\n" + "="*50) + print("MOST RELEVANT PAGES") + print("="*50) + + relevant_pages = adaptive.get_relevant_content(top_k=5) + for i, page in enumerate(relevant_pages, 1): + print(f"\n{i}. {page['url']}") + print(f" Relevance Score: {page['score']:.2%}") + + # Show a snippet of the content + content = page['content'] or "" + if content: + snippet = content[:200].replace('\n', ' ') + if len(content) > 200: + snippet += "..." + print(f" Preview: {snippet}") + + print(f"\n{'='*50}") + print(f"Pages crawled: {len(result.crawled_urls)}") + print(f"Final confidence: {adaptive.confidence:.1%}") + print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}") + + if result.metrics.get('is_irrelevant', False): + print("⚠️ Query detected as irrelevant!") + + return result + + +async def llm_embedding(): + """Demonstrate various embedding configurations""" + + print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES") + print("=" * 60) + + # Base URL and query for testing + test_url = "https://docs.python.org/3/library/asyncio.html" + + openai_llm_config = LLMConfig( + provider='openai/text-embedding-3-small', + api_token=os.getenv('OPENAI_API_KEY'), + temperature=0.7, + max_tokens=2000 + ) + config_openai = AdaptiveConfig( + strategy="embedding", + max_pages=10, + + # Use OpenAI embeddings + embedding_llm_config=openai_llm_config, + # embedding_llm_config={ + # 'provider': 'openai/text-embedding-3-small', + # 'api_token': os.getenv('OPENAI_API_KEY') + # }, + + # OpenAI embeddings are high quality, can be stricter + embedding_k_exp=4.0, + n_query_variations=12 + ) + + await test_configuration( + "OpenAI Embeddings", + config_openai, + test_url, + # "event-driven architecture patterns" + "async await context managers coroutines" + ) + return + + + +async def basic_adaptive_crawling(): + """Basic adaptive crawling example""" + + # Initialize the crawler + async with AsyncWebCrawler(verbose=True) as crawler: + # Create an adaptive crawler with default settings (statistical strategy) + adaptive = AdaptiveCrawler(crawler) + + # Note: You can also use embedding strategy for semantic understanding: + # from crawl4ai import AdaptiveConfig + # config = AdaptiveConfig(strategy="embedding") + # adaptive = AdaptiveCrawler(crawler, config) + + # Start adaptive crawling + print("Starting adaptive crawl for Python async programming information...") + result = await adaptive.digest( + start_url="https://docs.python.org/3/library/asyncio.html", + query="async await context managers coroutines" + ) + + # Display crawl statistics + print("\n" + "="*50) + print("CRAWL STATISTICS") + print("="*50) + adaptive.print_stats(detailed=False) + + # Get the most relevant content found + print("\n" + "="*50) + print("MOST RELEVANT PAGES") + print("="*50) + + relevant_pages = adaptive.get_relevant_content(top_k=5) + for i, page in enumerate(relevant_pages, 1): + print(f"\n{i}. {page['url']}") + print(f" Relevance Score: {page['score']:.2%}") + + # Show a snippet of the content + content = page['content'] or "" + if content: + snippet = content[:200].replace('\n', ' ') + if len(content) > 200: + snippet += "..." + print(f" Preview: {snippet}") + + # Show final confidence + print(f"\n{'='*50}") + print(f"Final Confidence: {adaptive.confidence:.2%}") + print(f"Total Pages Crawled: {len(result.crawled_urls)}") + print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents") + + + if adaptive.confidence >= 0.8: + print("✓ High confidence - can answer detailed questions about async Python") + elif adaptive.confidence >= 0.6: + print("~ Moderate confidence - can answer basic questions") + else: + print("✗ Low confidence - need more information") + + + +if __name__ == "__main__": + asyncio.run(llm_embedding()) + # asyncio.run(basic_adaptive_crawling()) \ No newline at end of file diff --git a/docs/md_v2/core/adaptive-crawling.md b/docs/md_v2/core/adaptive-crawling.md index ea1674c2..6f05416d 100644 --- a/docs/md_v2/core/adaptive-crawling.md +++ b/docs/md_v2/core/adaptive-crawling.md @@ -108,7 +108,19 @@ config = AdaptiveConfig( embedding_min_confidence_threshold=0.1 # Stop if completely irrelevant ) -# With custom embedding provider (e.g., OpenAI) +# With custom LLM provider for query expansion (recommended) +from crawl4ai import LLMConfig + +config = AdaptiveConfig( + strategy="embedding", + embedding_llm_config=LLMConfig( + provider='openai/text-embedding-3-small', + api_token='your-api-key', + temperature=0.7 + ) +) + +# Alternative: Dictionary format (backward compatible) config = AdaptiveConfig( strategy="embedding", embedding_llm_config={ diff --git a/tests/adaptive/test_llm_embedding.py b/tests/adaptive/test_llm_embedding.py new file mode 100644 index 00000000..52794744 --- /dev/null +++ b/tests/adaptive/test_llm_embedding.py @@ -0,0 +1,154 @@ +import asyncio +import os +from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig + + +async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str): + """Test a specific configuration""" + print(f"\n{'='*60}") + print(f"Configuration: {name}") + print(f"{'='*60}") + + async with AsyncWebCrawler(verbose=False) as crawler: + adaptive = AdaptiveCrawler(crawler, config) + result = await adaptive.digest(start_url=url, query=query) + + print("\n" + "="*50) + print("CRAWL STATISTICS") + print("="*50) + adaptive.print_stats(detailed=False) + + # Get the most relevant content found + print("\n" + "="*50) + print("MOST RELEVANT PAGES") + print("="*50) + + relevant_pages = adaptive.get_relevant_content(top_k=5) + for i, page in enumerate(relevant_pages, 1): + print(f"\n{i}. {page['url']}") + print(f" Relevance Score: {page['score']:.2%}") + + # Show a snippet of the content + content = page['content'] or "" + if content: + snippet = content[:200].replace('\n', ' ') + if len(content) > 200: + snippet += "..." + print(f" Preview: {snippet}") + + print(f"\n{'='*50}") + print(f"Pages crawled: {len(result.crawled_urls)}") + print(f"Final confidence: {adaptive.confidence:.1%}") + print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}") + + if result.metrics.get('is_irrelevant', False): + print("⚠️ Query detected as irrelevant!") + + return result + + +async def llm_embedding(): + """Demonstrate various embedding configurations""" + + print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES") + print("=" * 60) + + # Base URL and query for testing + test_url = "https://docs.python.org/3/library/asyncio.html" + + openai_llm_config = LLMConfig( + provider='openai/text-embedding-3-small', + api_token=os.getenv('OPENAI_API_KEY'), + temperature=0.7, + max_tokens=2000 + ) + config_openai = AdaptiveConfig( + strategy="embedding", + max_pages=10, + + # Use OpenAI embeddings + embedding_llm_config=openai_llm_config, + # embedding_llm_config={ + # 'provider': 'openai/text-embedding-3-small', + # 'api_token': os.getenv('OPENAI_API_KEY') + # }, + + # OpenAI embeddings are high quality, can be stricter + embedding_k_exp=4.0, + n_query_variations=12 + ) + + await test_configuration( + "OpenAI Embeddings", + config_openai, + test_url, + # "event-driven architecture patterns" + "async await context managers coroutines" + ) + return + + + +async def basic_adaptive_crawling(): + """Basic adaptive crawling example""" + + # Initialize the crawler + async with AsyncWebCrawler(verbose=True) as crawler: + # Create an adaptive crawler with default settings (statistical strategy) + adaptive = AdaptiveCrawler(crawler) + + # Note: You can also use embedding strategy for semantic understanding: + # from crawl4ai import AdaptiveConfig + # config = AdaptiveConfig(strategy="embedding") + # adaptive = AdaptiveCrawler(crawler, config) + + # Start adaptive crawling + print("Starting adaptive crawl for Python async programming information...") + result = await adaptive.digest( + start_url="https://docs.python.org/3/library/asyncio.html", + query="async await context managers coroutines" + ) + + # Display crawl statistics + print("\n" + "="*50) + print("CRAWL STATISTICS") + print("="*50) + adaptive.print_stats(detailed=False) + + # Get the most relevant content found + print("\n" + "="*50) + print("MOST RELEVANT PAGES") + print("="*50) + + relevant_pages = adaptive.get_relevant_content(top_k=5) + for i, page in enumerate(relevant_pages, 1): + print(f"\n{i}. {page['url']}") + print(f" Relevance Score: {page['score']:.2%}") + + # Show a snippet of the content + content = page['content'] or "" + if content: + snippet = content[:200].replace('\n', ' ') + if len(content) > 200: + snippet += "..." + print(f" Preview: {snippet}") + + # Show final confidence + print(f"\n{'='*50}") + print(f"Final Confidence: {adaptive.confidence:.2%}") + print(f"Total Pages Crawled: {len(result.crawled_urls)}") + print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents") + + + if adaptive.confidence >= 0.8: + print("✓ High confidence - can answer detailed questions about async Python") + elif adaptive.confidence >= 0.6: + print("~ Moderate confidence - can answer basic questions") + else: + print("✗ Low confidence - need more information") + + + +if __name__ == "__main__": + asyncio.run(llm_embedding()) + # asyncio.run(basic_adaptive_crawling()) \ No newline at end of file