Refactor adaptive crawling state management

- Renamed `CrawlState` to `AdaptiveCrawlResult` to better reflect its purpose.
- Updated all references to `CrawlState` in the codebase, including method signatures and documentation.
- Modified the `AdaptiveCrawler` class to initialize and manage the new `AdaptiveCrawlResult` state.
- Adjusted example strategies and documentation to align with the new state class.
- Ensured all tests are updated to use `AdaptiveCrawlResult` instead of `CrawlState`.
This commit is contained in:
UncleCode
2025-07-24 20:11:43 +08:00
parent d1de82a332
commit 843457a9cb
12 changed files with 51 additions and 1898 deletions

View File

@@ -216,7 +216,7 @@ Under certain assumptions about link preview accuracy:
### 8.1 Core Components ### 8.1 Core Components
1. **CrawlState**: Maintains crawl history and metrics 1. **AdaptiveCrawlResult**: Maintains crawl history and metrics
2. **AdaptiveConfig**: Configuration parameters 2. **AdaptiveConfig**: Configuration parameters
3. **CrawlStrategy**: Pluggable strategy interface 3. **CrawlStrategy**: Pluggable strategy interface
4. **AdaptiveCrawler**: Main orchestrator 4. **AdaptiveCrawler**: Main orchestrator

View File

