Compare commits
1 Commits
fix/adapti
...
fix/relati
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
813b1f5534 |
@@ -19,7 +19,7 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig, LLMConfig
|
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig
|
||||||
from crawl4ai.models import Link, CrawlResult
|
from crawl4ai.models import Link, CrawlResult
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ class AdaptiveConfig:
|
|||||||
|
|
||||||
# Embedding strategy parameters
|
# Embedding strategy parameters
|
||||||
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
embedding_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings
|
embedding_llm_config: Optional[Dict] = None # Separate config for embeddings
|
||||||
n_query_variations: int = 10
|
n_query_variations: int = 10
|
||||||
coverage_threshold: float = 0.85
|
coverage_threshold: float = 0.85
|
||||||
alpha_shape_alpha: float = 0.5
|
alpha_shape_alpha: float = 0.5
|
||||||
@@ -251,30 +251,6 @@ class AdaptiveConfig:
|
|||||||
assert self.embedding_quality_scale_factor > 0, "embedding_quality_scale_factor must be positive"
|
assert self.embedding_quality_scale_factor > 0, "embedding_quality_scale_factor must be positive"
|
||||||
assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1"
|
assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1"
|
||||||
|
|
||||||
@property
|
|
||||||
def _embedding_llm_config_dict(self) -> Optional[Dict]:
|
|
||||||
"""Convert LLMConfig to dict format for backward compatibility."""
|
|
||||||
if self.embedding_llm_config is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if isinstance(self.embedding_llm_config, dict):
|
|
||||||
# Already a dict - return as-is for backward compatibility
|
|
||||||
return self.embedding_llm_config
|
|
||||||
|
|
||||||
# Convert LLMConfig object to dict format
|
|
||||||
return {
|
|
||||||
'provider': self.embedding_llm_config.provider,
|
|
||||||
'api_token': self.embedding_llm_config.api_token,
|
|
||||||
'base_url': getattr(self.embedding_llm_config, 'base_url', None),
|
|
||||||
'temperature': getattr(self.embedding_llm_config, 'temperature', None),
|
|
||||||
'max_tokens': getattr(self.embedding_llm_config, 'max_tokens', None),
|
|
||||||
'top_p': getattr(self.embedding_llm_config, 'top_p', None),
|
|
||||||
'frequency_penalty': getattr(self.embedding_llm_config, 'frequency_penalty', None),
|
|
||||||
'presence_penalty': getattr(self.embedding_llm_config, 'presence_penalty', None),
|
|
||||||
'stop': getattr(self.embedding_llm_config, 'stop', None),
|
|
||||||
'n': getattr(self.embedding_llm_config, 'n', None),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CrawlStrategy(ABC):
|
class CrawlStrategy(ABC):
|
||||||
"""Abstract base class for crawling strategies"""
|
"""Abstract base class for crawling strategies"""
|
||||||
@@ -617,7 +593,7 @@ class StatisticalStrategy(CrawlStrategy):
|
|||||||
class EmbeddingStrategy(CrawlStrategy):
|
class EmbeddingStrategy(CrawlStrategy):
|
||||||
"""Embedding-based adaptive crawling using semantic space coverage"""
|
"""Embedding-based adaptive crawling using semantic space coverage"""
|
||||||
|
|
||||||
def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dict] = None):
|
def __init__(self, embedding_model: str = None, llm_config: Dict = None):
|
||||||
self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
|
self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
self._embedding_cache = {}
|
self._embedding_cache = {}
|
||||||
@@ -630,23 +606,13 @@ class EmbeddingStrategy(CrawlStrategy):
|
|||||||
self._validation_embeddings_cache = None # Cache validation query embeddings
|
self._validation_embeddings_cache = None # Cache validation query embeddings
|
||||||
self._kb_similarity_threshold = 0.95 # Threshold for deduplication
|
self._kb_similarity_threshold = 0.95 # Threshold for deduplication
|
||||||
|
|
||||||
def _get_embedding_llm_config_dict(self) -> Dict:
|
|
||||||
"""Get embedding LLM config as dict with fallback to default."""
|
|
||||||
if hasattr(self, 'config') and self.config:
|
|
||||||
config_dict = self.config._embedding_llm_config_dict
|
|
||||||
if config_dict:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
# Fallback to default if no config provided
|
|
||||||
return {
|
|
||||||
'provider': 'openai/text-embedding-3-small',
|
|
||||||
'api_token': os.getenv('OPENAI_API_KEY')
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _get_embeddings(self, texts: List[str]) -> Any:
|
async def _get_embeddings(self, texts: List[str]) -> Any:
|
||||||
"""Get embeddings using configured method"""
|
"""Get embeddings using configured method"""
|
||||||
from .utils import get_text_embeddings
|
from .utils import get_text_embeddings
|
||||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
embedding_llm_config = {
|
||||||
|
'provider': 'openai/text-embedding-3-small',
|
||||||
|
'api_token': os.getenv('OPENAI_API_KEY')
|
||||||
|
}
|
||||||
return await get_text_embeddings(
|
return await get_text_embeddings(
|
||||||
texts,
|
texts,
|
||||||
embedding_llm_config,
|
embedding_llm_config,
|
||||||
@@ -713,20 +679,8 @@ class EmbeddingStrategy(CrawlStrategy):
|
|||||||
Return as a JSON array of strings."""
|
Return as a JSON array of strings."""
|
||||||
|
|
||||||
# Use the LLM for query generation
|
# Use the LLM for query generation
|
||||||
# Convert LLMConfig to dict if needed
|
provider = self.llm_config.get('provider', 'openai/gpt-4o-mini') if self.llm_config else 'openai/gpt-4o-mini'
|
||||||
llm_config_dict = None
|
api_token = self.llm_config.get('api_token') if self.llm_config else None
|
||||||
if self.llm_config:
|
|
||||||
if isinstance(self.llm_config, dict):
|
|
||||||
llm_config_dict = self.llm_config
|
|
||||||
else:
|
|
||||||
# Convert LLMConfig object to dict
|
|
||||||
llm_config_dict = {
|
|
||||||
'provider': self.llm_config.provider,
|
|
||||||
'api_token': self.llm_config.api_token
|
|
||||||
}
|
|
||||||
|
|
||||||
provider = llm_config_dict.get('provider', 'openai/gpt-4o-mini') if llm_config_dict else 'openai/gpt-4o-mini'
|
|
||||||
api_token = llm_config_dict.get('api_token') if llm_config_dict else None
|
|
||||||
|
|
||||||
# response = perform_completion_with_backoff(
|
# response = perform_completion_with_backoff(
|
||||||
# provider=provider,
|
# provider=provider,
|
||||||
@@ -889,7 +843,10 @@ class EmbeddingStrategy(CrawlStrategy):
|
|||||||
|
|
||||||
# Batch embed only uncached links
|
# Batch embed only uncached links
|
||||||
if texts_to_embed:
|
if texts_to_embed:
|
||||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
embedding_llm_config = {
|
||||||
|
'provider': 'openai/text-embedding-3-small',
|
||||||
|
'api_token': os.getenv('OPENAI_API_KEY')
|
||||||
|
}
|
||||||
new_embeddings = await get_text_embeddings(texts_to_embed, embedding_llm_config, self.embedding_model)
|
new_embeddings = await get_text_embeddings(texts_to_embed, embedding_llm_config, self.embedding_model)
|
||||||
|
|
||||||
# Cache the new embeddings
|
# Cache the new embeddings
|
||||||
@@ -1227,7 +1184,10 @@ class EmbeddingStrategy(CrawlStrategy):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get embeddings for new texts
|
# Get embeddings for new texts
|
||||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
embedding_llm_config = {
|
||||||
|
'provider': 'openai/text-embedding-3-small',
|
||||||
|
'api_token': os.getenv('OPENAI_API_KEY')
|
||||||
|
}
|
||||||
new_embeddings = await get_text_embeddings(new_texts, embedding_llm_config, self.embedding_model)
|
new_embeddings = await get_text_embeddings(new_texts, embedding_llm_config, self.embedding_model)
|
||||||
|
|
||||||
# Deduplicate embeddings before adding to KB
|
# Deduplicate embeddings before adding to KB
|
||||||
@@ -1296,12 +1256,10 @@ class AdaptiveCrawler:
|
|||||||
if strategy_name == "statistical":
|
if strategy_name == "statistical":
|
||||||
return StatisticalStrategy()
|
return StatisticalStrategy()
|
||||||
elif strategy_name == "embedding":
|
elif strategy_name == "embedding":
|
||||||
strategy = EmbeddingStrategy(
|
return EmbeddingStrategy(
|
||||||
embedding_model=self.config.embedding_model,
|
embedding_model=self.config.embedding_model,
|
||||||
llm_config=self.config.embedding_llm_config
|
llm_config=self.config.embedding_llm_config
|
||||||
)
|
)
|
||||||
strategy.config = self.config # Pass config to strategy
|
|
||||||
return strategy
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown strategy: {strategy_name}")
|
raise ValueError(f"Unknown strategy: {strategy_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -1037,7 +1037,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
downloaded_files=(
|
downloaded_files=(
|
||||||
self._downloaded_files if self._downloaded_files else None
|
self._downloaded_files if self._downloaded_files else None
|
||||||
),
|
),
|
||||||
redirected_url=redirected_url,
|
redirected_url=page.url, # Update to current URL in case of JavaScript navigation
|
||||||
# Include captured data if enabled
|
# Include captured data if enabled
|
||||||
network_requests=captured_requests if config.capture_network_requests else None,
|
network_requests=captured_requests if config.capture_network_requests else None,
|
||||||
console_messages=captured_console if config.capture_console_messages else None,
|
console_messages=captured_console if config.capture_console_messages else None,
|
||||||
|
|||||||
@@ -480,7 +480,7 @@ class AsyncWebCrawler:
|
|||||||
# Scraping Strategy Execution #
|
# Scraping Strategy Execution #
|
||||||
################################
|
################################
|
||||||
result: ScrapingResult = scraping_strategy.scrap(
|
result: ScrapingResult = scraping_strategy.scrap(
|
||||||
url, html, **params)
|
kwargs.get("redirected_url", url), html, **params)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -2149,8 +2149,10 @@ def normalize_url(
|
|||||||
*,
|
*,
|
||||||
drop_query_tracking=True,
|
drop_query_tracking=True,
|
||||||
sort_query=True,
|
sort_query=True,
|
||||||
keep_fragment=False,
|
keep_fragment=True,
|
||||||
|
remove_fragments=None, # alias for keep_fragment=False
|
||||||
extra_drop_params=None,
|
extra_drop_params=None,
|
||||||
|
params_to_remove=None, # alias for extra_drop_params
|
||||||
preserve_https=False,
|
preserve_https=False,
|
||||||
original_scheme=None
|
original_scheme=None
|
||||||
):
|
):
|
||||||
@@ -2175,10 +2177,20 @@ def normalize_url(
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str | None
|
str | None
|
||||||
A clean, canonical URL or None if href is empty/None.
|
A clean, canonical URL or the base URL if href is empty/None.
|
||||||
"""
|
"""
|
||||||
if not href:
|
if not href:
|
||||||
return None
|
# For empty href, return the base URL (matching urljoin behavior)
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
# Validate base URL format
|
||||||
|
parsed_base = urlparse(base_url)
|
||||||
|
if not parsed_base.scheme or not parsed_base.netloc:
|
||||||
|
raise ValueError(f"Invalid base URL format: {base_url}")
|
||||||
|
|
||||||
|
if parsed_base.scheme.lower() not in ["http", "https"]:
|
||||||
|
# Handle special protocols
|
||||||
|
raise ValueError(f"Invalid base URL format: {base_url}")
|
||||||
|
|
||||||
# Resolve relative paths first
|
# Resolve relative paths first
|
||||||
full_url = urljoin(base_url, href.strip())
|
full_url = urljoin(base_url, href.strip())
|
||||||
@@ -2200,13 +2212,29 @@ def normalize_url(
|
|||||||
# ── netloc ──
|
# ── netloc ──
|
||||||
netloc = parsed.netloc.lower()
|
netloc = parsed.netloc.lower()
|
||||||
|
|
||||||
|
# Remove default ports (80 for http, 443 for https)
|
||||||
|
if ':' in netloc:
|
||||||
|
host, port = netloc.rsplit(':', 1)
|
||||||
|
if (parsed.scheme == 'http' and port == '80') or (parsed.scheme == 'https' and port == '443'):
|
||||||
|
netloc = host
|
||||||
|
|
||||||
# ── path ──
|
# ── path ──
|
||||||
# Strip duplicate slashes and trailing "/" (except root)
|
# Strip duplicate slashes and trailing "/" (except root)
|
||||||
# IMPORTANT: Don't use quote(unquote()) as it mangles + signs in URLs
|
# IMPORTANT: Don't use quote(unquote()) as it mangles + signs in URLs
|
||||||
# The path from urlparse is already properly encoded
|
# The path from urlparse is already properly encoded
|
||||||
path = parsed.path
|
path = parsed.path
|
||||||
if path.endswith('/') and path != '/':
|
if path.endswith('/') and path != '/':
|
||||||
path = path.rstrip('/')
|
# Only strip trailing slash if the original href didn't have a trailing slash
|
||||||
|
# and the base_url didn't end with a slash
|
||||||
|
base_parsed = urlparse(base_url)
|
||||||
|
if not href.strip().endswith('/') and not base_parsed.path.endswith('/'):
|
||||||
|
path = path.rstrip('/')
|
||||||
|
# Add trailing slash for URLs without explicit paths (indicates directory)
|
||||||
|
# But skip this for special protocols that don't use standard URL structure
|
||||||
|
elif not path:
|
||||||
|
special_protocols = {"javascript:", "mailto:", "tel:", "file:", "data:"}
|
||||||
|
if not any(href.strip().lower().startswith(p) for p in special_protocols):
|
||||||
|
path = '/'
|
||||||
|
|
||||||
# ── query ──
|
# ── query ──
|
||||||
query = parsed.query
|
query = parsed.query
|
||||||
@@ -2221,6 +2249,8 @@ def normalize_url(
|
|||||||
}
|
}
|
||||||
if extra_drop_params:
|
if extra_drop_params:
|
||||||
default_tracking |= {p.lower() for p in extra_drop_params}
|
default_tracking |= {p.lower() for p in extra_drop_params}
|
||||||
|
if params_to_remove:
|
||||||
|
default_tracking |= {p.lower() for p in params_to_remove}
|
||||||
params = [(k, v) for k, v in params if k not in default_tracking]
|
params = [(k, v) for k, v in params if k not in default_tracking]
|
||||||
|
|
||||||
if sort_query:
|
if sort_query:
|
||||||
@@ -2229,7 +2259,10 @@ def normalize_url(
|
|||||||
query = urlencode(params, doseq=True) if params else ''
|
query = urlencode(params, doseq=True) if params else ''
|
||||||
|
|
||||||
# ── fragment ──
|
# ── fragment ──
|
||||||
fragment = parsed.fragment if keep_fragment else ''
|
if remove_fragments is True:
|
||||||
|
fragment = ''
|
||||||
|
else:
|
||||||
|
fragment = parsed.fragment if keep_fragment else ''
|
||||||
|
|
||||||
# Re-assemble
|
# Re-assemble
|
||||||
normalized = urlunparse((
|
normalized = urlunparse((
|
||||||
@@ -2453,9 +2486,19 @@ def is_external_url(url: str, base_domain: str) -> bool:
|
|||||||
if not parsed.netloc: # Relative URL
|
if not parsed.netloc: # Relative URL
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Strip 'www.' from both domains for comparison
|
# Don't strip 'www.' from domains for comparison - treat www.example.com and example.com as different
|
||||||
url_domain = parsed.netloc.lower().replace("www.", "")
|
url_domain = parsed.netloc.lower()
|
||||||
base = base_domain.lower().replace("www.", "")
|
base = base_domain.lower()
|
||||||
|
|
||||||
|
# Strip user credentials from URL domain
|
||||||
|
if '@' in url_domain:
|
||||||
|
url_domain = url_domain.split('@', 1)[1]
|
||||||
|
|
||||||
|
# Strip ports from both for comparison (any port should be considered same domain)
|
||||||
|
if ':' in url_domain:
|
||||||
|
url_domain = url_domain.rsplit(':', 1)[0]
|
||||||
|
if ':' in base:
|
||||||
|
base = base.rsplit(':', 1)[0]
|
||||||
|
|
||||||
# Check if URL domain ends with base domain
|
# Check if URL domain ends with base domain
|
||||||
return not url_domain.endswith(base)
|
return not url_domain.endswith(base)
|
||||||
|
|||||||
@@ -1,154 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
|
|
||||||
|
|
||||||
|
|
||||||
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
|
|
||||||
"""Test a specific configuration"""
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Configuration: {name}")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
async with AsyncWebCrawler(verbose=False) as crawler:
|
|
||||||
adaptive = AdaptiveCrawler(crawler, config)
|
|
||||||
result = await adaptive.digest(start_url=url, query=query)
|
|
||||||
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("CRAWL STATISTICS")
|
|
||||||
print("="*50)
|
|
||||||
adaptive.print_stats(detailed=False)
|
|
||||||
|
|
||||||
# Get the most relevant content found
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("MOST RELEVANT PAGES")
|
|
||||||
print("="*50)
|
|
||||||
|
|
||||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
|
||||||
for i, page in enumerate(relevant_pages, 1):
|
|
||||||
print(f"\n{i}. {page['url']}")
|
|
||||||
print(f" Relevance Score: {page['score']:.2%}")
|
|
||||||
|
|
||||||
# Show a snippet of the content
|
|
||||||
content = page['content'] or ""
|
|
||||||
if content:
|
|
||||||
snippet = content[:200].replace('\n', ' ')
|
|
||||||
if len(content) > 200:
|
|
||||||
snippet += "..."
|
|
||||||
print(f" Preview: {snippet}")
|
|
||||||
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
print(f"Pages crawled: {len(result.crawled_urls)}")
|
|
||||||
print(f"Final confidence: {adaptive.confidence:.1%}")
|
|
||||||
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
|
|
||||||
|
|
||||||
if result.metrics.get('is_irrelevant', False):
|
|
||||||
print("⚠️ Query detected as irrelevant!")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def llm_embedding():
|
|
||||||
"""Demonstrate various embedding configurations"""
|
|
||||||
|
|
||||||
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Base URL and query for testing
|
|
||||||
test_url = "https://docs.python.org/3/library/asyncio.html"
|
|
||||||
|
|
||||||
openai_llm_config = LLMConfig(
|
|
||||||
provider='openai/text-embedding-3-small',
|
|
||||||
api_token=os.getenv('OPENAI_API_KEY'),
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=2000
|
|
||||||
)
|
|
||||||
config_openai = AdaptiveConfig(
|
|
||||||
strategy="embedding",
|
|
||||||
max_pages=10,
|
|
||||||
|
|
||||||
# Use OpenAI embeddings
|
|
||||||
embedding_llm_config=openai_llm_config,
|
|
||||||
# embedding_llm_config={
|
|
||||||
# 'provider': 'openai/text-embedding-3-small',
|
|
||||||
# 'api_token': os.getenv('OPENAI_API_KEY')
|
|
||||||
# },
|
|
||||||
|
|
||||||
# OpenAI embeddings are high quality, can be stricter
|
|
||||||
embedding_k_exp=4.0,
|
|
||||||
n_query_variations=12
|
|
||||||
)
|
|
||||||
|
|
||||||
await test_configuration(
|
|
||||||
"OpenAI Embeddings",
|
|
||||||
config_openai,
|
|
||||||
test_url,
|
|
||||||
# "event-driven architecture patterns"
|
|
||||||
"async await context managers coroutines"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def basic_adaptive_crawling():
|
|
||||||
"""Basic adaptive crawling example"""
|
|
||||||
|
|
||||||
# Initialize the crawler
|
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
|
||||||
# Create an adaptive crawler with default settings (statistical strategy)
|
|
||||||
adaptive = AdaptiveCrawler(crawler)
|
|
||||||
|
|
||||||
# Note: You can also use embedding strategy for semantic understanding:
|
|
||||||
# from crawl4ai import AdaptiveConfig
|
|
||||||
# config = AdaptiveConfig(strategy="embedding")
|
|
||||||
# adaptive = AdaptiveCrawler(crawler, config)
|
|
||||||
|
|
||||||
# Start adaptive crawling
|
|
||||||
print("Starting adaptive crawl for Python async programming information...")
|
|
||||||
result = await adaptive.digest(
|
|
||||||
start_url="https://docs.python.org/3/library/asyncio.html",
|
|
||||||
query="async await context managers coroutines"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Display crawl statistics
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("CRAWL STATISTICS")
|
|
||||||
print("="*50)
|
|
||||||
adaptive.print_stats(detailed=False)
|
|
||||||
|
|
||||||
# Get the most relevant content found
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("MOST RELEVANT PAGES")
|
|
||||||
print("="*50)
|
|
||||||
|
|
||||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
|
||||||
for i, page in enumerate(relevant_pages, 1):
|
|
||||||
print(f"\n{i}. {page['url']}")
|
|
||||||
print(f" Relevance Score: {page['score']:.2%}")
|
|
||||||
|
|
||||||
# Show a snippet of the content
|
|
||||||
content = page['content'] or ""
|
|
||||||
if content:
|
|
||||||
snippet = content[:200].replace('\n', ' ')
|
|
||||||
if len(content) > 200:
|
|
||||||
snippet += "..."
|
|
||||||
print(f" Preview: {snippet}")
|
|
||||||
|
|
||||||
# Show final confidence
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
print(f"Final Confidence: {adaptive.confidence:.2%}")
|
|
||||||
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
|
|
||||||
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
|
|
||||||
|
|
||||||
|
|
||||||
if adaptive.confidence >= 0.8:
|
|
||||||
print("✓ High confidence - can answer detailed questions about async Python")
|
|
||||||
elif adaptive.confidence >= 0.6:
|
|
||||||
print("~ Moderate confidence - can answer basic questions")
|
|
||||||
else:
|
|
||||||
print("✗ Low confidence - need more information")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(llm_embedding())
|
|
||||||
# asyncio.run(basic_adaptive_crawling())
|
|
||||||
@@ -108,19 +108,7 @@ config = AdaptiveConfig(
|
|||||||
embedding_min_confidence_threshold=0.1 # Stop if completely irrelevant
|
embedding_min_confidence_threshold=0.1 # Stop if completely irrelevant
|
||||||
)
|
)
|
||||||
|
|
||||||
# With custom LLM provider for query expansion (recommended)
|
# With custom embedding provider (e.g., OpenAI)
|
||||||
from crawl4ai import LLMConfig
|
|
||||||
|
|
||||||
config = AdaptiveConfig(
|
|
||||||
strategy="embedding",
|
|
||||||
embedding_llm_config=LLMConfig(
|
|
||||||
provider='openai/text-embedding-3-small',
|
|
||||||
api_token='your-api-key',
|
|
||||||
temperature=0.7
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Alternative: Dictionary format (backward compatible)
|
|
||||||
config = AdaptiveConfig(
|
config = AdaptiveConfig(
|
||||||
strategy="embedding",
|
strategy="embedding",
|
||||||
embedding_llm_config={
|
embedding_llm_config={
|
||||||
|
|||||||
@@ -1,154 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
|
|
||||||
|
|
||||||
|
|
||||||
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
|
|
||||||
"""Test a specific configuration"""
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Configuration: {name}")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
async with AsyncWebCrawler(verbose=False) as crawler:
|
|
||||||
adaptive = AdaptiveCrawler(crawler, config)
|
|
||||||
result = await adaptive.digest(start_url=url, query=query)
|
|
||||||
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("CRAWL STATISTICS")
|
|
||||||
print("="*50)
|
|
||||||
adaptive.print_stats(detailed=False)
|
|
||||||
|
|
||||||
# Get the most relevant content found
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("MOST RELEVANT PAGES")
|
|
||||||
print("="*50)
|
|
||||||
|
|
||||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
|
||||||
for i, page in enumerate(relevant_pages, 1):
|
|
||||||
print(f"\n{i}. {page['url']}")
|
|
||||||
print(f" Relevance Score: {page['score']:.2%}")
|
|
||||||
|
|
||||||
# Show a snippet of the content
|
|
||||||
content = page['content'] or ""
|
|
||||||
if content:
|
|
||||||
snippet = content[:200].replace('\n', ' ')
|
|
||||||
if len(content) > 200:
|
|
||||||
snippet += "..."
|
|
||||||
print(f" Preview: {snippet}")
|
|
||||||
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
print(f"Pages crawled: {len(result.crawled_urls)}")
|
|
||||||
print(f"Final confidence: {adaptive.confidence:.1%}")
|
|
||||||
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
|
|
||||||
|
|
||||||
if result.metrics.get('is_irrelevant', False):
|
|
||||||
print("⚠️ Query detected as irrelevant!")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def llm_embedding():
|
|
||||||
"""Demonstrate various embedding configurations"""
|
|
||||||
|
|
||||||
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Base URL and query for testing
|
|
||||||
test_url = "https://docs.python.org/3/library/asyncio.html"
|
|
||||||
|
|
||||||
openai_llm_config = LLMConfig(
|
|
||||||
provider='openai/text-embedding-3-small',
|
|
||||||
api_token=os.getenv('OPENAI_API_KEY'),
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=2000
|
|
||||||
)
|
|
||||||
config_openai = AdaptiveConfig(
|
|
||||||
strategy="embedding",
|
|
||||||
max_pages=10,
|
|
||||||
|
|
||||||
# Use OpenAI embeddings
|
|
||||||
embedding_llm_config=openai_llm_config,
|
|
||||||
# embedding_llm_config={
|
|
||||||
# 'provider': 'openai/text-embedding-3-small',
|
|
||||||
# 'api_token': os.getenv('OPENAI_API_KEY')
|
|
||||||
# },
|
|
||||||
|
|
||||||
# OpenAI embeddings are high quality, can be stricter
|
|
||||||
embedding_k_exp=4.0,
|
|
||||||
n_query_variations=12
|
|
||||||
)
|
|
||||||
|
|
||||||
await test_configuration(
|
|
||||||
"OpenAI Embeddings",
|
|
||||||
config_openai,
|
|
||||||
test_url,
|
|
||||||
# "event-driven architecture patterns"
|
|
||||||
"async await context managers coroutines"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def basic_adaptive_crawling():
|
|
||||||
"""Basic adaptive crawling example"""
|
|
||||||
|
|
||||||
# Initialize the crawler
|
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
|
||||||
# Create an adaptive crawler with default settings (statistical strategy)
|
|
||||||
adaptive = AdaptiveCrawler(crawler)
|
|
||||||
|
|
||||||
# Note: You can also use embedding strategy for semantic understanding:
|
|
||||||
# from crawl4ai import AdaptiveConfig
|
|
||||||
# config = AdaptiveConfig(strategy="embedding")
|
|
||||||
# adaptive = AdaptiveCrawler(crawler, config)
|
|
||||||
|
|
||||||
# Start adaptive crawling
|
|
||||||
print("Starting adaptive crawl for Python async programming information...")
|
|
||||||
result = await adaptive.digest(
|
|
||||||
start_url="https://docs.python.org/3/library/asyncio.html",
|
|
||||||
query="async await context managers coroutines"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Display crawl statistics
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("CRAWL STATISTICS")
|
|
||||||
print("="*50)
|
|
||||||
adaptive.print_stats(detailed=False)
|
|
||||||
|
|
||||||
# Get the most relevant content found
|
|
||||||
print("\n" + "="*50)
|
|
||||||
print("MOST RELEVANT PAGES")
|
|
||||||
print("="*50)
|
|
||||||
|
|
||||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
|
||||||
for i, page in enumerate(relevant_pages, 1):
|
|
||||||
print(f"\n{i}. {page['url']}")
|
|
||||||
print(f" Relevance Score: {page['score']:.2%}")
|
|
||||||
|
|
||||||
# Show a snippet of the content
|
|
||||||
content = page['content'] or ""
|
|
||||||
if content:
|
|
||||||
snippet = content[:200].replace('\n', ' ')
|
|
||||||
if len(content) > 200:
|
|
||||||
snippet += "..."
|
|
||||||
print(f" Preview: {snippet}")
|
|
||||||
|
|
||||||
# Show final confidence
|
|
||||||
print(f"\n{'='*50}")
|
|
||||||
print(f"Final Confidence: {adaptive.confidence:.2%}")
|
|
||||||
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
|
|
||||||
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
|
|
||||||
|
|
||||||
|
|
||||||
if adaptive.confidence >= 0.8:
|
|
||||||
print("✓ High confidence - can answer detailed questions about async Python")
|
|
||||||
elif adaptive.confidence >= 0.6:
|
|
||||||
print("~ Moderate confidence - can answer basic questions")
|
|
||||||
else:
|
|
||||||
print("✗ Low confidence - need more information")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(llm_embedding())
|
|
||||||
# asyncio.run(basic_adaptive_crawling())
|
|
||||||
Reference in New Issue
Block a user