Files
crawl4ai/crawl4ai/deep_crawling/bfs_strategy.py
unclecode f6897d1429 Add cancellation support for deep crawl strategies
- Add should_cancel callback parameter to BFS, DFS, and BestFirst strategies
- Add cancel() method for immediate cancellation (thread-safe)
- Add cancelled property to check cancellation status
- Add _check_cancellation() internal method supporting both sync/async callbacks
- Reset cancel event on strategy reuse for multiple crawls
- Include cancelled flag in state notifications via on_state_change
- Handle callback exceptions gracefully (fail-open, log warning)
- Add comprehensive test suite with 26 tests covering all edge cases

This enables external callers (e.g., cloud platforms) to stop a running
deep crawl mid-execution and retrieve partial results.
2026-01-22 06:08:25 +00:00

421 lines
17 KiB
Python

# bfs_deep_crawl_strategy.py
import asyncio
import logging
from datetime import datetime
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union
from urllib.parse import urlparse
from ..models import TraversalStats
from .filters import FilterChain
from .scorers import URLScorer
from . import DeepCrawlStrategy
from ..types import AsyncWebCrawler, CrawlerRunConfig, CrawlResult
from ..utils import normalize_url_for_deep_crawl, efficient_normalize_url_for_deep_crawl
from math import inf as infinity
class BFSDeepCrawlStrategy(DeepCrawlStrategy):
"""
Breadth-First Search deep crawling strategy.
Core functions:
- arun: Main entry point; splits execution into batch or stream modes.
- link_discovery: Extracts, filters, and (if needed) scores the outgoing URLs.
- can_process_url: Validates URL format and applies the filter chain.
"""
def __init__(
self,
max_depth: int,
filter_chain: FilterChain = FilterChain(),
url_scorer: Optional[URLScorer] = None,
include_external: bool = False,
score_threshold: float = -infinity,
max_pages: int = infinity,
logger: Optional[logging.Logger] = None,
# Optional resume/callback parameters for crash recovery
resume_state: Optional[Dict[str, Any]] = None,
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
# Optional cancellation callback - checked before each URL is processed
should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None,
):
self.max_depth = max_depth
self.filter_chain = filter_chain
self.url_scorer = url_scorer
self.include_external = include_external
self.score_threshold = score_threshold
self.max_pages = max_pages
# self.logger = logger or logging.getLogger(__name__)
# Ensure logger is always a Logger instance, not a dict from serialization
if isinstance(logger, logging.Logger):
self.logger = logger
else:
# Create a new logger if logger is None, dict, or any other non-Logger type
self.logger = logging.getLogger(__name__)
self.stats = TraversalStats(start_time=datetime.now())
self._cancel_event = asyncio.Event()
self._pages_crawled = 0
# Store for use in arun methods
self._resume_state = resume_state
self._on_state_change = on_state_change
self._should_cancel = should_cancel
self._last_state: Optional[Dict[str, Any]] = None
async def can_process_url(self, url: str, depth: int) -> bool:
"""
Validates the URL and applies the filter chain.
For the start URL (depth 0) filtering is bypassed.
"""
try:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
raise ValueError("Missing scheme or netloc")
if parsed.scheme not in ("http", "https"):
raise ValueError("Invalid scheme")
if "." not in parsed.netloc:
raise ValueError("Invalid domain")
except Exception as e:
self.logger.warning(f"Invalid URL: {url}, error: {e}")
return False
if depth != 0 and not await self.filter_chain.apply(url):
return False
return True
def cancel(self) -> None:
"""
Cancel the crawl. Thread-safe, can be called from any context.
The crawl will stop before processing the next URL. The current URL
being processed (if any) will complete before the crawl stops.
"""
self._cancel_event.set()
@property
def cancelled(self) -> bool:
"""
Check if the crawl was/is cancelled. Thread-safe.
Returns:
True if the crawl has been cancelled, False otherwise.
"""
return self._cancel_event.is_set()
async def _check_cancellation(self) -> bool:
"""
Check if crawl should be cancelled.
Handles both internal cancel flag and external should_cancel callback.
Supports both sync and async callbacks.
Returns:
True if crawl should be cancelled, False otherwise.
"""
if self._cancel_event.is_set():
return True
if self._should_cancel:
try:
# Handle both sync and async callbacks
result = self._should_cancel()
if asyncio.iscoroutine(result):
result = await result
if result:
self._cancel_event.set()
self.stats.end_time = datetime.now()
return True
except Exception as e:
# Fail-open: log warning and continue crawling
self.logger.warning(f"should_cancel callback error: {e}")
return False
async def link_discovery(
self,
result: CrawlResult,
source_url: str,
current_depth: int,
visited: Set[str],
next_level: List[Tuple[str, Optional[str]]],
depths: Dict[str, int],
) -> None:
"""
Extracts links from the crawl result, validates and scores them, and
prepares the next level of URLs.
Each valid URL is appended to next_level as a tuple (url, parent_url)
and its depth is tracked.
"""
next_depth = current_depth + 1
if next_depth > self.max_depth:
return
# If we've reached the max pages limit, don't discover new links
remaining_capacity = self.max_pages - self._pages_crawled
if remaining_capacity <= 0:
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping link discovery")
return
# Get internal links and, if enabled, external links.
links = result.links.get("internal", [])
if self.include_external:
links += result.links.get("external", [])
valid_links = []
# First collect all valid links
for link in links:
url = link.get("href")
# Strip URL fragments to avoid duplicate crawling
# base_url = url.split('#')[0] if url else url
base_url = normalize_url_for_deep_crawl(url, source_url)
if base_url in visited:
continue
if not await self.can_process_url(url, next_depth):
self.stats.urls_skipped += 1
continue
# Score the URL if a scorer is provided
score = self.url_scorer.score(base_url) if self.url_scorer else 0
# Skip URLs with scores below the threshold
if score < self.score_threshold:
self.logger.debug(f"URL {url} skipped: score {score} below threshold {self.score_threshold}")
self.stats.urls_skipped += 1
continue
visited.add(base_url)
valid_links.append((base_url, score))
# If we have more valid links than capacity, sort by score and take the top ones
if len(valid_links) > remaining_capacity:
if self.url_scorer:
# Sort by score in descending order
valid_links.sort(key=lambda x: x[1], reverse=True)
# Take only as many as we have capacity for
valid_links = valid_links[:remaining_capacity]
self.logger.info(f"Limiting to {remaining_capacity} URLs due to max_pages limit")
# Process the final selected links
for url, score in valid_links:
# attach the score to metadata if needed
if score:
result.metadata = result.metadata or {}
result.metadata["score"] = score
next_level.append((url, source_url))
depths[url] = next_depth
async def _arun_batch(
self,
start_url: str,
crawler: AsyncWebCrawler,
config: CrawlerRunConfig,
) -> List[CrawlResult]:
"""
Batch (non-streaming) mode:
Processes one BFS level at a time, then yields all the results.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
# Conditional state initialization for resume support
if self._resume_state:
visited = set(self._resume_state.get("visited", []))
current_level = [
(item["url"], item["parent_url"])
for item in self._resume_state.get("pending", [])
]
depths = dict(self._resume_state.get("depths", {}))
self._pages_crawled = self._resume_state.get("pages_crawled", 0)
else:
# Original initialization
visited: Set[str] = set()
# current_level holds tuples: (url, parent_url)
current_level: List[Tuple[str, Optional[str]]] = [(start_url, None)]
depths: Dict[str, int] = {start_url: 0}
results: List[CrawlResult] = []
while current_level and not self._cancel_event.is_set():
# Check if we've already reached max_pages before starting a new level
if self._pages_crawled >= self.max_pages:
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
break
# Check external cancellation callback before processing this level
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
next_level: List[Tuple[str, Optional[str]]] = []
urls = [url for url, _ in current_level]
# Clone the config to disable deep crawling recursion and enforce batch mode.
batch_config = config.clone(deep_crawl_strategy=None, stream=False)
batch_results = await crawler.arun_many(urls=urls, config=batch_config)
for result in batch_results:
url = result.url
depth = depths.get(url, 0)
result.metadata = result.metadata or {}
result.metadata["depth"] = depth
parent_url = next((parent for (u, parent) in current_level if u == url), None)
result.metadata["parent_url"] = parent_url
results.append(result)
# Only discover links from successful crawls
if result.success:
# Increment pages crawled per URL for accurate state tracking
self._pages_crawled += 1
# Link discovery will handle the max pages limit internally
await self.link_discovery(result, url, depth, visited, next_level, depths)
# Capture state after EACH URL processed (if callback set)
if self._on_state_change:
state = {
"strategy_type": "bfs",
"visited": list(visited),
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
current_level = next_level
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change:
state = {
"strategy_type": "bfs",
"visited": list(visited),
"pending": [{"url": u, "parent_url": p} for u, p in current_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
return results
async def _arun_stream(
self,
start_url: str,
crawler: AsyncWebCrawler,
config: CrawlerRunConfig,
) -> AsyncGenerator[CrawlResult, None]:
"""
Streaming mode:
Processes one BFS level at a time and yields results immediately as they arrive.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
# Conditional state initialization for resume support
if self._resume_state:
visited = set(self._resume_state.get("visited", []))
current_level = [
(item["url"], item["parent_url"])
for item in self._resume_state.get("pending", [])
]
depths = dict(self._resume_state.get("depths", {}))
self._pages_crawled = self._resume_state.get("pages_crawled", 0)
else:
# Original initialization
visited: Set[str] = set()
current_level: List[Tuple[str, Optional[str]]] = [(start_url, None)]
depths: Dict[str, int] = {start_url: 0}
while current_level and not self._cancel_event.is_set():
# Check external cancellation callback before processing this level
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
next_level: List[Tuple[str, Optional[str]]] = []
urls = [url for url, _ in current_level]
visited.update(urls)
stream_config = config.clone(deep_crawl_strategy=None, stream=True)
stream_gen = await crawler.arun_many(urls=urls, config=stream_config)
# Keep track of processed results for this batch
results_count = 0
async for result in stream_gen:
url = result.url
depth = depths.get(url, 0)
result.metadata = result.metadata or {}
result.metadata["depth"] = depth
parent_url = next((parent for (u, parent) in current_level if u == url), None)
result.metadata["parent_url"] = parent_url
# Count only successful crawls
if result.success:
self._pages_crawled += 1
# Check if we've reached the limit during batch processing
if self._pages_crawled >= self.max_pages:
self.logger.info(f"Max pages limit ({self.max_pages}) reached during batch, stopping crawl")
break # Exit the generator
results_count += 1
yield result
# Only discover links from successful crawls
if result.success:
# Link discovery will handle the max pages limit internally
await self.link_discovery(result, url, depth, visited, next_level, depths)
# Capture state after EACH URL processed (if callback set)
if self._on_state_change:
state = {
"strategy_type": "bfs",
"visited": list(visited),
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
# If we didn't get results back (e.g. due to errors), avoid getting stuck in an infinite loop
# by considering these URLs as visited but not counting them toward the max_pages limit
if results_count == 0 and urls:
self.logger.warning(f"No results returned for {len(urls)} URLs, marking as visited")
current_level = next_level
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change:
state = {
"strategy_type": "bfs",
"visited": list(visited),
"pending": [{"url": u, "parent_url": p} for u, p in current_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
async def shutdown(self) -> None:
"""
Clean up resources and signal cancellation of the crawl.
"""
self._cancel_event.set()
self.stats.end_time = datetime.now()
def export_state(self) -> Optional[Dict[str, Any]]:
"""
Export current crawl state for external persistence.
Note: This returns the last captured state. For real-time state,
use the on_state_change callback.
Returns:
Dict with strategy state, or None if no state captured yet.
"""
return self._last_state