@@ -73,7 +73,7 @@ from .async_url_seeder import AsyncUrlSeeder
from .adaptive_crawler import ( from .adaptive_crawler import (
AdaptiveCrawler, AdaptiveCrawler,
AdaptiveConfig, AdaptiveConfig,
CrawlState, AdaptiveCrawlResult,
CrawlStrategy, CrawlStrategy,
StatisticalStrategy StatisticalStrategy
) )
@@ -108,7 +108,7 @@ __all__ = [
# Adaptive Crawler # Adaptive Crawler
"AdaptiveCrawler", "AdaptiveCrawler",
"AdaptiveConfig", "AdaptiveConfig",
"CrawlState", "AdaptiveCrawlResult",
"CrawlStrategy", "CrawlStrategy",
"StatisticalStrategy", "StatisticalStrategy",
"DeepCrawlStrategy", "DeepCrawlStrategy",

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ from crawl4ai.models import Link, CrawlResult
import numpy as np import numpy as np
@dataclass @dataclass
class CrawlState: class AdaptiveCrawlResult:
"""Tracks the current state of adaptive crawling""" """Tracks the current state of adaptive crawling"""
crawled_urls: Set[str] = field(default_factory=set) crawled_urls: Set[str] = field(default_factory=set)
knowledge_base: List[CrawlResult] = field(default_factory=list) knowledge_base: List[CrawlResult] = field(default_factory=list)
@@ -80,7 +80,7 @@ class CrawlState:
json.dump(state_dict, f, indent=2) json.dump(state_dict, f, indent=2)
@classmethod @classmethod
def load(cls, path: Union[str, Path]) -> 'CrawlState': def load(cls, path: Union[str, Path]) -> 'AdaptiveCrawlResult':
"""Load state from disk""" """Load state from disk"""
path = Path(path) path = Path(path)
with open(path, 'r') as f: with open(path, 'r') as f:
@@ -256,22 +256,22 @@ class CrawlStrategy(ABC):
"""Abstract base class for crawling strategies""" """Abstract base class for crawling strategies"""
@abstractmethod @abstractmethod
async def calculate_confidence(self, state: CrawlState) -> float: async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
"""Calculate overall confidence that we have sufficient information""" """Calculate overall confidence that we have sufficient information"""
pass pass
@abstractmethod @abstractmethod
async def rank_links(self, state: CrawlState, config: AdaptiveConfig) -> List[Tuple[Link, float]]: async def rank_links(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
"""Rank pending links by expected information gain""" """Rank pending links by expected information gain"""
pass pass
@abstractmethod @abstractmethod
async def should_stop(self, state: CrawlState, config: AdaptiveConfig) -> bool: async def should_stop(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> bool:
"""Determine if crawling should stop""" """Determine if crawling should stop"""
pass pass
@abstractmethod @abstractmethod
async def update_state(self, state: CrawlState, new_results: List[CrawlResult]) -> None: async def update_state(self, state: AdaptiveCrawlResult, new_results: List[CrawlResult]) -> None:
"""Update state with new crawl results""" """Update state with new crawl results"""
pass pass
@@ -284,7 +284,7 @@ class StatisticalStrategy(CrawlStrategy):
self.bm25_k1 = 1.2 # BM25 parameter self.bm25_k1 = 1.2 # BM25 parameter
self.bm25_b = 0.75 # BM25 parameter self.bm25_b = 0.75 # BM25 parameter
async def calculate_confidence(self, state: CrawlState) -> float: async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
"""Calculate confidence using coverage, consistency, and saturation""" """Calculate confidence using coverage, consistency, and saturation"""
if not state.knowledge_base: if not state.knowledge_base:
return 0.0 return 0.0
@@ -303,7 +303,7 @@ class StatisticalStrategy(CrawlStrategy):
return confidence return confidence
def _calculate_coverage(self, state: CrawlState) -> float: def _calculate_coverage(self, state: AdaptiveCrawlResult) -> float:
"""Coverage scoring - measures query term presence across knowledge base """Coverage scoring - measures query term presence across knowledge base
Returns a score between 0 and 1, where: Returns a score between 0 and 1, where:
@@ -344,7 +344,7 @@ class StatisticalStrategy(CrawlStrategy):
# This helps differentiate between partial and good coverage # This helps differentiate between partial and good coverage
return min(1.0, math.sqrt(coverage)) return min(1.0, math.sqrt(coverage))
def _calculate_consistency(self, state: CrawlState) -> float: def _calculate_consistency(self, state: AdaptiveCrawlResult) -> float:
"""Information overlap between pages - high overlap suggests coherent topic coverage""" """Information overlap between pages - high overlap suggests coherent topic coverage"""
if len(state.knowledge_base) < 2: if len(state.knowledge_base) < 2:
return 1.0 # Single or no documents are perfectly consistent return 1.0 # Single or no documents are perfectly consistent
@@ -371,7 +371,7 @@ class StatisticalStrategy(CrawlStrategy):
return consistency return consistency
def _calculate_saturation(self, state: CrawlState) -> float: def _calculate_saturation(self, state: AdaptiveCrawlResult) -> float:
"""Diminishing returns indicator - are we still discovering new information?""" """Diminishing returns indicator - are we still discovering new information?"""
if not state.new_terms_history: if not state.new_terms_history:
return 0.0 return 0.0
@@ -388,7 +388,7 @@ class StatisticalStrategy(CrawlStrategy):
return max(0.0, min(saturation, 1.0)) return max(0.0, min(saturation, 1.0))
async def rank_links(self, state: CrawlState, config: AdaptiveConfig) -> List[Tuple[Link, float]]: async def rank_links(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
"""Rank links by expected information gain""" """Rank links by expected information gain"""
scored_links = [] scored_links = []
@@ -415,7 +415,7 @@ class StatisticalStrategy(CrawlStrategy):
return scored_links return scored_links
def _calculate_relevance(self, link: Link, state: CrawlState) -> float: def _calculate_relevance(self, link: Link, state: AdaptiveCrawlResult) -> float:
"""BM25 relevance score between link preview and query""" """BM25 relevance score between link preview and query"""
if not state.query or not link: if not state.query or not link:
return 0.0 return 0.0
@@ -447,7 +447,7 @@ class StatisticalStrategy(CrawlStrategy):
overlap = len(query_terms & link_terms) / len(query_terms) overlap = len(query_terms & link_terms) / len(query_terms)
return overlap return overlap
def _calculate_novelty(self, link: Link, state: CrawlState) -> float: def _calculate_novelty(self, link: Link, state: AdaptiveCrawlResult) -> float:
"""Estimate how much new information this link might provide""" """Estimate how much new information this link might provide"""
if not state.knowledge_base: if not state.knowledge_base:
return 1.0 # First links are maximally novel return 1.0 # First links are maximally novel
@@ -502,7 +502,7 @@ class StatisticalStrategy(CrawlStrategy):
return min(score, 1.0) return min(score, 1.0)
async def should_stop(self, state: CrawlState, config: AdaptiveConfig) -> bool: async def should_stop(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> bool:
"""Determine if crawling should stop""" """Determine if crawling should stop"""
# Check confidence threshold # Check confidence threshold
confidence = state.metrics.get('confidence', 0.0) confidence = state.metrics.get('confidence', 0.0)
@@ -523,7 +523,7 @@ class StatisticalStrategy(CrawlStrategy):
return False return False
async def update_state(self, state: CrawlState, new_results: List[CrawlResult]) -> None: async def update_state(self, state: AdaptiveCrawlResult, new_results: List[CrawlResult]) -> None:
"""Update state with new crawl results""" """Update state with new crawl results"""
for result in new_results: for result in new_results:
# Track new terms # Track new terms
@@ -921,7 +921,7 @@ class EmbeddingStrategy(CrawlStrategy):
return sorted(scored_links, key=lambda x: x[1], reverse=True) return sorted(scored_links, key=lambda x: x[1], reverse=True)
async def calculate_confidence(self, state: CrawlState) -> float: async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
"""Coverage-based learning score (01).""" """Coverage-based learning score (01)."""
# Guard clauses # Guard clauses
if state.kb_embeddings is None or state.query_embeddings is None: if state.kb_embeddings is None or state.query_embeddings is None:
@@ -951,7 +951,7 @@ class EmbeddingStrategy(CrawlStrategy):
# async def calculate_confidence(self, state: CrawlState) -> float: # async def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
# """Calculate learning score for adaptive crawling (used for stopping)""" # """Calculate learning score for adaptive crawling (used for stopping)"""
# #
@@ -1021,7 +1021,7 @@ class EmbeddingStrategy(CrawlStrategy):
# # For stopping criteria, return learning score # # For stopping criteria, return learning score
# return float(learning_score) # return float(learning_score)
async def rank_links(self, state: CrawlState, config: AdaptiveConfig) -> List[Tuple[Link, float]]: async def rank_links(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> List[Tuple[Link, float]]:
"""Main entry point for link ranking""" """Main entry point for link ranking"""
# Store config for use in other methods # Store config for use in other methods
self.config = config self.config = config
@@ -1052,7 +1052,7 @@ class EmbeddingStrategy(CrawlStrategy):
state.kb_embeddings state.kb_embeddings
) )
async def validate_coverage(self, state: CrawlState) -> float: async def validate_coverage(self, state: AdaptiveCrawlResult) -> float:
"""Validate coverage using held-out queries with caching""" """Validate coverage using held-out queries with caching"""
if not hasattr(self, '_validation_queries') or not self._validation_queries: if not hasattr(self, '_validation_queries') or not self._validation_queries:
return state.metrics.get('confidence', 0.0) return state.metrics.get('confidence', 0.0)
@@ -1088,7 +1088,7 @@ class EmbeddingStrategy(CrawlStrategy):
return validation_confidence return validation_confidence
async def should_stop(self, state: CrawlState, config: AdaptiveConfig) -> bool: async def should_stop(self, state: AdaptiveCrawlResult, config: AdaptiveConfig) -> bool:
"""Stop based on learning curve convergence""" """Stop based on learning curve convergence"""
confidence = state.metrics.get('confidence', 0.0) confidence = state.metrics.get('confidence', 0.0)
@@ -1139,7 +1139,7 @@ class EmbeddingStrategy(CrawlStrategy):
return False return False
def get_quality_confidence(self, state: CrawlState) -> float: def get_quality_confidence(self, state: AdaptiveCrawlResult) -> float:
"""Calculate quality-based confidence score for display""" """Calculate quality-based confidence score for display"""
learning_score = state.metrics.get('learning_score', 0.0) learning_score = state.metrics.get('learning_score', 0.0)
validation_score = state.metrics.get('validation_confidence', 0.0) validation_score = state.metrics.get('validation_confidence', 0.0)
@@ -1166,7 +1166,7 @@ class EmbeddingStrategy(CrawlStrategy):
return confidence return confidence
async def update_state(self, state: CrawlState, new_results: List[CrawlResult]) -> None: async def update_state(self, state: AdaptiveCrawlResult, new_results: List[CrawlResult]) -> None:
"""Update embeddings and coverage metrics with deduplication""" """Update embeddings and coverage metrics with deduplication"""
from .utils import get_text_embeddings from .utils import get_text_embeddings
@@ -1246,7 +1246,7 @@ class AdaptiveCrawler:
self.strategy = self._create_strategy(self.config.strategy) self.strategy = self._create_strategy(self.config.strategy)
# Initialize state # Initialize state
self.state: Optional[CrawlState] = None self.state: Optional[AdaptiveCrawlResult] = None
# Track if we own the crawler (for cleanup) # Track if we own the crawler (for cleanup)
self._owns_crawler = crawler is None self._owns_crawler = crawler is None
@@ -1266,14 +1266,14 @@ class AdaptiveCrawler:
async def digest(self, async def digest(self,
start_url: str, start_url: str,
query: str, query: str,
resume_from: Optional[str] = None) -> CrawlState: resume_from: Optional[str] = None) -> AdaptiveCrawlResult:
"""Main entry point for adaptive crawling""" """Main entry point for adaptive crawling"""
# Initialize or resume state # Initialize or resume state
if resume_from: if resume_from:
self.state = CrawlState.load(resume_from) self.state = AdaptiveCrawlResult.load(resume_from)
self.state.query = query # Update query in case it changed self.state.query = query # Update query in case it changed
else: else:
self.state = CrawlState( self.state = AdaptiveCrawlResult(
crawled_urls=set(), crawled_urls=set(),
knowledge_base=[], knowledge_base=[],
pending_links=[], pending_links=[],
@@ -1803,7 +1803,7 @@ class AdaptiveCrawler:
# Initialize state if needed # Initialize state if needed
if not self.state: if not self.state:
self.state = CrawlState() self.state = AdaptiveCrawlResult()
# Add imported results # Add imported results
self.state.knowledge_base.extend(imported_results) self.state.knowledge_base.extend(imported_results)

View File

@@ -9,7 +9,7 @@ import asyncio
import re import re
from typing import List, Dict, Set from typing import List, Dict, Set
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig
from crawl4ai.adaptive_crawler import CrawlState, Link from crawl4ai.adaptive_crawler import AdaptiveCrawlResult, Link
import math import math
@@ -45,7 +45,7 @@ class APIDocumentationStrategy:
r'/legal/' r'/legal/'
] ]
def score_link(self, link: Link, query: str, state: CrawlState) -> float: def score_link(self, link: Link, query: str, state: AdaptiveCrawlResult) -> float:
"""Custom link scoring for API documentation""" """Custom link scoring for API documentation"""
score = 1.0 score = 1.0
url = link.href.lower() url = link.href.lower()
@@ -77,7 +77,7 @@ class APIDocumentationStrategy:
return score return score
def calculate_api_coverage(self, state: CrawlState, query: str) -> Dict[str, float]: def calculate_api_coverage(self, state: AdaptiveCrawlResult, query: str) -> Dict[str, float]:
"""Calculate specialized coverage metrics for API documentation""" """Calculate specialized coverage metrics for API documentation"""
metrics = { metrics = {
'endpoint_coverage': 0.0, 'endpoint_coverage': 0.0,

View File

@@ -130,7 +130,7 @@ Factors:
```python ```python
class CustomLinkScorer: class CustomLinkScorer:
def score(self, link: Link, query: str, state: CrawlState) -> float: def score(self, link: Link, query: str, state: AdaptiveCrawlResult) -> float:
# Prioritize specific URL patterns # Prioritize specific URL patterns
if "/api/reference/" in link.href: if "/api/reference/" in link.href:
return 2.0 # Double the score return 2.0 # Double the score
@@ -325,17 +325,17 @@ with open("crawl_analysis.json", "w") as f:
from crawl4ai.adaptive_crawler import BaseStrategy from crawl4ai.adaptive_crawler import BaseStrategy
class DomainSpecificStrategy(BaseStrategy): class DomainSpecificStrategy(BaseStrategy):
def calculate_coverage(self, state: CrawlState) -> float: def calculate_coverage(self, state: AdaptiveCrawlResult) -> float:
# Custom coverage calculation # Custom coverage calculation
# e.g., weight certain terms more heavily # e.g., weight certain terms more heavily
pass pass
def calculate_consistency(self, state: CrawlState) -> float: def calculate_consistency(self, state: AdaptiveCrawlResult) -> float:
# Custom consistency logic # Custom consistency logic
# e.g., domain-specific validation # e.g., domain-specific validation
pass pass
def rank_links(self, links: List[Link], state: CrawlState) -> List[Link]: def rank_links(self, links: List[Link], state: AdaptiveCrawlResult) -> List[Link]:
# Custom link ranking # Custom link ranking
# e.g., prioritize specific URL patterns # e.g., prioritize specific URL patterns
pass pass
@@ -359,7 +359,7 @@ class HybridStrategy(BaseStrategy):
URLPatternStrategy() URLPatternStrategy()
] ]
def calculate_confidence(self, state: CrawlState) -> float: def calculate_confidence(self, state: AdaptiveCrawlResult) -> float:
# Weighted combination of strategies # Weighted combination of strategies
scores = [s.calculate_confidence(state) for s in self.strategies] scores = [s.calculate_confidence(state) for s in self.strategies]
weights = [0.5, 0.3, 0.2] weights = [0.5, 0.3, 0.2]

View File

@@ -27,7 +27,7 @@ async def digest(
start_url: str, start_url: str,
query: str, query: str,
resume_from: Optional[Union[str, Path]] = None resume_from: Optional[Union[str, Path]] = None
) -> CrawlState ) -> AdaptiveCrawlResult
``` ```
#### Parameters #### Parameters
@@ -38,7 +38,7 @@ async def digest(
#### Returns #### Returns
- **CrawlState**: The final crawl state containing all crawled URLs, knowledge base, and metrics - **AdaptiveCrawlResult**: The final crawl state containing all crawled URLs, knowledge base, and metrics
#### Example #### Example
@@ -92,7 +92,7 @@ Access to the current crawl state.
```python ```python
@property @property
def state(self) -> CrawlState def state(self) -> AdaptiveCrawlResult
``` ```
## Methods ## Methods

View File

@@ -9,7 +9,7 @@ async def digest(
start_url: str, start_url: str,
query: str, query: str,
resume_from: Optional[Union[str, Path]] = None resume_from: Optional[Union[str, Path]] = None
) -> CrawlState ) -> AdaptiveCrawlResult
``` ```
## Parameters ## Parameters
@@ -31,7 +31,7 @@ async def digest(
## Return Value ## Return Value
Returns a `CrawlState` object containing: Returns a `AdaptiveCrawlResult` object containing:
- **crawled_urls** (`Set[str]`): All URLs that have been crawled - **crawled_urls** (`Set[str]`): All URLs that have been crawled
- **knowledge_base** (`List[CrawlResult]`): Collection of crawled pages with content - **knowledge_base** (`List[CrawlResult]`): Collection of crawled pages with content

View File

@@ -23,7 +23,7 @@ from crawl4ai import (
AsyncWebCrawler, AsyncWebCrawler,
AdaptiveCrawler, AdaptiveCrawler,
AdaptiveConfig, AdaptiveConfig,
CrawlState AdaptiveCrawlResult
) )

View File

@@ -13,7 +13,7 @@ import math
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
from crawl4ai.adaptive_crawler import CrawlState, StatisticalStrategy from crawl4ai.adaptive_crawler import AdaptiveCrawlResult, StatisticalStrategy
from crawl4ai.models import CrawlResult from crawl4ai.models import CrawlResult
@@ -37,7 +37,7 @@ class ConfidenceTestHarness:
print("=" * 80) print("=" * 80)
# Initialize state # Initialize state
state = CrawlState(query=self.query) state = AdaptiveCrawlResult(query=self.query)
# Create crawler # Create crawler
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
@@ -107,7 +107,7 @@ class ConfidenceTestHarness:
state.metrics['prev_confidence'] = confidence state.metrics['prev_confidence'] = confidence
def _debug_coverage_calculation(self, state: CrawlState, query_terms: List[str]): def _debug_coverage_calculation(self, state: AdaptiveCrawlResult, query_terms: List[str]):
"""Debug coverage calculation step by step""" """Debug coverage calculation step by step"""
coverage_score = 0.0 coverage_score = 0.0
max_possible_score = 0.0 max_possible_score = 0.0
@@ -136,7 +136,7 @@ class ConfidenceTestHarness:
new_coverage = self._calculate_coverage_new(state, query_terms) new_coverage = self._calculate_coverage_new(state, query_terms)
print(f" → New Coverage: {new_coverage:.3f}") print(f" → New Coverage: {new_coverage:.3f}")
def _calculate_coverage_new(self, state: CrawlState, query_terms: List[str]) -> float: def _calculate_coverage_new(self, state: AdaptiveCrawlResult, query_terms: List[str]) -> float:
"""New coverage calculation without IDF""" """New coverage calculation without IDF"""
if not query_terms or state.total_documents == 0: if not query_terms or state.total_documents == 0:
return 0.0 return 0.0

View File

@@ -15,7 +15,7 @@ import os
sys.path.append(str(Path(__file__).parent.parent.parent)) sys.path.append(str(Path(__file__).parent.parent.parent))
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig
from crawl4ai.adaptive_crawler import EmbeddingStrategy, CrawlState from crawl4ai.adaptive_crawler import EmbeddingStrategy, AdaptiveCrawlResult
from crawl4ai.models import CrawlResult from crawl4ai.models import CrawlResult
@@ -132,7 +132,7 @@ async def test_embedding_performance():
strategy.config = config strategy.config = config
# Initialize state # Initialize state
state = CrawlState() state = AdaptiveCrawlResult()
state.query = "async await coroutines event loops tasks" state.query = "async await coroutines event loops tasks"
# Start performance monitoring # Start performance monitoring

View File

@@ -20,7 +20,7 @@ from crawl4ai import (
AsyncWebCrawler, AsyncWebCrawler,
AdaptiveCrawler, AdaptiveCrawler,
AdaptiveConfig, AdaptiveConfig,
CrawlState AdaptiveCrawlResult
) )
console = Console() console = Console()