Files
crawl4ai/tests/adaptive/test_embedding_strategy.py
UncleCode 1a73fb60db feat(crawl4ai): Implement adaptive crawling feature
This commit introduces the adaptive crawling feature to the crawl4ai project. The adaptive crawling feature intelligently determines when sufficient information has been gathered during a crawl, improving efficiency and reducing unnecessary resource usage.

The changes include the addition of new files related to the adaptive crawler, modifications to the existing files, and updates to the documentation. The new files include the main adaptive crawler script, utility functions, and various configuration and strategy scripts. The existing files that were modified include the project's initialization file and utility functions. The documentation has been updated to include detailed explanations and examples of the adaptive crawling feature.

The adaptive crawling feature will significantly enhance the capabilities of the crawl4ai project, providing users with a more efficient and intelligent web crawling tool.

Significant modifications:
- Added adaptive_crawler.py and related scripts
- Modified __init__.py and utils.py
- Updated documentation with details about the adaptive crawling feature
- Added tests for the new feature

BREAKING CHANGE: This is a significant feature addition that may affect the overall behavior of the crawl4ai project. Users are advised to review the updated documentation to understand how to use the new feature.

Refs: #123, #456
2025-07-04 15:16:53 +08:00

635 lines
25 KiB
Python

"""
Test and demo script for Embedding-based Adaptive Crawler
This script demonstrates the embedding-based adaptive crawling
with semantic space coverage and gap-driven expansion.
"""
import asyncio
import os
from pathlib import Path
import time
from rich.console import Console
from rich import print as rprint
import sys
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent.parent))
from crawl4ai import (
AsyncWebCrawler,
AdaptiveCrawler,
AdaptiveConfig,
CrawlState
)
console = Console()
async def test_basic_embedding_crawl():
"""Test basic embedding-based adaptive crawling"""
console.print("\n[bold yellow]Test 1: Basic Embedding-based Crawl[/bold yellow]")
console.print("Testing semantic space coverage with query expansion")
# Configure with embedding strategy
config = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.7, # Not used for stopping in embedding strategy
min_gain_threshold=0.01,
max_pages=15,
top_k_links=3,
n_query_variations=8,
embedding_model="sentence-transformers/all-MiniLM-L6-v2" # Fast, good quality
)
# For query expansion, we need an LLM config
llm_config = {
'provider': 'openai/gpt-4o-mini',
'api_token': os.getenv('OPENAI_API_KEY')
}
if not llm_config['api_token']:
console.print("[red]Warning: OPENAI_API_KEY not set. Using mock data for demo.[/red]")
# Continue with mock for demo purposes
config.embedding_llm_config = llm_config
# Create crawler
async with AsyncWebCrawler() as crawler:
prog_crawler = AdaptiveCrawler(
crawler=crawler,
config=config
)
# Start adaptive crawl
start_time = time.time()
console.print("\n[cyan]Starting semantic adaptive crawl...[/cyan]")
state = await prog_crawler.digest(
start_url="https://docs.python.org/3/library/asyncio.html",
query="async await coroutines event loops"
)
elapsed = time.time() - start_time
# Print results
console.print(f"\n[green]Crawl completed in {elapsed:.2f} seconds[/green]")
prog_crawler.print_stats(detailed=False)
# Show semantic coverage details
console.print("\n[bold cyan]Semantic Coverage Details:[/bold cyan]")
if state.expanded_queries:
console.print(f"Query expanded to {len(state.expanded_queries)} variations")
console.print("Sample variations:")
for i, q in enumerate(state.expanded_queries[:3], 1):
console.print(f" {i}. {q}")
if state.semantic_gaps:
console.print(f"\nSemantic gaps identified: {len(state.semantic_gaps)}")
console.print(f"\nFinal confidence: {prog_crawler.confidence:.2%}")
console.print(f"Is Sufficient: {'Yes (Validated)' if prog_crawler.is_sufficient else 'No'}")
console.print(f"Pages needed: {len(state.crawled_urls)}")
async def test_embedding_vs_statistical(use_openai=False):
"""Compare embedding strategy with statistical strategy"""
console.print("\n[bold yellow]Test 2: Embedding vs Statistical Strategy Comparison[/bold yellow]")
test_url = "https://httpbin.org"
test_query = "http headers authentication api"
# Test 1: Statistical strategy
console.print("\n[cyan]1. Statistical Strategy:[/cyan]")
config_stat = AdaptiveConfig(
strategy="statistical",
confidence_threshold=0.7,
max_pages=10
)
async with AsyncWebCrawler() as crawler:
stat_crawler = AdaptiveCrawler(crawler=crawler, config=config_stat)
start_time = time.time()
state_stat = await stat_crawler.digest(start_url=test_url, query=test_query)
stat_time = time.time() - start_time
stat_pages = len(state_stat.crawled_urls)
stat_confidence = stat_crawler.confidence
# Test 2: Embedding strategy
console.print("\n[cyan]2. Embedding Strategy:[/cyan]")
config_emb = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.7, # Not used for stopping
max_pages=10,
n_query_variations=5,
min_gain_threshold=0.01
)
# Use OpenAI if available or requested
if use_openai and os.getenv('OPENAI_API_KEY'):
config_emb.embedding_llm_config = {
'provider': 'openai/text-embedding-3-small',
'api_token': os.getenv('OPENAI_API_KEY'),
'embedding_model': 'text-embedding-3-small'
}
console.print("[cyan]Using OpenAI embeddings[/cyan]")
else:
# Default config will try sentence-transformers
config_emb.embedding_llm_config = {
'provider': 'openai/gpt-4o-mini',
'api_token': os.getenv('OPENAI_API_KEY', 'dummy-key')
}
async with AsyncWebCrawler() as crawler:
emb_crawler = AdaptiveCrawler(crawler=crawler, config=config_emb)
start_time = time.time()
state_emb = await emb_crawler.digest(start_url=test_url, query=test_query)
emb_time = time.time() - start_time
emb_pages = len(state_emb.crawled_urls)
emb_confidence = emb_crawler.confidence
# Compare results
console.print("\n[bold green]Comparison Results:[/bold green]")
console.print(f"Statistical: {stat_pages} pages in {stat_time:.2f}s, confidence: {stat_confidence:.2%}, sufficient: {stat_crawler.is_sufficient}")
console.print(f"Embedding: {emb_pages} pages in {emb_time:.2f}s, confidence: {emb_confidence:.2%}, sufficient: {emb_crawler.is_sufficient}")
if emb_pages < stat_pages:
efficiency = ((stat_pages - emb_pages) / stat_pages) * 100
console.print(f"\n[green]Embedding strategy used {efficiency:.0f}% fewer pages![/green]")
# Show validation info for embedding
if hasattr(state_emb, 'metrics') and 'validation_confidence' in state_emb.metrics:
console.print(f"Embedding validation score: {state_emb.metrics['validation_confidence']:.2%}")
async def test_custom_embedding_provider():
"""Test with different embedding providers"""
console.print("\n[bold yellow]Test 3: Custom Embedding Provider[/bold yellow]")
# Example with OpenAI embeddings
config = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.8, # Not used for stopping
max_pages=10,
min_gain_threshold=0.01,
n_query_variations=5
)
# Configure to use OpenAI embeddings instead of sentence-transformers
config.embedding_llm_config = {
'provider': 'openai/text-embedding-3-small',
'api_token': os.getenv('OPENAI_API_KEY'),
'embedding_model': 'text-embedding-3-small'
}
if not config.embedding_llm_config['api_token']:
console.print("[yellow]Skipping OpenAI embedding test - no API key[/yellow]")
return
async with AsyncWebCrawler() as crawler:
prog_crawler = AdaptiveCrawler(crawler=crawler, config=config)
console.print("Using OpenAI embeddings for semantic analysis...")
state = await prog_crawler.digest(
start_url="https://httpbin.org",
query="api endpoints json response"
)
prog_crawler.print_stats(detailed=False)
async def test_knowledge_export_import():
"""Test exporting and importing semantic knowledge bases"""
console.print("\n[bold yellow]Test 4: Semantic Knowledge Base Export/Import[/bold yellow]")
config = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.7, # Not used for stopping
max_pages=5,
min_gain_threshold=0.01,
n_query_variations=4
)
# First crawl
async with AsyncWebCrawler() as crawler:
crawler1 = AdaptiveCrawler(crawler=crawler, config=config)
console.print("\n[cyan]Building initial knowledge base...[/cyan]")
state1 = await crawler1.digest(
start_url="https://httpbin.org",
query="http methods headers"
)
# Export
export_path = "semantic_kb.jsonl"
crawler1.export_knowledge_base(export_path)
console.print(f"[green]Exported {len(state1.knowledge_base)} documents with embeddings[/green]")
# Import and continue
async with AsyncWebCrawler() as crawler:
crawler2 = AdaptiveCrawler(crawler=crawler, config=config)
console.print("\n[cyan]Importing knowledge base...[/cyan]")
crawler2.import_knowledge_base(export_path)
# Continue with new query - should be faster
console.print("\n[cyan]Extending with new query...[/cyan]")
state2 = await crawler2.digest(
start_url="https://httpbin.org",
query="authentication oauth tokens"
)
console.print(f"[green]Total knowledge base: {len(state2.knowledge_base)} documents[/green]")
# Cleanup
Path(export_path).unlink(missing_ok=True)
async def test_gap_visualization():
"""Visualize semantic gaps and coverage"""
console.print("\n[bold yellow]Test 5: Semantic Gap Analysis[/bold yellow]")
config = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.9, # Not used for stopping
max_pages=8,
n_query_variations=6,
min_gain_threshold=0.01
)
async with AsyncWebCrawler() as crawler:
prog_crawler = AdaptiveCrawler(crawler=crawler, config=config)
# Initial crawl
state = await prog_crawler.digest(
start_url="https://docs.python.org/3/library/",
query="concurrency threading multiprocessing"
)
# Analyze gaps
console.print("\n[bold cyan]Semantic Gap Analysis:[/bold cyan]")
console.print(f"Query variations: {len(state.expanded_queries)}")
console.print(f"Knowledge documents: {len(state.knowledge_base)}")
console.print(f"Identified gaps: {len(state.semantic_gaps)}")
if state.semantic_gaps:
console.print("\n[yellow]Gap sizes (distance from coverage):[/yellow]")
for i, (_, distance) in enumerate(state.semantic_gaps[:5], 1):
console.print(f" Gap {i}: {distance:.3f}")
# Show crawl progression
console.print("\n[cyan]Crawl Order (gap-driven selection):[/cyan]")
for i, url in enumerate(state.crawl_order[:5], 1):
console.print(f" {i}. {url}")
async def test_fast_convergence_with_relevant_query():
"""Test that both strategies reach high confidence quickly with relevant queries"""
console.print("\n[bold yellow]Test 7: Fast Convergence with Relevant Query[/bold yellow]")
console.print("Testing that strategies reach 80%+ confidence within 2-3 batches")
# Test scenarios
test_cases = [
{
"name": "Python Async Documentation",
"url": "https://docs.python.org/3/library/asyncio.html",
"query": "async await coroutines event loops tasks"
}
]
for test_case in test_cases:
console.print(f"\n[bold cyan]Testing: {test_case['name']}[/bold cyan]")
console.print(f"URL: {test_case['url']}")
console.print(f"Query: {test_case['query']}")
# Test Embedding Strategy
console.print("\n[yellow]Embedding Strategy:[/yellow]")
config_emb = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.8,
max_pages=9,
top_k_links=3,
min_gain_threshold=0.01,
n_query_variations=5
)
# Configure embeddings
config_emb.embedding_llm_config = {
'provider': 'openai/gpt-4o-mini',
'api_token': os.getenv('OPENAI_API_KEY'),
}
async with AsyncWebCrawler() as crawler:
emb_crawler = AdaptiveCrawler(crawler=crawler, config=config_emb)
start_time = time.time()
state = await emb_crawler.digest(
start_url=test_case['url'],
query=test_case['query']
)
# Get batch breakdown
total_pages = len(state.crawled_urls)
for i in range(0, total_pages, 3):
batch_num = (i // 3) + 1
batch_pages = min(3, total_pages - i)
pages_so_far = i + batch_pages
estimated_confidence = state.metrics.get('confidence', 0) * (pages_so_far / total_pages)
console.print(f"Batch {batch_num}: {batch_pages} pages → Confidence: {estimated_confidence:.1%} {'' if estimated_confidence >= 0.8 else ''}")
final_confidence = emb_crawler.confidence
console.print(f"[green]Final: {total_pages} pages → Confidence: {final_confidence:.1%} {'✅ (Sufficient!)' if emb_crawler.is_sufficient else ''}[/green]")
# Show learning metrics for embedding
if 'avg_min_distance' in state.metrics:
console.print(f"[dim]Avg gap distance: {state.metrics['avg_min_distance']:.3f}[/dim]")
if 'validation_confidence' in state.metrics:
console.print(f"[dim]Validation score: {state.metrics['validation_confidence']:.1%}[/dim]")
# Test Statistical Strategy
console.print("\n[yellow]Statistical Strategy:[/yellow]")
config_stat = AdaptiveConfig(
strategy="statistical",
confidence_threshold=0.8,
max_pages=9,
top_k_links=3,
min_gain_threshold=0.01
)
async with AsyncWebCrawler() as crawler:
stat_crawler = AdaptiveCrawler(crawler=crawler, config=config_stat)
# Track batch progress
batch_results = []
current_pages = 0
# Custom batch tracking
start_time = time.time()
state = await stat_crawler.digest(
start_url=test_case['url'],
query=test_case['query']
)
# Get batch breakdown (every 3 pages)
total_pages = len(state.crawled_urls)
for i in range(0, total_pages, 3):
batch_num = (i // 3) + 1
batch_pages = min(3, total_pages - i)
# Estimate confidence at this point (simplified)
pages_so_far = i + batch_pages
estimated_confidence = state.metrics.get('confidence', 0) * (pages_so_far / total_pages)
console.print(f"Batch {batch_num}: {batch_pages} pages → Confidence: {estimated_confidence:.1%} {'' if estimated_confidence >= 0.8 else ''}")
final_confidence = stat_crawler.confidence
console.print(f"[green]Final: {total_pages} pages → Confidence: {final_confidence:.1%} {'✅ (Sufficient!)' if stat_crawler.is_sufficient else ''}[/green]")
async def test_irrelevant_query_behavior():
"""Test how embedding strategy handles completely irrelevant queries"""
console.print("\n[bold yellow]Test 8: Irrelevant Query Behavior[/bold yellow]")
console.print("Testing embedding strategy with a query that has no semantic relevance to the content")
# Test with irrelevant query on Python async documentation
test_case = {
"name": "Irrelevant Query on Python Docs",
"url": "https://docs.python.org/3/library/asyncio.html",
"query": "how to cook fried rice with vegetables"
}
console.print(f"\n[bold cyan]Testing: {test_case['name']}[/bold cyan]")
console.print(f"URL: {test_case['url']} (Python async documentation)")
console.print(f"Query: '{test_case['query']}' (completely irrelevant)")
console.print("\n[dim]Expected behavior: Low confidence, high distances, no convergence[/dim]")
# Configure embedding strategy
config_emb = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.8,
max_pages=9,
top_k_links=3,
min_gain_threshold=0.01,
n_query_variations=5,
embedding_min_relative_improvement=0.05, # Lower threshold to see more iterations
embedding_min_confidence_threshold=0.1 # Will stop if confidence < 10%
)
# Configure embeddings using the correct format
config_emb.embedding_llm_config = {
'provider': 'openai/gpt-4o-mini',
'api_token': os.getenv('OPENAI_API_KEY'),
}
async with AsyncWebCrawler() as crawler:
emb_crawler = AdaptiveCrawler(crawler=crawler, config=config_emb)
start_time = time.time()
state = await emb_crawler.digest(
start_url=test_case['url'],
query=test_case['query']
)
elapsed = time.time() - start_time
# Analyze results
console.print(f"\n[bold]Results after {elapsed:.1f} seconds:[/bold]")
# Basic metrics
total_pages = len(state.crawled_urls)
final_confidence = emb_crawler.confidence
console.print(f"\nPages crawled: {total_pages}")
console.print(f"Final confidence: {final_confidence:.1%} {'' if emb_crawler.is_sufficient else ''}")
# Distance metrics
if 'avg_min_distance' in state.metrics:
console.print(f"\n[yellow]Distance Metrics:[/yellow]")
console.print(f" Average minimum distance: {state.metrics['avg_min_distance']:.3f}")
console.print(f" Close neighbors (<0.3): {state.metrics.get('avg_close_neighbors', 0):.1f}")
console.print(f" Very close neighbors (<0.2): {state.metrics.get('avg_very_close_neighbors', 0):.1f}")
# Interpret distances
avg_dist = state.metrics['avg_min_distance']
if avg_dist > 0.8:
console.print(f" [red]→ Very poor match (distance > 0.8)[/red]")
elif avg_dist > 0.6:
console.print(f" [yellow]→ Poor match (distance > 0.6)[/yellow]")
elif avg_dist > 0.4:
console.print(f" [blue]→ Moderate match (distance > 0.4)[/blue]")
else:
console.print(f" [green]→ Good match (distance < 0.4)[/green]")
# Show sample expanded queries
if state.expanded_queries:
console.print(f"\n[yellow]Sample Query Variations Generated:[/yellow]")
for i, q in enumerate(state.expanded_queries[:3], 1):
console.print(f" {i}. {q}")
# Show crawl progression
console.print(f"\n[yellow]Crawl Progression:[/yellow]")
for i, url in enumerate(state.crawl_order[:5], 1):
console.print(f" {i}. {url}")
if len(state.crawl_order) > 5:
console.print(f" ... and {len(state.crawl_order) - 5} more")
# Validation score
if 'validation_confidence' in state.metrics:
console.print(f"\n[yellow]Validation:[/yellow]")
console.print(f" Validation score: {state.metrics['validation_confidence']:.1%}")
# Why it stopped
if 'stopped_reason' in state.metrics:
console.print(f"\n[yellow]Stopping Reason:[/yellow] {state.metrics['stopped_reason']}")
if state.metrics.get('is_irrelevant', False):
console.print("[red]→ Query and content are completely unrelated![/red]")
elif total_pages >= config_emb.max_pages:
console.print(f"\n[yellow]Stopping Reason:[/yellow] Reached max pages limit ({config_emb.max_pages})")
# Summary
console.print(f"\n[bold]Summary:[/bold]")
if final_confidence < 0.2:
console.print("[red]✗ As expected: Query is completely irrelevant to content[/red]")
console.print("[green]✓ The embedding strategy correctly identified no semantic match[/green]")
else:
console.print(f"[yellow]⚠ Unexpected: Got {final_confidence:.1%} confidence for irrelevant query[/yellow]")
console.print("[yellow] This may indicate the query variations are too broad[/yellow]")
async def test_high_dimensional_handling():
"""Test handling of high-dimensional embedding spaces"""
console.print("\n[bold yellow]Test 6: High-Dimensional Embedding Space Handling[/bold yellow]")
console.print("Testing how the system handles 384+ dimensional embeddings")
config = AdaptiveConfig(
strategy="embedding",
confidence_threshold=0.8, # Not used for stopping
max_pages=5,
n_query_variations=8, # Will create 9 points total
min_gain_threshold=0.01,
embedding_model="sentence-transformers/all-MiniLM-L6-v2" # 384 dimensions
)
# Use OpenAI if available, otherwise mock
if os.getenv('OPENAI_API_KEY'):
config.embedding_llm_config = {
'provider': 'openai/text-embedding-3-small',
'api_token': os.getenv('OPENAI_API_KEY'),
'embedding_model': 'text-embedding-3-small'
}
else:
config.embedding_llm_config = {
'provider': 'openai/gpt-4o-mini',
'api_token': 'mock-key'
}
async with AsyncWebCrawler() as crawler:
prog_crawler = AdaptiveCrawler(crawler=crawler, config=config)
console.print("\n[cyan]Testing with high-dimensional embeddings (384D)...[/cyan]")
try:
state = await prog_crawler.digest(
start_url="https://httpbin.org",
query="api endpoints json"
)
console.print(f"[green]✓ Successfully handled {len(state.expanded_queries)} queries in 384D space[/green]")
console.print(f"Coverage shape type: {type(state.coverage_shape)}")
if isinstance(state.coverage_shape, dict):
console.print(f"Coverage model: centroid + radius")
console.print(f" - Center shape: {state.coverage_shape['center'].shape if 'center' in state.coverage_shape else 'N/A'}")
console.print(f" - Radius: {state.coverage_shape.get('radius', 'N/A'):.3f}")
except Exception as e:
console.print(f"[red]Error: {e}[/red]")
console.print("[yellow]This demonstrates why alpha shapes don't work in high dimensions[/yellow]")
async def main():
"""Run all embedding strategy tests"""
console.print("[bold magenta]Embedding-based Adaptive Crawler Test Suite[/bold magenta]")
console.print("=" * 60)
try:
# Check if we have required dependencies
has_sentence_transformers = True
has_numpy = True
try:
import numpy
console.print("[green]✓ NumPy installed[/green]")
except ImportError:
has_numpy = False
console.print("[red]Missing numpy[/red]")
# Try to import sentence_transformers but catch numpy compatibility errors
try:
import sentence_transformers
console.print("[green]✓ Sentence-transformers installed[/green]")
except (ImportError, RuntimeError, ValueError) as e:
has_sentence_transformers = False
console.print(f"[yellow]Warning: sentence-transformers not available[/yellow]")
console.print("[yellow]Tests will use OpenAI embeddings if available or mock data[/yellow]")
# Run tests based on available dependencies
if has_numpy:
# Check if we should use OpenAI for embeddings
use_openai = not has_sentence_transformers and os.getenv('OPENAI_API_KEY')
if not has_sentence_transformers and not os.getenv('OPENAI_API_KEY'):
console.print("\n[red]Neither sentence-transformers nor OpenAI API key available[/red]")
console.print("[yellow]Please set OPENAI_API_KEY or fix sentence-transformers installation[/yellow]")
return
# Run all tests
# await test_basic_embedding_crawl()
# await test_embedding_vs_statistical(use_openai=use_openai)
# Run the fast convergence test - this is the most important one
# await test_fast_convergence_with_relevant_query()
# Test with irrelevant query
await test_irrelevant_query_behavior()
# Only run OpenAI-specific test if we have API key
# if os.getenv('OPENAI_API_KEY'):
# await test_custom_embedding_provider()
# # Skip tests that require sentence-transformers when it's not available
# if has_sentence_transformers:
# await test_knowledge_export_import()
# await test_gap_visualization()
# else:
# console.print("\n[yellow]Skipping tests that require sentence-transformers due to numpy compatibility issues[/yellow]")
# This test should work with mock data
# await test_high_dimensional_handling()
else:
console.print("\n[red]Cannot run tests without NumPy[/red]")
return
console.print("\n[bold green]✅ All tests completed![/bold green]")
except Exception as e:
console.print(f"\n[bold red]❌ Test failed: {e}[/bold red]")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())