Add cancellation support for deep crawl strategies

- Add should_cancel callback parameter to BFS, DFS, and BestFirst strategies
- Add cancel() method for immediate cancellation (thread-safe)
- Add cancelled property to check cancellation status
- Add _check_cancellation() internal method supporting both sync/async callbacks
- Reset cancel event on strategy reuse for multiple crawls
- Include cancelled flag in state notifications via on_state_change
- Handle callback exceptions gracefully (fail-open, log warning)
- Add comprehensive test suite with 26 tests covering all edge cases

This enables external callers (e.g., cloud platforms) to stop a running
deep crawl mid-execution and retrieve partial results.
This commit is contained in:
unclecode
2026-01-22 06:08:25 +00:00
parent 418bfcfd3b
commit f6897d1429
4 changed files with 828 additions and 6 deletions

View File

@@ -2,7 +2,7 @@
import asyncio
import logging
from datetime import datetime
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union
from urllib.parse import urlparse
from ..models import TraversalStats
@@ -44,6 +44,8 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
# Optional resume/callback parameters for crash recovery
resume_state: Optional[Dict[str, Any]] = None,
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
# Optional cancellation callback - checked before each URL is processed
should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None,
):
self.max_depth = max_depth
self.filter_chain = filter_chain
@@ -63,6 +65,7 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
# Store for use in arun methods
self._resume_state = resume_state
self._on_state_change = on_state_change
self._should_cancel = should_cancel
self._last_state: Optional[Dict[str, Any]] = None
# Shadow list for queue items (only used when on_state_change is set)
self._queue_shadow: Optional[List[Tuple[float, int, str, Optional[str]]]] = None
@@ -89,6 +92,55 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
return True
def cancel(self) -> None:
"""
Cancel the crawl. Thread-safe, can be called from any context.
The crawl will stop before processing the next URL. The current URL
being processed (if any) will complete before the crawl stops.
"""
self._cancel_event.set()
@property
def cancelled(self) -> bool:
"""
Check if the crawl was/is cancelled. Thread-safe.
Returns:
True if the crawl has been cancelled, False otherwise.
"""
return self._cancel_event.is_set()
async def _check_cancellation(self) -> bool:
"""
Check if crawl should be cancelled.
Handles both internal cancel flag and external should_cancel callback.
Supports both sync and async callbacks.
Returns:
True if crawl should be cancelled, False otherwise.
"""
if self._cancel_event.is_set():
return True
if self._should_cancel:
try:
# Handle both sync and async callbacks
result = self._should_cancel()
if asyncio.iscoroutine(result):
result = await result
if result:
self._cancel_event.set()
self.stats.end_time = datetime.now()
return True
except Exception as e:
# Fail-open: log warning and continue crawling
self.logger.warning(f"should_cancel callback error: {e}")
return False
async def link_discovery(
self,
result: CrawlResult,
@@ -148,6 +200,9 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
The queue items are tuples of (score, depth, url, parent_url). Lower scores
are treated as higher priority. URLs are processed in batches for efficiency.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
# Conditional state initialization for resume support
@@ -180,7 +235,12 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
if self._pages_crawled >= self.max_pages:
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
break
# Check external cancellation callback before processing this batch
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
# Calculate how many more URLs we can process in this batch
remaining = self.max_pages - self._pages_crawled
batch_size = min(BATCH_SIZE, remaining)
@@ -262,11 +322,26 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
# End of crawl.
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change and self._queue_shadow is not None:
state = {
"strategy_type": "best_first",
"visited": list(visited),
"queue_items": [
{"score": s, "depth": d, "url": u, "parent_url": p}
for s, d, u, p in self._queue_shadow
],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
async def _arun_batch(
self,

View File

@@ -2,7 +2,7 @@
import asyncio
import logging
from datetime import datetime
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union
from urllib.parse import urlparse
from ..models import TraversalStats
@@ -34,6 +34,8 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
# Optional resume/callback parameters for crash recovery
resume_state: Optional[Dict[str, Any]] = None,
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
# Optional cancellation callback - checked before each URL is processed
should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None,
):
self.max_depth = max_depth
self.filter_chain = filter_chain
@@ -54,6 +56,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
# Store for use in arun methods
self._resume_state = resume_state
self._on_state_change = on_state_change
self._should_cancel = should_cancel
self._last_state: Optional[Dict[str, Any]] = None
async def can_process_url(self, url: str, depth: int) -> bool:
@@ -78,6 +81,55 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
return True
def cancel(self) -> None:
"""
Cancel the crawl. Thread-safe, can be called from any context.
The crawl will stop before processing the next URL. The current URL
being processed (if any) will complete before the crawl stops.
"""
self._cancel_event.set()
@property
def cancelled(self) -> bool:
"""
Check if the crawl was/is cancelled. Thread-safe.
Returns:
True if the crawl has been cancelled, False otherwise.
"""
return self._cancel_event.is_set()
async def _check_cancellation(self) -> bool:
"""
Check if crawl should be cancelled.
Handles both internal cancel flag and external should_cancel callback.
Supports both sync and async callbacks.
Returns:
True if crawl should be cancelled, False otherwise.
"""
if self._cancel_event.is_set():
return True
if self._should_cancel:
try:
# Handle both sync and async callbacks
result = self._should_cancel()
if asyncio.iscoroutine(result):
result = await result
if result:
self._cancel_event.set()
self.stats.end_time = datetime.now()
return True
except Exception as e:
# Fail-open: log warning and continue crawling
self.logger.warning(f"should_cancel callback error: {e}")
return False
async def link_discovery(
self,
result: CrawlResult,
@@ -162,6 +214,9 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
Batch (non-streaming) mode:
Processes one BFS level at a time, then yields all the results.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
# Conditional state initialization for resume support
if self._resume_state:
visited = set(self._resume_state.get("visited", []))
@@ -185,7 +240,12 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
if self._pages_crawled >= self.max_pages:
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
break
# Check external cancellation callback before processing this level
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
next_level: List[Tuple[str, Optional[str]]] = []
urls = [url for url, _ in current_level]
@@ -218,12 +278,26 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
current_level = next_level
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change:
state = {
"strategy_type": "bfs",
"visited": list(visited),
"pending": [{"url": u, "parent_url": p} for u, p in current_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
return results
async def _arun_stream(
@@ -236,6 +310,9 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
Streaming mode:
Processes one BFS level at a time and yields results immediately as they arrive.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
# Conditional state initialization for resume support
if self._resume_state:
visited = set(self._resume_state.get("visited", []))
@@ -252,6 +329,11 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
depths: Dict[str, int] = {start_url: 0}
while current_level and not self._cancel_event.is_set():
# Check external cancellation callback before processing this level
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
next_level: List[Tuple[str, Optional[str]]] = []
urls = [url for url, _ in current_level]
visited.update(urls)
@@ -293,6 +375,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
@@ -301,9 +384,22 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
# by considering these URLs as visited but not counting them toward the max_pages limit
if results_count == 0 and urls:
self.logger.warning(f"No results returned for {len(urls)} URLs, marking as visited")
current_level = next_level
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change:
state = {
"strategy_type": "bfs",
"visited": list(visited),
"pending": [{"url": u, "parent_url": p} for u, p in current_level],
"depths": depths,
"pages_crawled": self._pages_crawled,
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
async def shutdown(self) -> None:
"""
Clean up resources and signal cancellation of the crawl.

View File

@@ -1,4 +1,5 @@
# dfs_deep_crawl_strategy.py
import asyncio
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple
from ..models import CrawlResult
@@ -38,6 +39,9 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
in control of traversal. Every successful page bumps ``_pages_crawled`` and
seeds new stack items discovered via :meth:`link_discovery`.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
# Conditional state initialization for resume support
if self._resume_state:
visited = set(self._resume_state.get("visited", []))
@@ -59,6 +63,11 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
self._reset_seen(start_url)
while stack and not self._cancel_event.is_set():
# Check external cancellation callback before processing this URL
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
url, parent, depth = stack.pop()
if url in visited or depth > self.max_depth:
continue
@@ -105,9 +114,28 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
"depths": depths,
"pages_crawled": self._pages_crawled,
"dfs_seen": list(self._dfs_seen),
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change:
state = {
"strategy_type": "dfs",
"visited": list(visited),
"stack": [
{"url": u, "parent_url": p, "depth": d}
for u, p, d in stack
],
"depths": depths,
"pages_crawled": self._pages_crawled,
"dfs_seen": list(self._dfs_seen),
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
return results
async def _arun_stream(
@@ -123,6 +151,9 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
yielded before we even look at the next stack entry. Successful crawls
still feed :meth:`link_discovery`, keeping DFS order intact.
"""
# Reset cancel event for strategy reuse
self._cancel_event = asyncio.Event()
# Conditional state initialization for resume support
if self._resume_state:
visited = set(self._resume_state.get("visited", []))
@@ -141,6 +172,11 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
self._reset_seen(start_url)
while stack and not self._cancel_event.is_set():
# Check external cancellation callback before processing this URL
if await self._check_cancellation():
self.logger.info("Crawl cancelled by user")
break
url, parent, depth = stack.pop()
if url in visited or depth > self.max_depth:
continue
@@ -183,10 +219,28 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
"depths": depths,
"pages_crawled": self._pages_crawled,
"dfs_seen": list(self._dfs_seen),
"cancelled": self._cancel_event.is_set(),
}
self._last_state = state
await self._on_state_change(state)
# Final state update if cancelled
if self._cancel_event.is_set() and self._on_state_change:
state = {
"strategy_type": "dfs",
"visited": list(visited),
"stack": [
{"url": u, "parent_url": p, "depth": d}
for u, p, d in stack
],
"depths": depths,
"pages_crawled": self._pages_crawled,
"dfs_seen": list(self._dfs_seen),
"cancelled": True,
}
self._last_state = state
await self._on_state_change(state)
async def link_discovery(
self,
result: CrawlResult,

View File

@@ -0,0 +1,597 @@
"""
Test Suite: Deep Crawl Cancellation Tests
Tests that verify:
1. should_cancel callback is called before each URL
2. cancel() method immediately stops the crawl
3. cancelled property correctly reflects state
4. Strategy reuse works after cancellation
5. Both sync and async should_cancel callbacks work
6. Callback exceptions don't crash the crawl
7. State notifications include cancelled flag
"""
import pytest
import asyncio
from typing import Dict, Any, List
from unittest.mock import MagicMock
from crawl4ai.deep_crawling import (
BFSDeepCrawlStrategy,
DFSDeepCrawlStrategy,
BestFirstCrawlingStrategy,
)
# ============================================================================
# Helper Functions for Mock Crawler
# ============================================================================
def create_mock_config(stream=False):
"""Create a mock CrawlerRunConfig."""
config = MagicMock()
config.stream = stream
def clone_config(**kwargs):
"""Clone returns a new config with overridden values."""
new_config = MagicMock()
new_config.stream = kwargs.get('stream', stream)
new_config.clone = MagicMock(side_effect=clone_config)
return new_config
config.clone = MagicMock(side_effect=clone_config)
return config
def create_mock_crawler_with_links(num_links: int = 3):
"""Create mock crawler that returns results with links."""
call_count = 0
async def mock_arun_many(urls, config):
nonlocal call_count
results = []
for url in urls:
call_count += 1
result = MagicMock()
result.url = url
result.success = True
result.metadata = {}
# Generate child links
links = []
for i in range(num_links):
link_url = f"{url}/child{call_count}_{i}"
links.append({"href": link_url})
result.links = {"internal": links, "external": []}
results.append(result)
# For streaming mode, return async generator
if config.stream:
async def gen():
for r in results:
yield r
return gen()
return results
crawler = MagicMock()
crawler.arun_many = mock_arun_many
return crawler
def create_mock_crawler_tracking(crawl_order: List[str], return_no_links: bool = False):
"""Create mock crawler that tracks crawl order."""
async def mock_arun_many(urls, config):
results = []
for url in urls:
crawl_order.append(url)
result = MagicMock()
result.url = url
result.success = True
result.metadata = {}
result.links = {"internal": [], "external": []} if return_no_links else {"internal": [{"href": f"{url}/child"}], "external": []}
results.append(result)
# For streaming mode, return async generator
if config.stream:
async def gen():
for r in results:
yield r
return gen()
return results
crawler = MagicMock()
crawler.arun_many = mock_arun_many
return crawler
# ============================================================================
# TEST SUITE: Cancellation via should_cancel Callback
# ============================================================================
class TestBFSCancellation:
"""BFS strategy cancellation tests."""
@pytest.mark.asyncio
async def test_cancel_via_async_callback(self):
"""Verify async should_cancel callback stops crawl."""
pages_crawled = 0
cancel_after = 3
async def check_cancel():
return pages_crawled >= cancel_after
async def track_pages(state: Dict[str, Any]):
nonlocal pages_crawled
pages_crawled = state.get("pages_crawled", 0)
strategy = BFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
should_cancel=check_cancel,
on_state_change=track_pages,
)
mock_crawler = create_mock_crawler_with_links(num_links=5)
mock_config = create_mock_config()
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
# Should have stopped after cancel_after pages
assert strategy.cancelled == True
assert strategy._pages_crawled >= cancel_after
assert strategy._pages_crawled < 100 # Should not have crawled all pages
@pytest.mark.asyncio
async def test_cancel_via_sync_callback(self):
"""Verify sync should_cancel callback works."""
cancel_flag = False
def check_cancel():
return cancel_flag
async def set_cancel_after_3(state: Dict[str, Any]):
nonlocal cancel_flag
if state.get("pages_crawled", 0) >= 3:
cancel_flag = True
strategy = BFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
should_cancel=check_cancel,
on_state_change=set_cancel_after_3,
)
mock_crawler = create_mock_crawler_with_links(num_links=5)
mock_config = create_mock_config()
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == True
assert strategy._pages_crawled >= 3
@pytest.mark.asyncio
async def test_cancel_method_stops_crawl(self):
"""Verify cancel() method immediately stops the crawl."""
strategy = BFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
)
async def cancel_after_2_pages(state: Dict[str, Any]):
if state.get("pages_crawled", 0) >= 2:
strategy.cancel()
strategy._on_state_change = cancel_after_2_pages
mock_crawler = create_mock_crawler_with_links(num_links=5)
mock_config = create_mock_config()
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == True
assert strategy._pages_crawled >= 2
assert strategy._pages_crawled < 100
@pytest.mark.asyncio
async def test_cancelled_property_reflects_state(self):
"""Verify cancelled property correctly reflects state."""
strategy = BFSDeepCrawlStrategy(max_depth=2, max_pages=10)
# Before cancel
assert strategy.cancelled == False
# After cancel()
strategy.cancel()
assert strategy.cancelled == True
@pytest.mark.asyncio
async def test_strategy_reuse_after_cancellation(self):
"""Verify strategy can be reused after cancellation."""
call_count = 0
async def cancel_first_time():
return call_count == 1
strategy = BFSDeepCrawlStrategy(
max_depth=1,
max_pages=5,
should_cancel=cancel_first_time,
)
mock_crawler = create_mock_crawler_with_links(num_links=2)
mock_config = create_mock_config()
# First crawl - should be cancelled
call_count = 1
results1 = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == True
# Second crawl - should work normally (cancel_first_time returns False)
call_count = 2
results2 = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == False
assert len(results2) > len(results1)
@pytest.mark.asyncio
async def test_callback_exception_continues_crawl(self):
"""Verify callback exception doesn't crash crawl (fail-open)."""
exception_count = 0
async def failing_callback():
nonlocal exception_count
exception_count += 1
raise ConnectionError("Redis connection failed")
strategy = BFSDeepCrawlStrategy(
max_depth=1,
max_pages=3,
should_cancel=failing_callback,
)
mock_crawler = create_mock_crawler_with_links(num_links=2)
mock_config = create_mock_config()
# Should not raise, should complete crawl
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert exception_count > 0 # Callback was called
assert len(results) > 0 # Crawl completed
assert strategy.cancelled == False # Not cancelled due to exception
@pytest.mark.asyncio
async def test_state_includes_cancelled_flag(self):
"""Verify state notifications include cancelled flag."""
states: List[Dict] = []
cancel_at = 3
async def capture_state(state: Dict[str, Any]):
states.append(state)
async def cancel_after_3():
return len(states) >= cancel_at
strategy = BFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
should_cancel=cancel_after_3,
on_state_change=capture_state,
)
mock_crawler = create_mock_crawler_with_links(num_links=5)
mock_config = create_mock_config()
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
# Last state should have cancelled=True
assert len(states) > 0
assert states[-1].get("cancelled") == True
@pytest.mark.asyncio
async def test_cancel_before_first_url(self):
"""Verify cancel before first URL returns empty results."""
async def always_cancel():
return True
strategy = BFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
should_cancel=always_cancel,
)
mock_crawler = create_mock_crawler_with_links(num_links=5)
mock_config = create_mock_config()
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == True
assert len(results) == 0
class TestDFSCancellation:
"""DFS strategy cancellation tests."""
@pytest.mark.asyncio
async def test_cancel_via_callback(self):
"""Verify DFS respects should_cancel callback."""
pages_crawled = 0
cancel_after = 3
async def check_cancel():
return pages_crawled >= cancel_after
async def track_pages(state: Dict[str, Any]):
nonlocal pages_crawled
pages_crawled = state.get("pages_crawled", 0)
strategy = DFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
should_cancel=check_cancel,
on_state_change=track_pages,
)
mock_crawler = create_mock_crawler_with_links(num_links=3)
mock_config = create_mock_config()
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == True
assert strategy._pages_crawled >= cancel_after
assert strategy._pages_crawled < 100
@pytest.mark.asyncio
async def test_cancel_method_inherited(self):
"""Verify DFS inherits cancel() from BFS."""
strategy = DFSDeepCrawlStrategy(max_depth=2, max_pages=10)
assert hasattr(strategy, 'cancel')
assert hasattr(strategy, 'cancelled')
assert hasattr(strategy, '_check_cancellation')
strategy.cancel()
assert strategy.cancelled == True
@pytest.mark.asyncio
async def test_stream_mode_cancellation(self):
"""Verify DFS stream mode respects cancellation."""
results_count = 0
cancel_after = 2
async def check_cancel():
return results_count >= cancel_after
strategy = DFSDeepCrawlStrategy(
max_depth=5,
max_pages=100,
should_cancel=check_cancel,
)
mock_crawler = create_mock_crawler_with_links(num_links=3)
mock_config = create_mock_config(stream=True)
async for result in strategy._arun_stream("https://example.com", mock_crawler, mock_config):
results_count += 1
assert strategy.cancelled == True
assert results_count >= cancel_after
assert results_count < 100
class TestBestFirstCancellation:
"""Best-First strategy cancellation tests."""
@pytest.mark.asyncio
async def test_cancel_via_callback(self):
"""Verify Best-First respects should_cancel callback."""
pages_crawled = 0
cancel_after = 3
async def check_cancel():
return pages_crawled >= cancel_after
async def track_pages(state: Dict[str, Any]):
nonlocal pages_crawled
pages_crawled = state.get("pages_crawled", 0)
strategy = BestFirstCrawlingStrategy(
max_depth=5,
max_pages=100,
should_cancel=check_cancel,
on_state_change=track_pages,
)
mock_crawler = create_mock_crawler_with_links(num_links=3)
mock_config = create_mock_config(stream=True)
async for _ in strategy._arun_stream("https://example.com", mock_crawler, mock_config):
pass
assert strategy.cancelled == True
assert strategy._pages_crawled >= cancel_after
assert strategy._pages_crawled < 100
@pytest.mark.asyncio
async def test_cancel_method_works(self):
"""Verify Best-First cancel() method works."""
strategy = BestFirstCrawlingStrategy(max_depth=2, max_pages=10)
assert strategy.cancelled == False
strategy.cancel()
assert strategy.cancelled == True
@pytest.mark.asyncio
async def test_batch_mode_cancellation(self):
"""Verify Best-First batch mode respects cancellation."""
pages_crawled = 0
cancel_after = 2
async def check_cancel():
return pages_crawled >= cancel_after
async def track_pages(state: Dict[str, Any]):
nonlocal pages_crawled
pages_crawled = state.get("pages_crawled", 0)
strategy = BestFirstCrawlingStrategy(
max_depth=5,
max_pages=100,
should_cancel=check_cancel,
on_state_change=track_pages,
)
mock_crawler = create_mock_crawler_with_links(num_links=3)
mock_config = create_mock_config(stream=False)
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == True
assert len(results) >= cancel_after
assert len(results) < 100
class TestCrossStrategyCancellation:
"""Tests that apply to all strategies."""
@pytest.mark.asyncio
@pytest.mark.parametrize("strategy_class", [
BFSDeepCrawlStrategy,
DFSDeepCrawlStrategy,
BestFirstCrawlingStrategy,
])
async def test_no_cancel_callback_means_no_cancellation(self, strategy_class):
"""Verify crawl completes normally without should_cancel."""
strategy = strategy_class(max_depth=1, max_pages=5)
mock_crawler = create_mock_crawler_with_links(num_links=2)
if strategy_class == BestFirstCrawlingStrategy:
mock_config = create_mock_config(stream=True)
results = []
async for r in strategy._arun_stream("https://example.com", mock_crawler, mock_config):
results.append(r)
else:
mock_config = create_mock_config()
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
assert strategy.cancelled == False
assert len(results) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("strategy_class", [
BFSDeepCrawlStrategy,
DFSDeepCrawlStrategy,
BestFirstCrawlingStrategy,
])
async def test_cancel_thread_safety(self, strategy_class):
"""Verify cancel() is thread-safe (doesn't raise)."""
strategy = strategy_class(max_depth=2, max_pages=10)
# Call cancel from multiple "threads" (simulated)
for _ in range(10):
strategy.cancel()
# Should be cancelled without errors
assert strategy.cancelled == True
@pytest.mark.asyncio
@pytest.mark.parametrize("strategy_class", [
BFSDeepCrawlStrategy,
DFSDeepCrawlStrategy,
BestFirstCrawlingStrategy,
])
async def test_should_cancel_param_accepted(self, strategy_class):
"""Verify should_cancel parameter is accepted by constructor."""
async def dummy_cancel():
return False
# Should not raise
strategy = strategy_class(
max_depth=2,
max_pages=10,
should_cancel=dummy_cancel,
)
assert strategy._should_cancel == dummy_cancel
class TestCancellationEdgeCases:
"""Edge case tests for cancellation."""
@pytest.mark.asyncio
async def test_cancel_during_batch_processing(self):
"""Verify cancellation during batch doesn't lose results."""
results_count = 0
async def cancel_mid_batch():
# Cancel after receiving first result
return results_count >= 1
strategy = BFSDeepCrawlStrategy(
max_depth=2,
max_pages=100,
should_cancel=cancel_mid_batch,
)
async def track_results(state):
nonlocal results_count
results_count = state.get("pages_crawled", 0)
strategy._on_state_change = track_results
mock_crawler = create_mock_crawler_with_links(num_links=5)
mock_config = create_mock_config()
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
# Should have at least the first batch of results
assert len(results) >= 1
assert strategy.cancelled == True
@pytest.mark.asyncio
async def test_partial_results_on_cancel(self):
"""Verify partial results are returned on cancellation."""
cancel_after = 5
async def check_cancel():
return strategy._pages_crawled >= cancel_after
strategy = BFSDeepCrawlStrategy(
max_depth=10,
max_pages=1000,
should_cancel=check_cancel,
)
mock_crawler = create_mock_crawler_with_links(num_links=10)
mock_config = create_mock_config()
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
# Should have results up to cancellation point
assert len(results) >= cancel_after
assert strategy.cancelled == True
@pytest.mark.asyncio
async def test_cancel_callback_called_once_per_level_bfs(self):
"""Verify BFS checks cancellation once per level."""
check_count = 0
async def count_checks():
nonlocal check_count
check_count += 1
return False # Never cancel
strategy = BFSDeepCrawlStrategy(
max_depth=2,
max_pages=10,
should_cancel=count_checks,
)
mock_crawler = create_mock_crawler_with_links(num_links=2)
mock_config = create_mock_config()
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
# Should have checked at least once per level
assert check_count >= 1