Compare commits
13 Commits
fix/reques
...
fix/case_s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89679cee67 | ||
|
|
84ba78c852 | ||
|
|
3899ac3d3b | ||
|
|
23431d8109 | ||
|
|
1717827732 | ||
|
|
f8eaf01ed1 | ||
|
|
14b42b1f9a | ||
|
|
3bc56dd028 | ||
|
|
0482c1eafc | ||
|
|
6e728096fa | ||
|
|
4ed33fce9e | ||
|
|
f7a3366f72 | ||
|
|
88a9fbbb7e |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -265,7 +265,7 @@ CLAUDE.md
|
||||
tests/**/test_site
|
||||
tests/**/reports
|
||||
tests/**/benchmark_reports
|
||||
|
||||
test_scripts/
|
||||
docs/**/data
|
||||
.codecat/
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig
|
||||
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig, LLMConfig
|
||||
from crawl4ai.models import Link, CrawlResult
|
||||
import numpy as np
|
||||
|
||||
@@ -178,7 +178,7 @@ class AdaptiveConfig:
|
||||
|
||||
# Embedding strategy parameters
|
||||
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_llm_config: Optional[Dict] = None # Separate config for embeddings
|
||||
embedding_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings
|
||||
n_query_variations: int = 10
|
||||
coverage_threshold: float = 0.85
|
||||
alpha_shape_alpha: float = 0.5
|
||||
@@ -250,6 +250,30 @@ class AdaptiveConfig:
|
||||
assert 0 <= self.embedding_quality_max_confidence <= 1, "embedding_quality_max_confidence must be between 0 and 1"
|
||||
assert self.embedding_quality_scale_factor > 0, "embedding_quality_scale_factor must be positive"
|
||||
assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1"
|
||||
|
||||
@property
|
||||
def _embedding_llm_config_dict(self) -> Optional[Dict]:
|
||||
"""Convert LLMConfig to dict format for backward compatibility."""
|
||||
if self.embedding_llm_config is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.embedding_llm_config, dict):
|
||||
# Already a dict - return as-is for backward compatibility
|
||||
return self.embedding_llm_config
|
||||
|
||||
# Convert LLMConfig object to dict format
|
||||
return {
|
||||
'provider': self.embedding_llm_config.provider,
|
||||
'api_token': self.embedding_llm_config.api_token,
|
||||
'base_url': getattr(self.embedding_llm_config, 'base_url', None),
|
||||
'temperature': getattr(self.embedding_llm_config, 'temperature', None),
|
||||
'max_tokens': getattr(self.embedding_llm_config, 'max_tokens', None),
|
||||
'top_p': getattr(self.embedding_llm_config, 'top_p', None),
|
||||
'frequency_penalty': getattr(self.embedding_llm_config, 'frequency_penalty', None),
|
||||
'presence_penalty': getattr(self.embedding_llm_config, 'presence_penalty', None),
|
||||
'stop': getattr(self.embedding_llm_config, 'stop', None),
|
||||
'n': getattr(self.embedding_llm_config, 'n', None),
|
||||
}
|
||||
|
||||
|
||||
class CrawlStrategy(ABC):
|
||||
@@ -593,7 +617,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
class EmbeddingStrategy(CrawlStrategy):
|
||||
"""Embedding-based adaptive crawling using semantic space coverage"""
|
||||
|
||||
def __init__(self, embedding_model: str = None, llm_config: Dict = None):
|
||||
def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dict] = None):
|
||||
self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
self.llm_config = llm_config
|
||||
self._embedding_cache = {}
|
||||
@@ -605,14 +629,24 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
self._kb_embeddings_hash = None # Track KB changes
|
||||
self._validation_embeddings_cache = None # Cache validation query embeddings
|
||||
self._kb_similarity_threshold = 0.95 # Threshold for deduplication
|
||||
|
||||
def _get_embedding_llm_config_dict(self) -> Dict:
|
||||
"""Get embedding LLM config as dict with fallback to default."""
|
||||
if hasattr(self, 'config') and self.config:
|
||||
config_dict = self.config._embedding_llm_config_dict
|
||||
if config_dict:
|
||||
return config_dict
|
||||
|
||||
# Fallback to default if no config provided
|
||||
return {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
|
||||
async def _get_embeddings(self, texts: List[str]) -> Any:
|
||||
"""Get embeddings using configured method"""
|
||||
from .utils import get_text_embeddings
|
||||
embedding_llm_config = {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
||||
return await get_text_embeddings(
|
||||
texts,
|
||||
embedding_llm_config,
|
||||
@@ -679,8 +713,20 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
Return as a JSON array of strings."""
|
||||
|
||||
# Use the LLM for query generation
|
||||
provider = self.llm_config.get('provider', 'openai/gpt-4o-mini') if self.llm_config else 'openai/gpt-4o-mini'
|
||||
api_token = self.llm_config.get('api_token') if self.llm_config else None
|
||||
# Convert LLMConfig to dict if needed
|
||||
llm_config_dict = None
|
||||
if self.llm_config:
|
||||
if isinstance(self.llm_config, dict):
|
||||
llm_config_dict = self.llm_config
|
||||
else:
|
||||
# Convert LLMConfig object to dict
|
||||
llm_config_dict = {
|
||||
'provider': self.llm_config.provider,
|
||||
'api_token': self.llm_config.api_token
|
||||
}
|
||||
|
||||
provider = llm_config_dict.get('provider', 'openai/gpt-4o-mini') if llm_config_dict else 'openai/gpt-4o-mini'
|
||||
api_token = llm_config_dict.get('api_token') if llm_config_dict else None
|
||||
|
||||
# response = perform_completion_with_backoff(
|
||||
# provider=provider,
|
||||
@@ -843,10 +889,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
# Batch embed only uncached links
|
||||
if texts_to_embed:
|
||||
embedding_llm_config = {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
||||
new_embeddings = await get_text_embeddings(texts_to_embed, embedding_llm_config, self.embedding_model)
|
||||
|
||||
# Cache the new embeddings
|
||||
@@ -1184,10 +1227,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
return
|
||||
|
||||
# Get embeddings for new texts
|
||||
embedding_llm_config = {
|
||||
'provider': 'openai/text-embedding-3-small',
|
||||
'api_token': os.getenv('OPENAI_API_KEY')
|
||||
}
|
||||
embedding_llm_config = self._get_embedding_llm_config_dict()
|
||||
new_embeddings = await get_text_embeddings(new_texts, embedding_llm_config, self.embedding_model)
|
||||
|
||||
# Deduplicate embeddings before adding to KB
|
||||
@@ -1256,10 +1296,12 @@ class AdaptiveCrawler:
|
||||
if strategy_name == "statistical":
|
||||
return StatisticalStrategy()
|
||||
elif strategy_name == "embedding":
|
||||
return EmbeddingStrategy(
|
||||
strategy = EmbeddingStrategy(
|
||||
embedding_model=self.config.embedding_model,
|
||||
llm_config=self.config.embedding_llm_config
|
||||
)
|
||||
strategy.config = self.config # Pass config to strategy
|
||||
return strategy
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy_name}")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import Union
|
||||
import warnings
|
||||
from .config import (
|
||||
DEFAULT_PROVIDER,
|
||||
DEFAULT_PROVIDER_API_KEY,
|
||||
@@ -257,24 +258,39 @@ class ProxyConfig:
|
||||
|
||||
@staticmethod
|
||||
def from_string(proxy_str: str) -> "ProxyConfig":
|
||||
"""Create a ProxyConfig from a string in the format 'ip:port:username:password'."""
|
||||
parts = proxy_str.split(":")
|
||||
if len(parts) == 4: # ip:port:username:password
|
||||
"""Create a ProxyConfig from a string.
|
||||
|
||||
Supported formats:
|
||||
- 'http://username:password@ip:port'
|
||||
- 'http://ip:port'
|
||||
- 'socks5://ip:port'
|
||||
- 'ip:port:username:password'
|
||||
- 'ip:port'
|
||||
"""
|
||||
s = (proxy_str or "").strip()
|
||||
# URL with credentials
|
||||
if "@" in s and "://" in s:
|
||||
auth_part, server_part = s.split("@", 1)
|
||||
protocol, credentials = auth_part.split("://", 1)
|
||||
if ":" in credentials:
|
||||
username, password = credentials.split(":", 1)
|
||||
return ProxyConfig(
|
||||
server=f"{protocol}://{server_part}",
|
||||
username=username,
|
||||
password=password,
|
||||
)
|
||||
# URL without credentials (keep scheme)
|
||||
if "://" in s and "@" not in s:
|
||||
return ProxyConfig(server=s)
|
||||
# Colon separated forms
|
||||
parts = s.split(":")
|
||||
if len(parts) == 4:
|
||||
ip, port, username, password = parts
|
||||
return ProxyConfig(
|
||||
server=f"http://{ip}:{port}",
|
||||
username=username,
|
||||
password=password,
|
||||
ip=ip
|
||||
)
|
||||
elif len(parts) == 2: # ip:port only
|
||||
return ProxyConfig(server=f"http://{ip}:{port}", username=username, password=password)
|
||||
if len(parts) == 2:
|
||||
ip, port = parts
|
||||
return ProxyConfig(
|
||||
server=f"http://{ip}:{port}",
|
||||
ip=ip
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid proxy string format: {proxy_str}")
|
||||
return ProxyConfig(server=f"http://{ip}:{port}")
|
||||
raise ValueError(f"Invalid proxy string format: {proxy_str}")
|
||||
|
||||
@staticmethod
|
||||
def from_dict(proxy_dict: Dict) -> "ProxyConfig":
|
||||
@@ -438,6 +454,7 @@ class BrowserConfig:
|
||||
host: str = "localhost",
|
||||
enable_stealth: bool = False,
|
||||
):
|
||||
|
||||
self.browser_type = browser_type
|
||||
self.headless = headless
|
||||
self.browser_mode = browser_mode
|
||||
@@ -450,13 +467,22 @@ class BrowserConfig:
|
||||
if self.browser_type in ["firefox", "webkit"]:
|
||||
self.channel = ""
|
||||
self.chrome_channel = ""
|
||||
if proxy:
|
||||
warnings.warn("The 'proxy' parameter is deprecated and will be removed in a future release. Use 'proxy_config' instead.", UserWarning)
|
||||
self.proxy = proxy
|
||||
self.proxy_config = proxy_config
|
||||
if isinstance(self.proxy_config, dict):
|
||||
self.proxy_config = ProxyConfig.from_dict(self.proxy_config)
|
||||
if isinstance(self.proxy_config, str):
|
||||
self.proxy_config = ProxyConfig.from_string(self.proxy_config)
|
||||
|
||||
|
||||
if self.proxy and self.proxy_config:
|
||||
warnings.warn("Both 'proxy' and 'proxy_config' are provided. 'proxy_config' will take precedence.", UserWarning)
|
||||
self.proxy = None
|
||||
elif self.proxy:
|
||||
# Convert proxy string to ProxyConfig if proxy_config is not provided
|
||||
self.proxy_config = ProxyConfig.from_string(self.proxy)
|
||||
self.proxy = None
|
||||
|
||||
self.viewport_width = viewport_width
|
||||
self.viewport_height = viewport_height
|
||||
|
||||
@@ -15,6 +15,7 @@ from .js_snippet import load_js_script
|
||||
from .config import DOWNLOAD_PAGE_TIMEOUT
|
||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from .utils import get_chromium_path
|
||||
import warnings
|
||||
|
||||
|
||||
BROWSER_DISABLE_OPTIONS = [
|
||||
@@ -741,17 +742,18 @@ class BrowserManager:
|
||||
)
|
||||
os.makedirs(browser_args["downloads_path"], exist_ok=True)
|
||||
|
||||
if self.config.proxy or self.config.proxy_config:
|
||||
if self.config.proxy:
|
||||
warnings.warn(
|
||||
"BrowserConfig.proxy is deprecated and ignored. Use proxy_config instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.config.proxy_config:
|
||||
from playwright.async_api import ProxySettings
|
||||
|
||||
proxy_settings = (
|
||||
ProxySettings(server=self.config.proxy)
|
||||
if self.config.proxy
|
||||
else ProxySettings(
|
||||
server=self.config.proxy_config.server,
|
||||
username=self.config.proxy_config.username,
|
||||
password=self.config.proxy_config.password,
|
||||
)
|
||||
proxy_settings = ProxySettings(
|
||||
server=self.config.proxy_config.server,
|
||||
username=self.config.proxy_config.username,
|
||||
password=self.config.proxy_config.password,
|
||||
)
|
||||
browser_args["proxy"] = proxy_settings
|
||||
|
||||
|
||||
@@ -122,11 +122,6 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
||||
|
||||
valid_links.append(base_url)
|
||||
|
||||
# If we have more valid links than capacity, limit them
|
||||
if len(valid_links) > remaining_capacity:
|
||||
valid_links = valid_links[:remaining_capacity]
|
||||
self.logger.info(f"Limiting to {remaining_capacity} URLs due to max_pages limit")
|
||||
|
||||
# Record the new depths and add to next_links
|
||||
for url in valid_links:
|
||||
depths[url] = new_depth
|
||||
@@ -146,7 +141,8 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
||||
"""
|
||||
queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
||||
# Push the initial URL with score 0 and depth 0.
|
||||
await queue.put((0, 0, start_url, None))
|
||||
initial_score = self.url_scorer.score(start_url) if self.url_scorer else 0
|
||||
await queue.put((-initial_score, 0, start_url, None))
|
||||
visited: Set[str] = set()
|
||||
depths: Dict[str, int] = {start_url: 0}
|
||||
|
||||
@@ -193,7 +189,7 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
||||
result.metadata = result.metadata or {}
|
||||
result.metadata["depth"] = depth
|
||||
result.metadata["parent_url"] = parent_url
|
||||
result.metadata["score"] = score
|
||||
result.metadata["score"] = -score
|
||||
|
||||
# Count only successful crawls toward max_pages limit
|
||||
if result.success:
|
||||
@@ -214,7 +210,7 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
||||
for new_url, new_parent in new_links:
|
||||
new_depth = depths.get(new_url, depth + 1)
|
||||
new_score = self.url_scorer.score(new_url) if self.url_scorer else 0
|
||||
await queue.put((new_score, new_depth, new_url, new_parent))
|
||||
await queue.put((-new_score, new_depth, new_url, new_parent))
|
||||
|
||||
# End of crawl.
|
||||
|
||||
|
||||
@@ -2177,19 +2177,19 @@ def normalize_url(
|
||||
str | None
|
||||
A clean, canonical URL or None if href is empty/None.
|
||||
"""
|
||||
if not href:
|
||||
if not href or not href.strip():
|
||||
return None
|
||||
|
||||
# Resolve relative paths first
|
||||
full_url = urljoin(base_url, href.strip())
|
||||
|
||||
|
||||
# Preserve HTTPS if requested and original scheme was HTTPS
|
||||
if preserve_https and original_scheme == 'https':
|
||||
parsed_full = urlparse(full_url)
|
||||
parsed_base = urlparse(base_url)
|
||||
# Only preserve HTTPS for same-domain links (not protocol-relative URLs)
|
||||
# Protocol-relative URLs (//example.com) should follow the base URL's scheme
|
||||
if (parsed_full.scheme == 'http' and
|
||||
if (parsed_full.scheme == 'http' and
|
||||
parsed_full.netloc == parsed_base.netloc and
|
||||
not href.strip().startswith('//')):
|
||||
full_url = full_url.replace('http://', 'https://', 1)
|
||||
@@ -2199,6 +2199,14 @@ def normalize_url(
|
||||
|
||||
# ── netloc ──
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
# Remove default ports
|
||||
if ':' in netloc:
|
||||
host, port = netloc.rsplit(':', 1)
|
||||
if (parsed.scheme == 'http' and port == '80') or (parsed.scheme == 'https' and port == '443'):
|
||||
netloc = host
|
||||
else:
|
||||
netloc = f"{host}:{port}"
|
||||
|
||||
# ── path ──
|
||||
# Strip duplicate slashes and trailing "/" (except root)
|
||||
@@ -2212,21 +2220,25 @@ def normalize_url(
|
||||
query = parsed.query
|
||||
if query:
|
||||
# explode, mutate, then rebuild
|
||||
params = [(k.lower(), v) for k, v in parse_qsl(query, keep_blank_values=True)]
|
||||
params = list(parse_qsl(query, keep_blank_values=True)) # Parse query string into key-value pairs, preserving blank values
|
||||
|
||||
if drop_query_tracking:
|
||||
# Define default tracking parameters to remove for cleaner URLs
|
||||
default_tracking = {
|
||||
'utm_source', 'utm_medium', 'utm_campaign', 'utm_term',
|
||||
'utm_content', 'gclid', 'fbclid', 'ref', 'ref_src'
|
||||
}
|
||||
if extra_drop_params:
|
||||
default_tracking |= {p.lower() for p in extra_drop_params}
|
||||
params = [(k, v) for k, v in params if k not in default_tracking]
|
||||
default_tracking |= {p.lower() for p in extra_drop_params} # Add any extra parameters to drop, case-insensitive
|
||||
params = [(k, v) for k, v in params if k not in default_tracking] # Filter out tracking parameters
|
||||
|
||||
# Normalize parameter keys
|
||||
params = [(k, v) for k, v in params]
|
||||
|
||||
if sort_query:
|
||||
params.sort(key=lambda kv: kv[0])
|
||||
params.sort(key=lambda kv: kv[0]) # Sort parameters alphabetically by key (now lowercase)
|
||||
|
||||
query = urlencode(params, doseq=True) if params else ''
|
||||
query = urlencode(params, doseq=True) if params else '' # Rebuild query string, handling sequences properly
|
||||
|
||||
# ── fragment ──
|
||||
fragment = parsed.fragment if keep_fragment else ''
|
||||
|
||||
@@ -28,25 +28,43 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
|
||||
signing_key = get_jwk_from_secret(SECRET_KEY)
|
||||
return instance.encode(to_encode, signing_key, alg='HS256')
|
||||
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
|
||||
def verify_token(credentials: HTTPAuthorizationCredentials) -> Dict:
|
||||
"""Verify the JWT token from the Authorization header."""
|
||||
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
if not credentials or not credentials.credentials:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="No token provided",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
verifying_key = get_jwk_from_secret(SECRET_KEY)
|
||||
try:
|
||||
payload = instance.decode(token, verifying_key, do_time_check=True, algorithms='HS256')
|
||||
return payload
|
||||
except Exception:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Invalid or expired token: {str(e)}",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
def get_token_dependency(config: Dict):
|
||||
"""Return the token dependency if JWT is enabled, else a function that returns None."""
|
||||
|
||||
|
||||
if config.get("security", {}).get("jwt_enabled", False):
|
||||
return verify_token
|
||||
def jwt_required(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict:
|
||||
"""Enforce JWT authentication when enabled."""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Authentication required. Please provide a valid Bearer token.",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
return verify_token(credentials)
|
||||
return jwt_required
|
||||
else:
|
||||
return lambda: None
|
||||
|
||||
|
||||
@@ -7520,17 +7520,18 @@ class BrowserManager:
|
||||
)
|
||||
os.makedirs(browser_args["downloads_path"], exist_ok=True)
|
||||
|
||||
if self.config.proxy or self.config.proxy_config:
|
||||
if self.config.proxy:
|
||||
warnings.warn(
|
||||
"BrowserConfig.proxy is deprecated and ignored. Use proxy_config instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.config.proxy_config:
|
||||
from playwright.async_api import ProxySettings
|
||||
|
||||
proxy_settings = (
|
||||
ProxySettings(server=self.config.proxy)
|
||||
if self.config.proxy
|
||||
else ProxySettings(
|
||||
server=self.config.proxy_config.server,
|
||||
username=self.config.proxy_config.username,
|
||||
password=self.config.proxy_config.password,
|
||||
)
|
||||
proxy_settings = ProxySettings(
|
||||
server=self.config.proxy_config.server,
|
||||
username=self.config.proxy_config.username,
|
||||
password=self.config.proxy_config.password,
|
||||
)
|
||||
browser_args["proxy"] = proxy_settings
|
||||
|
||||
|
||||
@@ -38,8 +38,8 @@ rate_limiting:
|
||||
|
||||
# Security Configuration
|
||||
security:
|
||||
enabled: false
|
||||
jwt_enabled: false
|
||||
enabled: false
|
||||
jwt_enabled: false
|
||||
https_redirect: false
|
||||
trusted_hosts: ["*"]
|
||||
headers:
|
||||
|
||||
154
docs/examples/adaptive_crawling/llm_config_example.py
Normal file
154
docs/examples/adaptive_crawling/llm_config_example.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import os
|
||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
|
||||
|
||||
|
||||
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
|
||||
"""Test a specific configuration"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Configuration: {name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
async with AsyncWebCrawler(verbose=False) as crawler:
|
||||
adaptive = AdaptiveCrawler(crawler, config)
|
||||
result = await adaptive.digest(start_url=url, query=query)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Pages crawled: {len(result.crawled_urls)}")
|
||||
print(f"Final confidence: {adaptive.confidence:.1%}")
|
||||
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
|
||||
|
||||
if result.metrics.get('is_irrelevant', False):
|
||||
print("⚠️ Query detected as irrelevant!")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def llm_embedding():
|
||||
"""Demonstrate various embedding configurations"""
|
||||
|
||||
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
|
||||
print("=" * 60)
|
||||
|
||||
# Base URL and query for testing
|
||||
test_url = "https://docs.python.org/3/library/asyncio.html"
|
||||
|
||||
openai_llm_config = LLMConfig(
|
||||
provider='openai/text-embedding-3-small',
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
config_openai = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
max_pages=10,
|
||||
|
||||
# Use OpenAI embeddings
|
||||
embedding_llm_config=openai_llm_config,
|
||||
# embedding_llm_config={
|
||||
# 'provider': 'openai/text-embedding-3-small',
|
||||
# 'api_token': os.getenv('OPENAI_API_KEY')
|
||||
# },
|
||||
|
||||
# OpenAI embeddings are high quality, can be stricter
|
||||
embedding_k_exp=4.0,
|
||||
n_query_variations=12
|
||||
)
|
||||
|
||||
await test_configuration(
|
||||
"OpenAI Embeddings",
|
||||
config_openai,
|
||||
test_url,
|
||||
# "event-driven architecture patterns"
|
||||
"async await context managers coroutines"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def basic_adaptive_crawling():
|
||||
"""Basic adaptive crawling example"""
|
||||
|
||||
# Initialize the crawler
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# Create an adaptive crawler with default settings (statistical strategy)
|
||||
adaptive = AdaptiveCrawler(crawler)
|
||||
|
||||
# Note: You can also use embedding strategy for semantic understanding:
|
||||
# from crawl4ai import AdaptiveConfig
|
||||
# config = AdaptiveConfig(strategy="embedding")
|
||||
# adaptive = AdaptiveCrawler(crawler, config)
|
||||
|
||||
# Start adaptive crawling
|
||||
print("Starting adaptive crawl for Python async programming information...")
|
||||
result = await adaptive.digest(
|
||||
start_url="https://docs.python.org/3/library/asyncio.html",
|
||||
query="async await context managers coroutines"
|
||||
)
|
||||
|
||||
# Display crawl statistics
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
# Show final confidence
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Final Confidence: {adaptive.confidence:.2%}")
|
||||
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
|
||||
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
|
||||
|
||||
|
||||
if adaptive.confidence >= 0.8:
|
||||
print("✓ High confidence - can answer detailed questions about async Python")
|
||||
elif adaptive.confidence >= 0.6:
|
||||
print("~ Moderate confidence - can answer basic questions")
|
||||
else:
|
||||
print("✗ Low confidence - need more information")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(llm_embedding())
|
||||
# asyncio.run(basic_adaptive_crawling())
|
||||
@@ -7,13 +7,13 @@ Simple proxy configuration with `BrowserConfig`:
|
||||
```python
|
||||
from crawl4ai.async_configs import BrowserConfig
|
||||
|
||||
# Using proxy URL
|
||||
browser_config = BrowserConfig(proxy="http://proxy.example.com:8080")
|
||||
# Using HTTP proxy
|
||||
browser_config = BrowserConfig(proxy_config={"server": "http://proxy.example.com:8080"})
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
result = await crawler.arun(url="https://example.com")
|
||||
|
||||
# Using SOCKS proxy
|
||||
browser_config = BrowserConfig(proxy="socks5://proxy.example.com:1080")
|
||||
browser_config = BrowserConfig(proxy_config={"server": "socks5://proxy.example.com:1080"})
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
result = await crawler.arun(url="https://example.com")
|
||||
```
|
||||
@@ -25,7 +25,11 @@ Use an authenticated proxy with `BrowserConfig`:
|
||||
```python
|
||||
from crawl4ai.async_configs import BrowserConfig
|
||||
|
||||
browser_config = BrowserConfig(proxy="http://[username]:[password]@[host]:[port]")
|
||||
browser_config = BrowserConfig(proxy_config={
|
||||
"server": "http://[host]:[port]",
|
||||
"username": "[username]",
|
||||
"password": "[password]",
|
||||
})
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
result = await crawler.arun(url="https://example.com")
|
||||
```
|
||||
|
||||
@@ -23,7 +23,7 @@ browser_cfg = BrowserConfig(
|
||||
| **`headless`** | `bool` (default: `True`) | Headless means no visible UI. `False` is handy for debugging. |
|
||||
| **`viewport_width`** | `int` (default: `1080`) | Initial page width (in px). Useful for testing responsive layouts. |
|
||||
| **`viewport_height`** | `int` (default: `600`) | Initial page height (in px). |
|
||||
| **`proxy`** | `str` (default: `None`) | Single-proxy URL if you want all traffic to go through it, e.g. `"http://user:pass@proxy:8080"`. |
|
||||
| **`proxy`** | `str` (deprecated) | Deprecated. Use `proxy_config` instead. If set, it will be auto-converted internally. |
|
||||
| **`proxy_config`** | `dict` (default: `None`) | For advanced or multi-proxy needs, specify details like `{"server": "...", "username": "...", ...}`. |
|
||||
| **`use_persistent_context`** | `bool` (default: `False`) | If `True`, uses a **persistent** browser context (keep cookies, sessions across runs). Also sets `use_managed_browser=True`. |
|
||||
| **`user_data_dir`** | `str or None` (default: `None`) | Directory to store user data (profiles, cookies). Must be set if you want permanent sessions. |
|
||||
|
||||
@@ -108,7 +108,19 @@ config = AdaptiveConfig(
|
||||
embedding_min_confidence_threshold=0.1 # Stop if completely irrelevant
|
||||
)
|
||||
|
||||
# With custom embedding provider (e.g., OpenAI)
|
||||
# With custom LLM provider for query expansion (recommended)
|
||||
from crawl4ai import LLMConfig
|
||||
|
||||
config = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
embedding_llm_config=LLMConfig(
|
||||
provider='openai/text-embedding-3-small',
|
||||
api_token='your-api-key',
|
||||
temperature=0.7
|
||||
)
|
||||
)
|
||||
|
||||
# Alternative: Dictionary format (backward compatible)
|
||||
config = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
embedding_llm_config={
|
||||
|
||||
154
tests/adaptive/test_llm_embedding.py
Normal file
154
tests/adaptive/test_llm_embedding.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import os
|
||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
|
||||
|
||||
|
||||
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
|
||||
"""Test a specific configuration"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Configuration: {name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
async with AsyncWebCrawler(verbose=False) as crawler:
|
||||
adaptive = AdaptiveCrawler(crawler, config)
|
||||
result = await adaptive.digest(start_url=url, query=query)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Pages crawled: {len(result.crawled_urls)}")
|
||||
print(f"Final confidence: {adaptive.confidence:.1%}")
|
||||
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
|
||||
|
||||
if result.metrics.get('is_irrelevant', False):
|
||||
print("⚠️ Query detected as irrelevant!")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def llm_embedding():
|
||||
"""Demonstrate various embedding configurations"""
|
||||
|
||||
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
|
||||
print("=" * 60)
|
||||
|
||||
# Base URL and query for testing
|
||||
test_url = "https://docs.python.org/3/library/asyncio.html"
|
||||
|
||||
openai_llm_config = LLMConfig(
|
||||
provider='openai/text-embedding-3-small',
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
config_openai = AdaptiveConfig(
|
||||
strategy="embedding",
|
||||
max_pages=10,
|
||||
|
||||
# Use OpenAI embeddings
|
||||
embedding_llm_config=openai_llm_config,
|
||||
# embedding_llm_config={
|
||||
# 'provider': 'openai/text-embedding-3-small',
|
||||
# 'api_token': os.getenv('OPENAI_API_KEY')
|
||||
# },
|
||||
|
||||
# OpenAI embeddings are high quality, can be stricter
|
||||
embedding_k_exp=4.0,
|
||||
n_query_variations=12
|
||||
)
|
||||
|
||||
await test_configuration(
|
||||
"OpenAI Embeddings",
|
||||
config_openai,
|
||||
test_url,
|
||||
# "event-driven architecture patterns"
|
||||
"async await context managers coroutines"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def basic_adaptive_crawling():
|
||||
"""Basic adaptive crawling example"""
|
||||
|
||||
# Initialize the crawler
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# Create an adaptive crawler with default settings (statistical strategy)
|
||||
adaptive = AdaptiveCrawler(crawler)
|
||||
|
||||
# Note: You can also use embedding strategy for semantic understanding:
|
||||
# from crawl4ai import AdaptiveConfig
|
||||
# config = AdaptiveConfig(strategy="embedding")
|
||||
# adaptive = AdaptiveCrawler(crawler, config)
|
||||
|
||||
# Start adaptive crawling
|
||||
print("Starting adaptive crawl for Python async programming information...")
|
||||
result = await adaptive.digest(
|
||||
start_url="https://docs.python.org/3/library/asyncio.html",
|
||||
query="async await context managers coroutines"
|
||||
)
|
||||
|
||||
# Display crawl statistics
|
||||
print("\n" + "="*50)
|
||||
print("CRAWL STATISTICS")
|
||||
print("="*50)
|
||||
adaptive.print_stats(detailed=False)
|
||||
|
||||
# Get the most relevant content found
|
||||
print("\n" + "="*50)
|
||||
print("MOST RELEVANT PAGES")
|
||||
print("="*50)
|
||||
|
||||
relevant_pages = adaptive.get_relevant_content(top_k=5)
|
||||
for i, page in enumerate(relevant_pages, 1):
|
||||
print(f"\n{i}. {page['url']}")
|
||||
print(f" Relevance Score: {page['score']:.2%}")
|
||||
|
||||
# Show a snippet of the content
|
||||
content = page['content'] or ""
|
||||
if content:
|
||||
snippet = content[:200].replace('\n', ' ')
|
||||
if len(content) > 200:
|
||||
snippet += "..."
|
||||
print(f" Preview: {snippet}")
|
||||
|
||||
# Show final confidence
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Final Confidence: {adaptive.confidence:.2%}")
|
||||
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
|
||||
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
|
||||
|
||||
|
||||
if adaptive.confidence >= 0.8:
|
||||
print("✓ High confidence - can answer detailed questions about async Python")
|
||||
elif adaptive.confidence >= 0.6:
|
||||
print("~ Moderate confidence - can answer basic questions")
|
||||
else:
|
||||
print("✗ Low confidence - need more information")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(llm_embedding())
|
||||
# asyncio.run(basic_adaptive_crawling())
|
||||
@@ -112,7 +112,7 @@ async def test_proxy_settings():
|
||||
headless=True,
|
||||
verbose=False,
|
||||
user_agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
|
||||
proxy="http://127.0.0.1:8080", # Assuming local proxy server for test
|
||||
proxy_config={"server": "http://127.0.0.1:8080"}, # Assuming local proxy server for test
|
||||
use_managed_browser=False,
|
||||
use_persistent_context=False,
|
||||
) as crawler:
|
||||
|
||||
117
tests/general/test_bff_scoring.py
Normal file
117
tests/general/test_bff_scoring.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple test to verify BestFirstCrawlingStrategy fixes.
|
||||
This test crawls a real website and shows that:
|
||||
1. Higher-scoring pages are crawled first (priority queue fix)
|
||||
2. Links are scored before truncation (link discovery fix)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||
from crawl4ai.deep_crawling import BestFirstCrawlingStrategy
|
||||
from crawl4ai.deep_crawling.scorers import KeywordRelevanceScorer
|
||||
|
||||
async def test_best_first_strategy():
|
||||
"""Test BestFirstCrawlingStrategy with keyword scoring"""
|
||||
|
||||
print("=" * 70)
|
||||
print("Testing BestFirstCrawlingStrategy with Real URL")
|
||||
print("=" * 70)
|
||||
print("\nThis test will:")
|
||||
print("1. Crawl Python.org documentation")
|
||||
print("2. Score pages based on keywords: 'tutorial', 'guide', 'reference'")
|
||||
print("3. Show that higher-scoring pages are crawled first")
|
||||
print("-" * 70)
|
||||
|
||||
# Create a keyword scorer that prioritizes tutorial/guide pages
|
||||
scorer = KeywordRelevanceScorer(
|
||||
keywords=["tutorial", "guide", "reference", "documentation"],
|
||||
weight=1.0,
|
||||
case_sensitive=False
|
||||
)
|
||||
|
||||
# Create the strategy with scoring
|
||||
strategy = BestFirstCrawlingStrategy(
|
||||
max_depth=2, # Crawl 2 levels deep
|
||||
max_pages=10, # Limit to 10 pages total
|
||||
url_scorer=scorer, # Use keyword scoring
|
||||
include_external=False # Only internal links
|
||||
)
|
||||
|
||||
# Configure browser and crawler
|
||||
browser_config = BrowserConfig(
|
||||
headless=True, # Run in background
|
||||
verbose=False # Reduce output noise
|
||||
)
|
||||
|
||||
crawler_config = CrawlerRunConfig(
|
||||
deep_crawl_strategy=strategy,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
print("\nStarting crawl of https://docs.python.org/3/")
|
||||
print("Looking for pages with keywords: tutorial, guide, reference, documentation")
|
||||
print("-" * 70)
|
||||
|
||||
crawled_urls = []
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
# Crawl and collect results
|
||||
results = await crawler.arun(
|
||||
url="https://docs.python.org/3/",
|
||||
config=crawler_config
|
||||
)
|
||||
|
||||
# Process results
|
||||
if isinstance(results, list):
|
||||
for result in results:
|
||||
score = result.metadata.get('score', 0) if result.metadata else 0
|
||||
depth = result.metadata.get('depth', 0) if result.metadata else 0
|
||||
crawled_urls.append({
|
||||
'url': result.url,
|
||||
'score': score,
|
||||
'depth': depth,
|
||||
'success': result.success
|
||||
})
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("CRAWL RESULTS (in order of crawling)")
|
||||
print("=" * 70)
|
||||
|
||||
for i, item in enumerate(crawled_urls, 1):
|
||||
status = "✓" if item['success'] else "✗"
|
||||
# Highlight high-scoring pages
|
||||
if item['score'] > 0.5:
|
||||
print(f"{i:2}. [{status}] Score: {item['score']:.2f} | Depth: {item['depth']} | {item['url']}")
|
||||
print(f" ^ HIGH SCORE - Contains keywords!")
|
||||
else:
|
||||
print(f"{i:2}. [{status}] Score: {item['score']:.2f} | Depth: {item['depth']} | {item['url']}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("ANALYSIS")
|
||||
print("=" * 70)
|
||||
|
||||
# Check if higher scores appear early in the crawl
|
||||
scores = [item['score'] for item in crawled_urls[1:]] # Skip initial URL
|
||||
high_score_indices = [i for i, s in enumerate(scores) if s > 0.3]
|
||||
|
||||
if high_score_indices and high_score_indices[0] < len(scores) / 2:
|
||||
print("✅ SUCCESS: Higher-scoring pages (with keywords) were crawled early!")
|
||||
print(" This confirms the priority queue fix is working.")
|
||||
else:
|
||||
print("⚠️ Check the crawl order above - higher scores should appear early")
|
||||
|
||||
# Show score distribution
|
||||
print(f"\nScore Statistics:")
|
||||
print(f" - Total pages crawled: {len(crawled_urls)}")
|
||||
print(f" - Average score: {sum(item['score'] for item in crawled_urls) / len(crawled_urls):.2f}")
|
||||
print(f" - Max score: {max(item['score'] for item in crawled_urls):.2f}")
|
||||
print(f" - Pages with keywords: {sum(1 for item in crawled_urls if item['score'] > 0.3)}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST COMPLETE")
|
||||
print("=" * 70)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n🔍 BestFirstCrawlingStrategy Simple Test\n")
|
||||
asyncio.run(test_best_first_strategy())
|
||||
@@ -24,7 +24,7 @@ CASES = [
|
||||
# --- BrowserConfig variants ---
|
||||
"BrowserConfig()",
|
||||
"BrowserConfig(headless=False, extra_args=['--disable-gpu'])",
|
||||
"BrowserConfig(browser_mode='builtin', proxy='http://1.2.3.4:8080')",
|
||||
"BrowserConfig(browser_mode='builtin', proxy_config={'server': 'http://1.2.3.4:8080'})",
|
||||
]
|
||||
|
||||
for code in CASES:
|
||||
|
||||
42
tests/proxy/test_proxy_deprecation.py
Normal file
42
tests/proxy/test_proxy_deprecation.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from crawl4ai.async_configs import BrowserConfig, ProxyConfig
|
||||
|
||||
|
||||
def test_browser_config_proxy_string_emits_deprecation_and_autoconverts():
|
||||
warnings.simplefilter("always", DeprecationWarning)
|
||||
|
||||
proxy_str = "23.95.150.145:6114:username:password"
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
cfg = BrowserConfig(proxy=proxy_str, headless=True)
|
||||
|
||||
dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
|
||||
assert dep_warnings, "Expected DeprecationWarning when using BrowserConfig(proxy=...)"
|
||||
|
||||
assert cfg.proxy is None, "cfg.proxy should be None after auto-conversion"
|
||||
assert isinstance(cfg.proxy_config, ProxyConfig), "cfg.proxy_config should be ProxyConfig instance"
|
||||
assert cfg.proxy_config.username == "username"
|
||||
assert cfg.proxy_config.password == "password"
|
||||
assert cfg.proxy_config.server.startswith("http://")
|
||||
assert cfg.proxy_config.server.endswith(":6114")
|
||||
|
||||
|
||||
def test_browser_config_with_proxy_config_emits_no_deprecation():
|
||||
warnings.simplefilter("always", DeprecationWarning)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
cfg = BrowserConfig(
|
||||
headless=True,
|
||||
proxy_config={
|
||||
"server": "http://127.0.0.1:8080",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
},
|
||||
)
|
||||
|
||||
dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
|
||||
assert not dep_warnings, "Did not expect DeprecationWarning when using proxy_config"
|
||||
assert cfg.proxy is None
|
||||
assert isinstance(cfg.proxy_config, ProxyConfig)
|
||||
849
tests/test_url_normalization_comprehensive.py
Normal file
849
tests/test_url_normalization_comprehensive.py
Normal file
@@ -0,0 +1,849 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive test suite for URL normalization functions in utils.py
|
||||
Tests all scenarios and edge cases for the updated normalize_url functions.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib.parse import urljoin, urlparse, urlunparse, parse_qsl, urlencode
|
||||
|
||||
# Add the crawl4ai package to the path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Import only the specific functions we need to test
|
||||
from crawl4ai.utils import get_base_domain, is_external_url
|
||||
|
||||
|
||||
# ANSI Color codes for beautiful console output
|
||||
class Colors:
|
||||
# Basic colors
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
MAGENTA = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
WHITE = '\033[97m'
|
||||
|
||||
# Bright colors
|
||||
BRIGHT_RED = '\033[91;1m'
|
||||
BRIGHT_GREEN = '\033[92;1m'
|
||||
BRIGHT_YELLOW = '\033[93;1m'
|
||||
BRIGHT_BLUE = '\033[94;1m'
|
||||
BRIGHT_MAGENTA = '\033[95;1m'
|
||||
BRIGHT_CYAN = '\033[96;1m'
|
||||
BRIGHT_WHITE = '\033[97;1m'
|
||||
|
||||
# Background colors
|
||||
BG_RED = '\033[41m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_YELLOW = '\033[43m'
|
||||
BG_BLUE = '\033[44m'
|
||||
|
||||
# Text styles
|
||||
BOLD = '\033[1m'
|
||||
UNDERLINE = '\033[4m'
|
||||
RESET = '\033[0m'
|
||||
|
||||
# Icons
|
||||
CHECK = '✓'
|
||||
CROSS = '✗'
|
||||
WARNING = '⚠'
|
||||
INFO = 'ℹ'
|
||||
STAR = '⭐'
|
||||
FIRE = '🔥'
|
||||
ROCKET = '🚀'
|
||||
TARGET = '🎯'
|
||||
|
||||
|
||||
def colorize(text, color):
|
||||
"""Apply color to text"""
|
||||
return f"{color}{text}{Colors.RESET}"
|
||||
|
||||
|
||||
def print_header(title, icon=""):
|
||||
"""Print a formatted header"""
|
||||
width = 80
|
||||
print(f"\n{Colors.BG_BLUE}{Colors.WHITE}{Colors.BOLD}{'=' * width}{Colors.RESET}")
|
||||
if icon:
|
||||
print(f"{Colors.BG_BLUE}{Colors.WHITE}{Colors.BOLD}{' ' * ((width - len(title) - len(icon) - 1) // 2)}{icon} {title}{' ' * ((width - len(title) - len(icon) - 1) // 2)}{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.BG_BLUE}{Colors.WHITE}{Colors.BOLD}{' ' * ((width - len(title)) // 2)}{title}{' ' * ((width - len(title)) // 2)}{Colors.RESET}")
|
||||
print(f"{Colors.BG_BLUE}{Colors.WHITE}{Colors.BOLD}{'=' * width}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_section(title, icon=""):
|
||||
"""Print a formatted section header"""
|
||||
if icon:
|
||||
print(f"\n{Colors.CYAN}{Colors.BOLD}{icon} {title}{Colors.RESET}")
|
||||
else:
|
||||
print(f"\n{Colors.CYAN}{Colors.BOLD}{title}{Colors.RESET}")
|
||||
print(f"{Colors.CYAN}{'-' * (len(title) + (len(icon) + 1 if icon else 0))}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_success(message):
|
||||
"""Print success message"""
|
||||
print(f"{Colors.GREEN}{Colors.CHECK} {message}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_error(message):
|
||||
"""Print error message"""
|
||||
print(f"{Colors.RED}{Colors.CROSS} {message}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_warning(message):
|
||||
"""Print warning message"""
|
||||
print(f"{Colors.YELLOW}{Colors.WARNING} {message}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_info(message):
|
||||
"""Print info message"""
|
||||
print(f"{Colors.BLUE}{Colors.INFO} {message}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_test_result(test_name, passed, expected=None, actual=None):
|
||||
"""Print formatted test result"""
|
||||
if passed:
|
||||
print(f" {Colors.GREEN}{Colors.CHECK} {test_name}{Colors.RESET}")
|
||||
else:
|
||||
print(f" {Colors.RED}{Colors.CROSS} {test_name}{Colors.RESET}")
|
||||
if expected is not None and actual is not None:
|
||||
print(f" {Colors.BRIGHT_RED}Expected: {expected}{Colors.RESET}")
|
||||
print(f" {Colors.BRIGHT_RED}Actual: {actual}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_progress(current, total, test_name=""):
|
||||
"""Print progress indicator"""
|
||||
percentage = (current / total) * 100
|
||||
bar_length = 40
|
||||
filled_length = int(bar_length * current // total)
|
||||
bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
||||
|
||||
sys.stdout.write(f'\r{Colors.CYAN}Progress: [{bar}] {percentage:.1f}% ({current}/{total}) {test_name}{Colors.RESET}')
|
||||
sys.stdout.flush()
|
||||
|
||||
if current == total:
|
||||
print() # New line when complete
|
||||
|
||||
# Copy the normalize_url functions directly to avoid import issues
|
||||
def normalize_url(
|
||||
href: str,
|
||||
base_url: str,
|
||||
*,
|
||||
drop_query_tracking=True,
|
||||
sort_query=True,
|
||||
keep_fragment=False,
|
||||
extra_drop_params=None,
|
||||
preserve_https=False,
|
||||
original_scheme=None
|
||||
):
|
||||
"""
|
||||
Extended URL normalizer with fixes for edge cases - copied from utils.py for testing
|
||||
"""
|
||||
if not href or not href.strip():
|
||||
return None
|
||||
|
||||
# Resolve relative paths first
|
||||
full_url = urljoin(base_url, href.strip())
|
||||
|
||||
# Preserve HTTPS if requested and original scheme was HTTPS
|
||||
if preserve_https and original_scheme == 'https':
|
||||
parsed_full = urlparse(full_url)
|
||||
parsed_base = urlparse(base_url)
|
||||
# Only preserve HTTPS for same-domain links (not protocol-relative URLs)
|
||||
# Protocol-relative URLs (//example.com) should follow the base URL's scheme
|
||||
if (parsed_full.scheme == 'http' and
|
||||
parsed_full.netloc == parsed_base.netloc and
|
||||
not href.strip().startswith('//')):
|
||||
full_url = full_url.replace('http://', 'https://', 1)
|
||||
|
||||
# Parse once, edit parts, then rebuild
|
||||
parsed = urlparse(full_url)
|
||||
|
||||
# ── netloc ──
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
# Remove default ports
|
||||
if ':' in netloc:
|
||||
host, port = netloc.rsplit(':', 1)
|
||||
if (parsed.scheme == 'http' and port == '80') or (parsed.scheme == 'https' and port == '443'):
|
||||
netloc = host
|
||||
else:
|
||||
netloc = f"{host}:{port}"
|
||||
|
||||
# ── path ──
|
||||
# Strip duplicate slashes and trailing "/" (except root)
|
||||
# IMPORTANT: Don't use quote(unquote()) as it mangles + signs in URLs
|
||||
# The path from urlparse is already properly encoded
|
||||
path = parsed.path
|
||||
if path.endswith('/') and path != '/':
|
||||
path = path.rstrip('/')
|
||||
|
||||
# ── query ──
|
||||
query = parsed.query
|
||||
if query:
|
||||
# explode, mutate, then rebuild
|
||||
params = list(parse_qsl(query, keep_blank_values=True)) # Parse query string into key-value pairs, preserving blank values
|
||||
|
||||
if drop_query_tracking:
|
||||
# Define default tracking parameters to remove for cleaner URLs
|
||||
default_tracking = {
|
||||
'utm_source', 'utm_medium', 'utm_campaign', 'utm_term',
|
||||
'utm_content', 'gclid', 'fbclid', 'ref', 'ref_src'
|
||||
}
|
||||
if extra_drop_params:
|
||||
default_tracking |= {p.lower() for p in extra_drop_params} # Add any extra parameters to drop, case-insensitive
|
||||
params = [(k, v) for k, v in params if k not in default_tracking] # Filter out tracking parameters
|
||||
|
||||
# Normalize parameter keys to lowercase
|
||||
params = [(k.lower(), v) for k, v in params]
|
||||
|
||||
if sort_query:
|
||||
params.sort(key=lambda kv: kv[0]) # Sort parameters alphabetically by key (now lowercase)
|
||||
|
||||
query = urlencode(params, doseq=True) if params else '' # Rebuild query string, handling sequences properly
|
||||
|
||||
# ── fragment ──
|
||||
fragment = parsed.fragment if keep_fragment else ''
|
||||
|
||||
# Re-assemble
|
||||
normalized = urlunparse((
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
path,
|
||||
parsed.params,
|
||||
query,
|
||||
fragment
|
||||
))
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_url_for_deep_crawl(href, base_url, preserve_https=False, original_scheme=None):
|
||||
"""Normalize URLs for deep crawling - copied from utils.py for testing"""
|
||||
if not href:
|
||||
return None
|
||||
|
||||
# Use urljoin to handle relative URLs
|
||||
full_url = urljoin(base_url, href.strip())
|
||||
|
||||
# Preserve HTTPS if requested and original scheme was HTTPS
|
||||
if preserve_https and original_scheme == 'https':
|
||||
parsed_full = urlparse(full_url)
|
||||
parsed_base = urlparse(base_url)
|
||||
# Only preserve HTTPS for same-domain links (not protocol-relative URLs)
|
||||
# Protocol-relative URLs (//example.com) should follow the base URL's scheme
|
||||
if (parsed_full.scheme == 'http' and
|
||||
parsed_full.netloc == parsed_base.netloc and
|
||||
not href.strip().startswith('//')):
|
||||
full_url = full_url.replace('http://', 'https://', 1)
|
||||
|
||||
# Parse the URL for normalization
|
||||
parsed = urlparse(full_url)
|
||||
|
||||
# Convert hostname to lowercase
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
# Remove fragment entirely
|
||||
fragment = ''
|
||||
|
||||
# Normalize query parameters if needed
|
||||
query = parsed.query
|
||||
if query:
|
||||
# Parse query parameters
|
||||
params = parse_qsl(query)
|
||||
|
||||
# Remove tracking parameters (example - customize as needed)
|
||||
tracking_params = ['utm_source', 'utm_medium', 'utm_campaign', 'ref', 'fbclid']
|
||||
params = [(k, v) for k, v in params if k not in tracking_params]
|
||||
|
||||
# Rebuild query string, sorted for consistency
|
||||
query = urlencode(params, doseq=True) if params else ''
|
||||
|
||||
# Build normalized URL
|
||||
normalized = urlunparse((
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
parsed.path.rstrip('/'), # Normalize trailing slash
|
||||
parsed.params,
|
||||
query,
|
||||
fragment
|
||||
))
|
||||
|
||||
return normalized
|
||||
|
||||
def efficient_normalize_url_for_deep_crawl(href, base_url, preserve_https=False, original_scheme=None):
|
||||
"""Efficient URL normalization with proper parsing - copied from utils.py for testing"""
|
||||
if not href:
|
||||
return None
|
||||
|
||||
# Resolve relative URLs
|
||||
full_url = urljoin(base_url, href.strip())
|
||||
|
||||
# Preserve HTTPS if requested and original scheme was HTTPS
|
||||
if preserve_https and original_scheme == 'https':
|
||||
parsed_full = urlparse(full_url)
|
||||
parsed_base = urlparse(base_url)
|
||||
# Only preserve HTTPS for same-domain links (not protocol-relative URLs)
|
||||
# Protocol-relative URLs (//example.com) should follow the base URL's scheme
|
||||
if (parsed_full.scheme == 'http' and
|
||||
parsed_full.netloc == parsed_base.netloc and
|
||||
not href.strip().startswith('//')):
|
||||
full_url = full_url.replace('http://', 'https://', 1)
|
||||
|
||||
# Use proper URL parsing
|
||||
parsed = urlparse(full_url)
|
||||
|
||||
# Only perform the most critical normalizations
|
||||
# 1. Lowercase hostname
|
||||
# 2. Remove fragment
|
||||
normalized = urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc.lower(),
|
||||
parsed.path.rstrip('/'),
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
'' # Remove fragment
|
||||
))
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class URLNormalizationTestSuite:
|
||||
"""Comprehensive test suite for URL normalization functions"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = "https://example.com/path/page.html"
|
||||
self.https_base_url = "https://example.com/path/page.html"
|
||||
self.http_base_url = "http://example.com/path/page.html"
|
||||
self.tests_run = 0
|
||||
self.tests_passed = 0
|
||||
self.tests_failed = []
|
||||
self.test_start_time = None
|
||||
self.section_stats = {}
|
||||
self.current_section = None
|
||||
|
||||
def start_section(self, section_name, icon=""):
|
||||
"""Start a new test section"""
|
||||
self.current_section = section_name
|
||||
if section_name not in self.section_stats:
|
||||
self.section_stats[section_name] = {'run': 0, 'passed': 0, 'failed': 0}
|
||||
print_section(section_name, icon)
|
||||
|
||||
def assert_equal(self, actual, expected, test_name):
|
||||
"""Assert that actual equals expected"""
|
||||
self.tests_run += 1
|
||||
if self.current_section:
|
||||
self.section_stats[self.current_section]['run'] += 1
|
||||
|
||||
if actual == expected:
|
||||
self.tests_passed += 1
|
||||
if self.current_section:
|
||||
self.section_stats[self.current_section]['passed'] += 1
|
||||
print_test_result(test_name, True)
|
||||
else:
|
||||
self.tests_failed.append({
|
||||
'name': test_name,
|
||||
'expected': expected,
|
||||
'actual': actual,
|
||||
'section': self.current_section
|
||||
})
|
||||
if self.current_section:
|
||||
self.section_stats[self.current_section]['failed'] += 1
|
||||
print_test_result(test_name, False, expected, actual)
|
||||
|
||||
def assert_none(self, actual, test_name):
|
||||
"""Assert that actual is None"""
|
||||
self.assert_equal(actual, None, test_name)
|
||||
|
||||
def test_basic_url_resolution(self):
|
||||
"""Test basic relative and absolute URL resolution"""
|
||||
self.start_section("Basic URL Resolution", Colors.TARGET)
|
||||
|
||||
# Absolute URLs should remain unchanged
|
||||
self.assert_equal(
|
||||
normalize_url("https://other.com/page.html", self.base_url),
|
||||
"https://other.com/page.html",
|
||||
"Absolute URL unchanged"
|
||||
)
|
||||
|
||||
# Relative URLs
|
||||
self.assert_equal(
|
||||
normalize_url("relative.html", self.base_url),
|
||||
"https://example.com/path/relative.html",
|
||||
"Relative URL resolution"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
normalize_url("./relative.html", self.base_url),
|
||||
"https://example.com/path/relative.html",
|
||||
"Relative URL with dot"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
normalize_url("../relative.html", self.base_url),
|
||||
"https://example.com/relative.html",
|
||||
"Parent directory resolution"
|
||||
)
|
||||
|
||||
# Root-relative URLs
|
||||
self.assert_equal(
|
||||
normalize_url("/root.html", self.base_url),
|
||||
"https://example.com/root.html",
|
||||
"Root-relative URL"
|
||||
)
|
||||
|
||||
# Protocol-relative URLs
|
||||
self.assert_equal(
|
||||
normalize_url("//cdn.example.com/asset.js", self.base_url),
|
||||
"https://cdn.example.com/asset.js",
|
||||
"Protocol-relative URL"
|
||||
)
|
||||
|
||||
def test_query_parameter_handling(self):
|
||||
"""Test query parameter sorting and tracking removal"""
|
||||
self.start_section("Query Parameter Handling", Colors.STAR)
|
||||
|
||||
# Basic query parameters
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?page=1&sort=name", self.base_url),
|
||||
"https://example.com?page=1&sort=name",
|
||||
"Basic query parameters sorted"
|
||||
)
|
||||
|
||||
# Tracking parameters removal
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?utm_source=google&utm_medium=email&page=1", self.base_url),
|
||||
"https://example.com?page=1",
|
||||
"Tracking parameters removed"
|
||||
)
|
||||
|
||||
# Mixed tracking and valid parameters
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?fbclid=123&utm_campaign=test&category=news&id=456", self.base_url),
|
||||
"https://example.com?category=news&id=456",
|
||||
"Mixed tracking and valid parameters"
|
||||
)
|
||||
|
||||
# Empty query values
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?page=&sort=name", self.base_url),
|
||||
"https://example.com?page=&sort=name",
|
||||
"Empty query values preserved"
|
||||
)
|
||||
|
||||
# Disable tracking removal
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?utm_source=google&page=1", self.base_url, drop_query_tracking=False),
|
||||
"https://example.com?page=1&utm_source=google",
|
||||
"Tracking parameters preserved when disabled"
|
||||
)
|
||||
|
||||
# Disable sorting
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?z=1&a=2", self.base_url, sort_query=False),
|
||||
"https://example.com?z=1&a=2",
|
||||
"Query parameters not sorted when disabled"
|
||||
)
|
||||
|
||||
def test_fragment_handling(self):
|
||||
"""Test fragment/hash handling"""
|
||||
self.start_section("Fragment Handling", Colors.FIRE)
|
||||
|
||||
# Fragments removed by default
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/page.html#section", self.base_url),
|
||||
"https://example.com/page.html",
|
||||
"Fragment removed by default"
|
||||
)
|
||||
|
||||
# Fragments preserved when requested
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/page.html#section", self.base_url, keep_fragment=True),
|
||||
"https://example.com/page.html#section",
|
||||
"Fragment preserved when requested"
|
||||
)
|
||||
|
||||
# Fragments with query parameters
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?page=1#section", self.base_url, keep_fragment=True),
|
||||
"https://example.com?page=1#section",
|
||||
"Fragment with query parameters"
|
||||
)
|
||||
|
||||
def test_https_preservation(self):
|
||||
"""Test HTTPS preservation logic"""
|
||||
self.start_section("HTTPS Preservation", Colors.ROCKET)
|
||||
|
||||
# Same domain HTTP to HTTPS
|
||||
self.assert_equal(
|
||||
normalize_url("http://example.com/page.html", self.https_base_url, preserve_https=True, original_scheme='https'),
|
||||
"https://example.com/page.html",
|
||||
"HTTP to HTTPS for same domain"
|
||||
)
|
||||
|
||||
# Different domain should not change
|
||||
self.assert_equal(
|
||||
normalize_url("http://other.com/page.html", self.https_base_url, preserve_https=True, original_scheme='https'),
|
||||
"http://other.com/page.html",
|
||||
"Different domain HTTP unchanged"
|
||||
)
|
||||
|
||||
# Protocol-relative should follow base
|
||||
self.assert_equal(
|
||||
normalize_url("//example.com/page.html", self.https_base_url, preserve_https=True, original_scheme='https'),
|
||||
"https://example.com/page.html",
|
||||
"Protocol-relative follows base scheme"
|
||||
)
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases and error conditions"""
|
||||
self.start_section("Edge Cases", Colors.WARNING)
|
||||
|
||||
# None and empty inputs
|
||||
result = normalize_url(None, self.base_url) # type: ignore
|
||||
self.assert_none(result, "None input")
|
||||
|
||||
self.assert_none(normalize_url("", self.base_url), "Empty string input")
|
||||
self.assert_none(normalize_url(" ", self.base_url), "Whitespace only input")
|
||||
|
||||
# Malformed URLs
|
||||
try:
|
||||
normalize_url("not-a-url", "invalid-base")
|
||||
print("✗ Should have raised ValueError for invalid base URL")
|
||||
except ValueError:
|
||||
print("✓ Correctly raised ValueError for invalid base URL")
|
||||
|
||||
# Special protocols
|
||||
self.assert_equal(
|
||||
normalize_url("mailto:test@example.com", self.base_url),
|
||||
"mailto:test@example.com",
|
||||
"Mailto protocol preserved"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
normalize_url("tel:+1234567890", self.base_url),
|
||||
"tel:+1234567890",
|
||||
"Tel protocol preserved"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
normalize_url("javascript:void(0)", self.base_url),
|
||||
"javascript:void(0)",
|
||||
"JavaScript protocol preserved"
|
||||
)
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test case sensitivity handling"""
|
||||
self.start_section("Case Sensitivity", Colors.INFO)
|
||||
|
||||
# Domain case normalization
|
||||
self.assert_equal(
|
||||
normalize_url("https://EXAMPLE.COM/page.html", self.base_url),
|
||||
"https://example.com/page.html",
|
||||
"Domain case normalization"
|
||||
)
|
||||
|
||||
# Mixed case paths
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/PATH/Page.HTML", self.base_url),
|
||||
"https://example.com/PATH/Page.HTML",
|
||||
"Path case preserved"
|
||||
)
|
||||
|
||||
# Query parameter case
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com?PARAM=value", self.base_url),
|
||||
"https://example.com?param=value",
|
||||
"Query parameter case normalization"
|
||||
)
|
||||
|
||||
def test_unicode_and_special_chars(self):
|
||||
"""Test Unicode and special characters"""
|
||||
self.start_section("Unicode & Special Characters", "🌍")
|
||||
|
||||
# Unicode in path
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/café.html", self.base_url),
|
||||
"https://example.com/café.html",
|
||||
"Unicode characters in path"
|
||||
)
|
||||
|
||||
# Encoded characters
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/caf%C3%A9.html", self.base_url),
|
||||
"https://example.com/caf%C3%A9.html",
|
||||
"URL-encoded characters preserved"
|
||||
)
|
||||
|
||||
# Spaces in URLs
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/page with spaces.html", self.base_url),
|
||||
"https://example.com/page with spaces.html",
|
||||
"Spaces in URLs handled"
|
||||
)
|
||||
|
||||
def test_port_numbers(self):
|
||||
"""Test port number handling"""
|
||||
self.start_section("Port Numbers", "🔌")
|
||||
|
||||
# Default ports
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com:443/page.html", self.base_url),
|
||||
"https://example.com/page.html",
|
||||
"Default HTTPS port removed"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
normalize_url("http://example.com:80/page.html", self.base_url),
|
||||
"http://example.com/page.html",
|
||||
"Default HTTP port removed"
|
||||
)
|
||||
|
||||
# Non-default ports
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com:8443/page.html", self.base_url),
|
||||
"https://example.com:8443/page.html",
|
||||
"Non-default port preserved"
|
||||
)
|
||||
|
||||
def test_trailing_slashes(self):
|
||||
"""Test trailing slash normalization"""
|
||||
self.start_section("Trailing Slashes", "📁")
|
||||
|
||||
# Remove trailing slash from paths
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/path/", self.base_url),
|
||||
"https://example.com/path",
|
||||
"Trailing slash removed from path"
|
||||
)
|
||||
|
||||
# Preserve root trailing slash
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/", self.base_url),
|
||||
"https://example.com/",
|
||||
"Root trailing slash preserved"
|
||||
)
|
||||
|
||||
# Multiple trailing slashes
|
||||
self.assert_equal(
|
||||
normalize_url("https://example.com/path//", self.base_url),
|
||||
"https://example.com/path",
|
||||
"Multiple trailing slashes normalized"
|
||||
)
|
||||
|
||||
def test_deep_crawl_functions(self):
|
||||
"""Test deep crawl specific normalization functions"""
|
||||
self.start_section("Deep Crawl Functions", "🔍")
|
||||
|
||||
# Test normalize_url_for_deep_crawl
|
||||
result = normalize_url_for_deep_crawl("https://EXAMPLE.COM/path/?utm_source=test&page=1", self.base_url)
|
||||
expected = "https://example.com/path?page=1"
|
||||
self.assert_equal(result, expected, "Deep crawl normalization")
|
||||
|
||||
# Test efficient version
|
||||
result = efficient_normalize_url_for_deep_crawl("https://EXAMPLE.COM/path/#fragment", self.base_url)
|
||||
expected = "https://example.com/path"
|
||||
self.assert_equal(result, expected, "Efficient deep crawl normalization")
|
||||
|
||||
def test_base_domain_extraction(self):
|
||||
"""Test base domain extraction"""
|
||||
self.start_section("Base Domain Extraction", "🏠")
|
||||
|
||||
self.assert_equal(
|
||||
get_base_domain("https://www.example.com/path"),
|
||||
"example.com",
|
||||
"WWW prefix removed"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
get_base_domain("https://sub.example.co.uk/path"),
|
||||
"example.co.uk",
|
||||
"Special TLD handled"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
get_base_domain("https://example.com:8080/path"),
|
||||
"example.com",
|
||||
"Port removed"
|
||||
)
|
||||
|
||||
def test_external_url_detection(self):
|
||||
"""Test external URL detection"""
|
||||
self.start_section("External URL Detection", "🌐")
|
||||
|
||||
self.assert_equal(
|
||||
is_external_url("https://other.com/page.html", "example.com"),
|
||||
True,
|
||||
"Different domain is external"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
is_external_url("https://www.example.com/page.html", "example.com"),
|
||||
False,
|
||||
"Same domain with www is internal"
|
||||
)
|
||||
|
||||
self.assert_equal(
|
||||
is_external_url("mailto:test@example.com", "example.com"),
|
||||
True,
|
||||
"Special protocol is external"
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all test suites"""
|
||||
print_header("🚀 URL Normalization Test Suite", Colors.ROCKET)
|
||||
self.test_start_time = time.time()
|
||||
|
||||
# Run all test sections
|
||||
sections = [
|
||||
("Basic URL Resolution", Colors.TARGET, self.test_basic_url_resolution),
|
||||
("Query Parameter Handling", Colors.STAR, self.test_query_parameter_handling),
|
||||
("Fragment Handling", Colors.FIRE, self.test_fragment_handling),
|
||||
("HTTPS Preservation", Colors.ROCKET, self.test_https_preservation),
|
||||
("Edge Cases", Colors.WARNING, self.test_edge_cases),
|
||||
("Case Sensitivity", Colors.INFO, self.test_case_sensitivity),
|
||||
("Unicode & Special Characters", "🌍", self.test_unicode_and_special_chars),
|
||||
("Port Numbers", "🔌", self.test_port_numbers),
|
||||
("Trailing Slashes", "📁", self.test_trailing_slashes),
|
||||
("Deep Crawl Functions", "🔍", self.test_deep_crawl_functions),
|
||||
("Base Domain Extraction", "🏠", self.test_base_domain_extraction),
|
||||
("External URL Detection", "🌐", self.test_external_url_detection),
|
||||
]
|
||||
|
||||
total_sections = len(sections)
|
||||
for i, (section_name, icon, test_method) in enumerate(sections, 1):
|
||||
print_progress(i - 1, total_sections, f"Running {section_name}")
|
||||
test_method()
|
||||
print_progress(i, total_sections, f"Completed {section_name}")
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - self.test_start_time
|
||||
|
||||
# Print comprehensive statistics
|
||||
self.print_comprehensive_stats(execution_time)
|
||||
|
||||
return len(self.tests_failed) == 0
|
||||
|
||||
def print_comprehensive_stats(self, execution_time):
|
||||
"""Print comprehensive test statistics"""
|
||||
print_header("📊 Test Results Summary", "📈")
|
||||
|
||||
# Overall statistics
|
||||
success_rate = (self.tests_passed / self.tests_run * 100) if self.tests_run > 0 else 0
|
||||
|
||||
print(f"{Colors.BOLD}Overall Statistics:{Colors.RESET}")
|
||||
print(f" Total Tests: {Colors.CYAN}{self.tests_run}{Colors.RESET}")
|
||||
print(f" Passed: {Colors.GREEN}{self.tests_passed}{Colors.RESET}")
|
||||
print(f" Failed: {Colors.RED}{len(self.tests_failed)}{Colors.RESET}")
|
||||
print(f" Success Rate: {Colors.BRIGHT_CYAN}{success_rate:.1f}%{Colors.RESET}")
|
||||
print(f" Execution Time: {Colors.YELLOW}{execution_time:.2f}s{Colors.RESET}")
|
||||
|
||||
# Performance indicator
|
||||
if success_rate == 100:
|
||||
print_success("🎉 Perfect! All tests passed!")
|
||||
elif success_rate >= 90:
|
||||
print_success("✅ Excellent! Nearly perfect results!")
|
||||
elif success_rate >= 75:
|
||||
print_warning("⚠️ Good results, but some improvements needed")
|
||||
else:
|
||||
print_error("❌ Significant issues detected - review failures below")
|
||||
|
||||
# Section-by-section breakdown
|
||||
if self.section_stats:
|
||||
print(f"\n{Colors.BOLD}Section Breakdown:{Colors.RESET}")
|
||||
for section_name, stats in self.section_stats.items():
|
||||
section_success_rate = (stats['passed'] / stats['run'] * 100) if stats['run'] > 0 else 0
|
||||
status_icon = Colors.CHECK if stats['failed'] == 0 else Colors.CROSS
|
||||
status_color = Colors.GREEN if stats['failed'] == 0 else Colors.RED
|
||||
|
||||
print(f" {status_icon} {section_name}: {Colors.CYAN}{stats['run']}{Colors.RESET} tests, "
|
||||
f"{status_color}{stats['passed']} passed{Colors.RESET}, "
|
||||
f"{Colors.RED}{stats['failed']} failed{Colors.RESET} "
|
||||
f"({Colors.BRIGHT_CYAN}{section_success_rate:.1f}%{Colors.RESET})")
|
||||
|
||||
# Failed tests details
|
||||
if self.tests_failed:
|
||||
print(f"\n{Colors.BOLD}{Colors.RED}Failed Tests Details:{Colors.RESET}")
|
||||
for i, failure in enumerate(self.tests_failed, 1):
|
||||
print(f" {Colors.RED}{i}. {failure['name']}{Colors.RESET}")
|
||||
if 'section' in failure and failure['section']:
|
||||
print(f" Section: {Colors.YELLOW}{failure['section']}{Colors.RESET}")
|
||||
print(f" Expected: {Colors.BRIGHT_RED}{failure['expected']}{Colors.RESET}")
|
||||
print(f" Actual: {Colors.BRIGHT_RED}{failure['actual']}{Colors.RESET}")
|
||||
print()
|
||||
|
||||
# Recommendations
|
||||
if self.tests_failed:
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}Recommendations:{Colors.RESET}")
|
||||
print(f" • Review the {len(self.tests_failed)} failed test(s) above")
|
||||
print(" • Check URL normalization logic for edge cases")
|
||||
print(" • Verify query parameter handling")
|
||||
print(" • Test with real-world URLs")
|
||||
else:
|
||||
print(f"\n{Colors.BOLD}{Colors.GREEN}Recommendations:{Colors.RESET}")
|
||||
print(" • All tests passed! URL normalization is working correctly")
|
||||
print(" • Consider adding more edge cases for future robustness")
|
||||
print(" • Monitor performance with large-scale crawling")
|
||||
|
||||
|
||||
def test_crawling_integration():
|
||||
"""Test integration with crawling scripts"""
|
||||
print_section("Crawling Integration Test", "🔗")
|
||||
|
||||
# Test URLs that would be encountered in real crawling
|
||||
test_urls = [
|
||||
"https://example.com/blog/post?utm_source=newsletter&utm_medium=email",
|
||||
"https://example.com/products?page=1&sort=price&ref=search",
|
||||
"/about.html",
|
||||
"../contact.html",
|
||||
"//cdn.example.com/js/main.js",
|
||||
"mailto:support@example.com",
|
||||
"#top",
|
||||
"",
|
||||
None,
|
||||
]
|
||||
|
||||
base_url = "https://example.com/current/page.html"
|
||||
|
||||
print("Testing real-world URL scenarios:")
|
||||
for url in test_urls:
|
||||
try:
|
||||
normalized = normalize_url(url, base_url)
|
||||
print(f" {url} -> {normalized}")
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f" {url} -> ERROR: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_header("🧪 URL Normalization Comprehensive Test Suite", "🧪")
|
||||
print_info("Testing URL normalization functions with comprehensive scenarios and edge cases")
|
||||
print()
|
||||
|
||||
# Run the test suite
|
||||
test_suite = URLNormalizationTestSuite()
|
||||
success = test_suite.run_all_tests()
|
||||
|
||||
# Run integration tests
|
||||
print()
|
||||
test_crawling_integration()
|
||||
|
||||
# Final summary
|
||||
print()
|
||||
print_header("🏁 Final Test Summary", "🏁")
|
||||
|
||||
if success:
|
||||
print_success("🎉 ALL TESTS PASSED! URL normalization is working perfectly!")
|
||||
print_info("The updated URL normalization functions are ready for production use.")
|
||||
else:
|
||||
print_error("❌ SOME TESTS FAILED! Please review the issues above.")
|
||||
print_warning("URL normalization may have issues that need to be addressed before deployment.")
|
||||
|
||||
print()
|
||||
print_info("Test suite completed. Check the results above for detailed analysis.")
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user