Refactor adaptive crawling state management
- Renamed `CrawlState` to `AdaptiveCrawlResult` to better reflect its purpose. - Updated all references to `CrawlState` in the codebase, including method signatures and documentation. - Modified the `AdaptiveCrawler` class to initialize and manage the new `AdaptiveCrawlResult` state. - Adjusted example strategies and documentation to align with the new state class. - Ensured all tests are updated to use `AdaptiveCrawlResult` instead of `CrawlState`.
This commit is contained in:
@@ -216,7 +216,7 @@ Under certain assumptions about link preview accuracy:
|
||||
|
||||
### 8.1 Core Components
|
||||
|
||||
1. **CrawlState**: Maintains crawl history and metrics
|
||||
1. **AdaptiveCrawlResult**: Maintains crawl history and metrics
|
||||
2. **AdaptiveConfig**: Configuration parameters
|
||||
3. **CrawlStrategy**: Pluggable strategy interface
|
||||
4. **AdaptiveCrawler**: Main orchestrator
|
||||
|
||||
@@ -73,7 +73,7 @@ from .async_url_seeder import AsyncUrlSeeder
|
||||
from .adaptive_crawler import (
|
||||
AdaptiveCrawler,
|
||||
AdaptiveConfig,
|
||||
CrawlState,
|
||||
AdaptiveCrawlResult,
|
||||
CrawlStrategy,
|
||||
StatisticalStrategy
|
||||
)
|
||||
@@ -108,7 +108,7 @@ __all__ = [
|
||||
# Adaptive Crawler
|
||||
"AdaptiveCrawler",
|
||||
"AdaptiveConfig",
|
||||
"CrawlState",
|
||||
"AdaptiveCrawlResult",
|
||||
"CrawlStrategy",
|
||||
"StatisticalStrategy",
|
||||
"DeepCrawlStrategy",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ from crawl4ai.models import Link, CrawlResult
|
||||
import numpy as np
|
||||
|
||||
@dataclass
|
||||
class CrawlState:
|
||||
class AdaptiveCrawlResult:
|
||||
"""Tracks the current state of adaptive crawling"""
|
||||
crawled_urls: Set[str] = field(default_factory=set)
|
||||
knowledge_base: List[CrawlResult] = field(default_factory=list)
|
||||
@@ -80,7 +80,7 @@ class CrawlState:
|
||||
json.dump(state_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> 'CrawlState':
|
||||
def load(cls, path: Union[str, Path]) -> 'AdaptiveCrawlResult':
|
||||
"""Load state from disk"""
|
||||
path = Path(path)
|
||||
with open(path, 'r') as f:
|
||||
@@ -256,22 +256,22 @@ class CrawlStrategy(ABC):
|
||||
"""Abstract base class for crawling strategies"""
|
||||
|
||||
@abstractmethod
|
||||
async def calculate_confidence(self, state: CrawlState) -> float:
|
||||
async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Calculate overall confidence that we have sufficient information"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def rank_links(self, state: CrawlState, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
|
||||
async def rank_links(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
|
||||
"""Rank pending links by expected information gain"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def should_stop(self, state: CrawlState, config: AdaptiveConfig) -> bool:
|
||||
async def should_stop(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> bool:
|
||||
"""Determine if crawling should stop"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update_state(self, state: CrawlState, new_results: List[CrawlResult]) -> None:
|
||||
async def update_state(self, state: AdaptiveCrawlResult, new_results: List[CrawlResult]) -> None:
|
||||
"""Update state with new crawl results"""
|
||||
pass
|
||||
|
||||
@@ -284,7 +284,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
self.bm25_k1 = 1.2 # BM25 parameter
|
||||
self.bm25_b = 0.75 # BM25 parameter
|
||||
|
||||
async def calculate_confidence(self, state: CrawlState) -> float:
|
||||
async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Calculate confidence using coverage, consistency, and saturation"""
|
||||
if not state.knowledge_base:
|
||||
return 0.0
|
||||
@@ -303,7 +303,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
|
||||
return confidence
|
||||
|
||||
def _calculate_coverage(self, state: CrawlState) -> float:
|
||||
def _calculate_coverage(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Coverage scoring - measures query term presence across knowledge base
|
||||
|
||||
Returns a score between 0 and 1, where:
|
||||
@@ -344,7 +344,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
# This helps differentiate between partial and good coverage
|
||||
return min(1.0, math.sqrt(coverage))
|
||||
|
||||
def _calculate_consistency(self, state: CrawlState) -> float:
|
||||
def _calculate_consistency(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Information overlap between pages - high overlap suggests coherent topic coverage"""
|
||||
if len(state.knowledge_base) < 2:
|
||||
return 1.0 # Single or no documents are perfectly consistent
|
||||
@@ -371,7 +371,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
|
||||
return consistency
|
||||
|
||||
def _calculate_saturation(self, state: CrawlState) -> float:
|
||||
def _calculate_saturation(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Diminishing returns indicator - are we still discovering new information?"""
|
||||
if not state.new_terms_history:
|
||||
return 0.0
|
||||
@@ -388,7 +388,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
|
||||
return max(0.0, min(saturation, 1.0))
|
||||
|
||||
async def rank_links(self, state: CrawlState, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
|
||||
async def rank_links(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
|
||||
"""Rank links by expected information gain"""
|
||||
scored_links = []
|
||||
|
||||
@@ -415,7 +415,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
|
||||
return scored_links
|
||||
|
||||
def _calculate_relevance(self, link: Link, state: CrawlState) -> float:
|
||||
def _calculate_relevance(self, link: Link, state: AdaptiveCrawlResult) -> float:
|
||||
"""BM25 relevance score between link preview and query"""
|
||||
if not state.query or not link:
|
||||
return 0.0
|
||||
@@ -447,7 +447,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
overlap = len(query_terms & link_terms) / len(query_terms)
|
||||
return overlap
|
||||
|
||||
def _calculate_novelty(self, link: Link, state: CrawlState) -> float:
|
||||
def _calculate_novelty(self, link: Link, state: AdaptiveCrawlResult) -> float:
|
||||
"""Estimate how much new information this link might provide"""
|
||||
if not state.knowledge_base:
|
||||
return 1.0 # First links are maximally novel
|
||||
@@ -502,7 +502,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
async def should_stop(self, state: CrawlState, config: AdaptiveConfig) -> bool:
|
||||
async def should_stop(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> bool:
|
||||
"""Determine if crawling should stop"""
|
||||
# Check confidence threshold
|
||||
confidence = state.metrics.get('confidence', 0.0)
|
||||
@@ -523,7 +523,7 @@ class StatisticalStrategy(CrawlStrategy):
|
||||
|
||||
return False
|
||||
|
||||
async def update_state(self, state: CrawlState, new_results: List[CrawlResult]) -> None:
|
||||
async def update_state(self, state: AdaptiveCrawlResult, new_results: List[CrawlResult]) -> None:
|
||||
"""Update state with new crawl results"""
|
||||
for result in new_results:
|
||||
# Track new terms
|
||||
@@ -921,7 +921,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
return sorted(scored_links, key=lambda x: x[1], reverse=True)
|
||||
|
||||
async def calculate_confidence(self, state: CrawlState) -> float:
|
||||
async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Coverage-based learning score (0–1)."""
|
||||
# Guard clauses
|
||||
if state.kb_embeddings is None or state.query_embeddings is None:
|
||||
@@ -951,7 +951,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
|
||||
|
||||
# async def calculate_confidence(self, state: CrawlState) -> float:
|
||||
# async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
|
||||
# """Calculate learning score for adaptive crawling (used for stopping)"""
|
||||
#
|
||||
|
||||
@@ -1021,7 +1021,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
# # For stopping criteria, return learning score
|
||||
# return float(learning_score)
|
||||
|
||||
async def rank_links(self, state: CrawlState, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
|
||||
async def rank_links(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
|
||||
"""Main entry point for link ranking"""
|
||||
# Store config for use in other methods
|
||||
self.config = config
|
||||
@@ -1052,7 +1052,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
state.kb_embeddings
|
||||
)
|
||||
|
||||
async def validate_coverage(self, state: CrawlState) -> float:
|
||||
async def validate_coverage(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Validate coverage using held-out queries with caching"""
|
||||
if not hasattr(self, '_validation_queries') or not self._validation_queries:
|
||||
return state.metrics.get('confidence', 0.0)
|
||||
@@ -1088,7 +1088,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
return validation_confidence
|
||||
|
||||
async def should_stop(self, state: CrawlState, config: AdaptiveConfig) -> bool:
|
||||
async def should_stop(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> bool:
|
||||
"""Stop based on learning curve convergence"""
|
||||
confidence = state.metrics.get('confidence', 0.0)
|
||||
|
||||
@@ -1139,7 +1139,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
return False
|
||||
|
||||
def get_quality_confidence(self, state: CrawlState) -> float:
|
||||
def get_quality_confidence(self, state: AdaptiveCrawlResult) -> float:
|
||||
"""Calculate quality-based confidence score for display"""
|
||||
learning_score = state.metrics.get('learning_score', 0.0)
|
||||
validation_score = state.metrics.get('validation_confidence', 0.0)
|
||||
@@ -1166,7 +1166,7 @@ class EmbeddingStrategy(CrawlStrategy):
|
||||
|
||||
return confidence
|
||||
|
||||
async def update_state(self, state: CrawlState, new_results: List[CrawlResult]) -> None:
|
||||
async def update_state(self, state: AdaptiveCrawlResult, new_results: List[CrawlResult]) -> None:
|
||||
"""Update embeddings and coverage metrics with deduplication"""
|
||||
from .utils import get_text_embeddings
|
||||
|
||||
@@ -1246,7 +1246,7 @@ class AdaptiveCrawler:
|
||||
self.strategy = self._create_strategy(self.config.strategy)
|
||||
|
||||
# Initialize state
|
||||
self.state: Optional[CrawlState] = None
|
||||
self.state: Optional[AdaptiveCrawlResult] = None
|
||||
|
||||
# Track if we own the crawler (for cleanup)
|
||||
self._owns_crawler = crawler is None
|
||||
@@ -1266,14 +1266,14 @@ class AdaptiveCrawler:
|
||||
async def digest(self,
|
||||
start_url: str,
|
||||
query: str,
|
||||
resume_from: Optional[str] = None) -> CrawlState:
|
||||
resume_from: Optional[str] = None) -> AdaptiveCrawlResult:
|
||||
"""Main entry point for adaptive crawling"""
|
||||
# Initialize or resume state
|
||||
if resume_from:
|
||||
self.state = CrawlState.load(resume_from)
|
||||
self.state = AdaptiveCrawlResult.load(resume_from)
|
||||
self.state.query = query # Update query in case it changed
|
||||
else:
|
||||
self.state = CrawlState(
|
||||
self.state = AdaptiveCrawlResult(
|
||||
crawled_urls=set(),
|
||||
knowledge_base=[],
|
||||
pending_links=[],
|
||||
@@ -1803,7 +1803,7 @@ class AdaptiveCrawler:
|
||||
|
||||
# Initialize state if needed
|
||||
if not self.state:
|
||||
self.state = CrawlState()
|
||||
self.state = AdaptiveCrawlResult()
|
||||
|
||||
# Add imported results
|
||||
self.state.knowledge_base.extend(imported_results)
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Set
|
||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig
|
||||
from crawl4ai.adaptive_crawler import CrawlState, Link
|
||||
from crawl4ai.adaptive_crawler import AdaptiveCrawlResult, Link
|
||||
import math
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class APIDocumentationStrategy:
|
||||
r'/legal/'
|
||||
]
|
||||
|
||||
def score_link(self, link: Link, query: str, state: CrawlState) -> float:
|
||||
def score_link(self, link: Link, query: str, state: AdaptiveCrawlResult) -> float:
|
||||
"""Custom link scoring for API documentation"""
|
||||
score = 1.0
|
||||
url = link.href.lower()
|
||||
@@ -77,7 +77,7 @@ class APIDocumentationStrategy:
|
||||
|
||||
return score
|
||||
|
||||
def calculate_api_coverage(self, state: CrawlState, query: str) -> Dict[str, float]:
|
||||
def calculate_api_coverage(self, state: AdaptiveCrawlResult, query: str) -> Dict[str, float]:
|
||||
"""Calculate specialized coverage metrics for API documentation"""
|
||||
metrics = {
|
||||
'endpoint_coverage': 0.0,
|
||||
|
||||
@@ -130,7 +130,7 @@ Factors:
|
||||
|
||||
```python
|
||||
class CustomLinkScorer:
|
||||
def score(self, link: Link, query: str, state: CrawlState) -> float:
|
||||
def score(self, link: Link, query: str, state: AdaptiveCrawlResult) -> float:
|
||||
# Prioritize specific URL patterns
|
||||
if "/api/reference/" in link.href:
|
||||
return 2.0 # Double the score
|
||||
@@ -325,17 +325,17 @@ with open("crawl_analysis.json", "w") as f:
|
||||
from crawl4ai.adaptive_crawler import BaseStrategy
|
||||
|
||||
class DomainSpecificStrategy(BaseStrategy):
|
||||
def calculate_coverage(self, state: CrawlState) -> float:
|
||||
def calculate_coverage(self, state: AdaptiveCrawlResult) -> float:
|
||||
# Custom coverage calculation
|
||||
# e.g., weight certain terms more heavily
|
||||
pass
|
||||
|
||||
def calculate_consistency(self, state: CrawlState) -> float:
|
||||
def calculate_consistency(self, state: AdaptiveCrawlResult) -> float:
|
||||
# Custom consistency logic
|
||||
# e.g., domain-specific validation
|
||||
pass
|
||||
|
||||
def rank_links(self, links: List[Link], state: CrawlState) -> List[Link]:
|
||||
def rank_links(self, links: List[Link], state: AdaptiveCrawlResult) -> List[Link]:
|
||||
# Custom link ranking
|
||||
# e.g., prioritize specific URL patterns
|
||||
pass
|
||||
@@ -359,7 +359,7 @@ class HybridStrategy(BaseStrategy):
|
||||
URLPatternStrategy()
|
||||
]
|
||||
|
||||
def calculate_confidence(self, state: CrawlState) -> float:
|
||||
def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
|
||||
# Weighted combination of strategies
|
||||
scores = [s.calculate_confidence(state) for s in self.strategies]
|
||||
weights = [0.5, 0.3, 0.2]
|
||||
|
||||
@@ -27,7 +27,7 @@ async def digest(
|
||||
start_url: str,
|
||||
query: str,
|
||||
resume_from: Optional[Union[str, Path]] = None
|
||||
) -> CrawlState
|
||||
) -> AdaptiveCrawlResult
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
@@ -38,7 +38,7 @@ async def digest(
|
||||
|
||||
#### Returns
|
||||
|
||||
- **CrawlState**: The final crawl state containing all crawled URLs, knowledge base, and metrics
|
||||
- **AdaptiveCrawlResult**: The final crawl state containing all crawled URLs, knowledge base, and metrics
|
||||
|
||||
#### Example
|
||||
|
||||
@@ -92,7 +92,7 @@ Access to the current crawl state.
|
||||
|
||||
```python
|
||||
@property
|
||||
def state(self) -> CrawlState
|
||||
def state(self) -> AdaptiveCrawlResult
|
||||
```
|
||||
|
||||
## Methods
|
||||
|
||||
@@ -9,7 +9,7 @@ async def digest(
|
||||
start_url: str,
|
||||
query: str,
|
||||
resume_from: Optional[Union[str, Path]] = None
|
||||
) -> CrawlState
|
||||
) -> AdaptiveCrawlResult
|
||||
```
|
||||
|
||||
## Parameters
|
||||
@@ -31,7 +31,7 @@ async def digest(
|
||||
|
||||
## Return Value
|
||||
|
||||
Returns a `CrawlState` object containing:
|
||||
Returns a `AdaptiveCrawlResult` object containing:
|
||||
|
||||
- **crawled_urls** (`Set[str]`): All URLs that have been crawled
|
||||
- **knowledge_base** (`List[CrawlResult]`): Collection of crawled pages with content
|
||||
|
||||
@@ -23,7 +23,7 @@ from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
AdaptiveCrawler,
|
||||
AdaptiveConfig,
|
||||
CrawlState
|
||||
AdaptiveCrawlResult
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import math
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
from crawl4ai.adaptive_crawler import CrawlState, StatisticalStrategy
|
||||
from crawl4ai.adaptive_crawler import AdaptiveCrawlResult, StatisticalStrategy
|
||||
from crawl4ai.models import CrawlResult
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class ConfidenceTestHarness:
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize state
|
||||
state = CrawlState(query=self.query)
|
||||
state = AdaptiveCrawlResult(query=self.query)
|
||||
|
||||
# Create crawler
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
@@ -107,7 +107,7 @@ class ConfidenceTestHarness:
|
||||
|
||||
state.metrics['prev_confidence'] = confidence
|
||||
|
||||
def _debug_coverage_calculation(self, state: CrawlState, query_terms: List[str]):
|
||||
def _debug_coverage_calculation(self, state: AdaptiveCrawlResult, query_terms: List[str]):
|
||||
"""Debug coverage calculation step by step"""
|
||||
coverage_score = 0.0
|
||||
max_possible_score = 0.0
|
||||
@@ -136,7 +136,7 @@ class ConfidenceTestHarness:
|
||||
new_coverage = self._calculate_coverage_new(state, query_terms)
|
||||
print(f" → New Coverage: {new_coverage:.3f}")
|
||||
|
||||
def _calculate_coverage_new(self, state: CrawlState, query_terms: List[str]) -> float:
|
||||
def _calculate_coverage_new(self, state: AdaptiveCrawlResult, query_terms: List[str]) -> float:
|
||||
"""New coverage calculation without IDF"""
|
||||
if not query_terms or state.total_documents == 0:
|
||||
return 0.0
|
||||
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig
|
||||
from crawl4ai.adaptive_crawler import EmbeddingStrategy, CrawlState
|
||||
from crawl4ai.adaptive_crawler import EmbeddingStrategy, AdaptiveCrawlResult
|
||||
from crawl4ai.models import CrawlResult
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ async def test_embedding_performance():
|
||||
strategy.config = config
|
||||
|
||||
# Initialize state
|
||||
state = CrawlState()
|
||||
state = AdaptiveCrawlResult()
|
||||
state.query = "async await coroutines event loops tasks"
|
||||
|
||||
# Start performance monitoring
|
||||
|
||||
@@ -20,7 +20,7 @@ from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
AdaptiveCrawler,
|
||||
AdaptiveConfig,
|
||||
CrawlState
|
||||
AdaptiveCrawlResult
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
Reference in New Issue
Block a user