From f6897d1429800c2d10dac875038ee985d07465a8 Mon Sep 17 00:00:00 2001 From: unclecode Date: Thu, 22 Jan 2026 06:08:25 +0000 Subject: [PATCH] Add cancellation support for deep crawl strategies - Add should_cancel callback parameter to BFS, DFS, and BestFirst strategies - Add cancel() method for immediate cancellation (thread-safe) - Add cancelled property to check cancellation status - Add _check_cancellation() internal method supporting both sync/async callbacks - Reset cancel event on strategy reuse for multiple crawls - Include cancelled flag in state notifications via on_state_change - Handle callback exceptions gracefully (fail-open, log warning) - Add comprehensive test suite with 26 tests covering all edge cases This enables external callers (e.g., cloud platforms) to stop a running deep crawl mid-execution and retrieve partial results. --- crawl4ai/deep_crawling/bff_strategy.py | 81 ++- crawl4ai/deep_crawling/bfs_strategy.py | 102 ++- crawl4ai/deep_crawling/dfs_strategy.py | 54 ++ .../test_deep_crawl_cancellation.py | 597 ++++++++++++++++++ 4 files changed, 828 insertions(+), 6 deletions(-) create mode 100644 tests/deep_crawling/test_deep_crawl_cancellation.py diff --git a/crawl4ai/deep_crawling/bff_strategy.py b/crawl4ai/deep_crawling/bff_strategy.py index fdb96248..5e250ca6 100644 --- a/crawl4ai/deep_crawling/bff_strategy.py +++ b/crawl4ai/deep_crawling/bff_strategy.py @@ -2,7 +2,7 @@ import asyncio import logging from datetime import datetime -from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable +from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union from urllib.parse import urlparse from ..models import TraversalStats @@ -44,6 +44,8 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy): # Optional resume/callback parameters for crash recovery resume_state: Optional[Dict[str, Any]] = None, on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None, + # Optional cancellation callback - checked before each URL is processed + should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None, ): self.max_depth = max_depth self.filter_chain = filter_chain @@ -63,6 +65,7 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy): # Store for use in arun methods self._resume_state = resume_state self._on_state_change = on_state_change + self._should_cancel = should_cancel self._last_state: Optional[Dict[str, Any]] = None # Shadow list for queue items (only used when on_state_change is set) self._queue_shadow: Optional[List[Tuple[float, int, str, Optional[str]]]] = None @@ -89,6 +92,55 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy): return True + def cancel(self) -> None: + """ + Cancel the crawl. Thread-safe, can be called from any context. + + The crawl will stop before processing the next URL. The current URL + being processed (if any) will complete before the crawl stops. + """ + self._cancel_event.set() + + @property + def cancelled(self) -> bool: + """ + Check if the crawl was/is cancelled. Thread-safe. + + Returns: + True if the crawl has been cancelled, False otherwise. + """ + return self._cancel_event.is_set() + + async def _check_cancellation(self) -> bool: + """ + Check if crawl should be cancelled. + + Handles both internal cancel flag and external should_cancel callback. + Supports both sync and async callbacks. + + Returns: + True if crawl should be cancelled, False otherwise. + """ + if self._cancel_event.is_set(): + return True + + if self._should_cancel: + try: + # Handle both sync and async callbacks + result = self._should_cancel() + if asyncio.iscoroutine(result): + result = await result + + if result: + self._cancel_event.set() + self.stats.end_time = datetime.now() + return True + except Exception as e: + # Fail-open: log warning and continue crawling + self.logger.warning(f"should_cancel callback error: {e}") + + return False + async def link_discovery( self, result: CrawlResult, @@ -148,6 +200,9 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy): The queue items are tuples of (score, depth, url, parent_url). Lower scores are treated as higher priority. URLs are processed in batches for efficiency. """ + # Reset cancel event for strategy reuse + self._cancel_event = asyncio.Event() + queue: asyncio.PriorityQueue = asyncio.PriorityQueue() # Conditional state initialization for resume support @@ -180,7 +235,12 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy): if self._pages_crawled >= self.max_pages: self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl") break - + + # Check external cancellation callback before processing this batch + if await self._check_cancellation(): + self.logger.info("Crawl cancelled by user") + break + # Calculate how many more URLs we can process in this batch remaining = self.max_pages - self._pages_crawled batch_size = min(BATCH_SIZE, remaining) @@ -262,11 +322,26 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy): ], "depths": depths, "pages_crawled": self._pages_crawled, + "cancelled": self._cancel_event.is_set(), } self._last_state = state await self._on_state_change(state) - # End of crawl. + # Final state update if cancelled + if self._cancel_event.is_set() and self._on_state_change and self._queue_shadow is not None: + state = { + "strategy_type": "best_first", + "visited": list(visited), + "queue_items": [ + {"score": s, "depth": d, "url": u, "parent_url": p} + for s, d, u, p in self._queue_shadow + ], + "depths": depths, + "pages_crawled": self._pages_crawled, + "cancelled": True, + } + self._last_state = state + await self._on_state_change(state) async def _arun_batch( self, diff --git a/crawl4ai/deep_crawling/bfs_strategy.py b/crawl4ai/deep_crawling/bfs_strategy.py index 35b66939..dab94532 100644 --- a/crawl4ai/deep_crawling/bfs_strategy.py +++ b/crawl4ai/deep_crawling/bfs_strategy.py @@ -2,7 +2,7 @@ import asyncio import logging from datetime import datetime -from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable +from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union from urllib.parse import urlparse from ..models import TraversalStats @@ -34,6 +34,8 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): # Optional resume/callback parameters for crash recovery resume_state: Optional[Dict[str, Any]] = None, on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None, + # Optional cancellation callback - checked before each URL is processed + should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None, ): self.max_depth = max_depth self.filter_chain = filter_chain @@ -54,6 +56,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): # Store for use in arun methods self._resume_state = resume_state self._on_state_change = on_state_change + self._should_cancel = should_cancel self._last_state: Optional[Dict[str, Any]] = None async def can_process_url(self, url: str, depth: int) -> bool: @@ -78,6 +81,55 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): return True + def cancel(self) -> None: + """ + Cancel the crawl. Thread-safe, can be called from any context. + + The crawl will stop before processing the next URL. The current URL + being processed (if any) will complete before the crawl stops. + """ + self._cancel_event.set() + + @property + def cancelled(self) -> bool: + """ + Check if the crawl was/is cancelled. Thread-safe. + + Returns: + True if the crawl has been cancelled, False otherwise. + """ + return self._cancel_event.is_set() + + async def _check_cancellation(self) -> bool: + """ + Check if crawl should be cancelled. + + Handles both internal cancel flag and external should_cancel callback. + Supports both sync and async callbacks. + + Returns: + True if crawl should be cancelled, False otherwise. + """ + if self._cancel_event.is_set(): + return True + + if self._should_cancel: + try: + # Handle both sync and async callbacks + result = self._should_cancel() + if asyncio.iscoroutine(result): + result = await result + + if result: + self._cancel_event.set() + self.stats.end_time = datetime.now() + return True + except Exception as e: + # Fail-open: log warning and continue crawling + self.logger.warning(f"should_cancel callback error: {e}") + + return False + async def link_discovery( self, result: CrawlResult, @@ -162,6 +214,9 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): Batch (non-streaming) mode: Processes one BFS level at a time, then yields all the results. """ + # Reset cancel event for strategy reuse + self._cancel_event = asyncio.Event() + # Conditional state initialization for resume support if self._resume_state: visited = set(self._resume_state.get("visited", [])) @@ -185,7 +240,12 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): if self._pages_crawled >= self.max_pages: self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl") break - + + # Check external cancellation callback before processing this level + if await self._check_cancellation(): + self.logger.info("Crawl cancelled by user") + break + next_level: List[Tuple[str, Optional[str]]] = [] urls = [url for url, _ in current_level] @@ -218,12 +278,26 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): "pending": [{"url": u, "parent_url": p} for u, p in next_level], "depths": depths, "pages_crawled": self._pages_crawled, + "cancelled": self._cancel_event.is_set(), } self._last_state = state await self._on_state_change(state) current_level = next_level + # Final state update if cancelled + if self._cancel_event.is_set() and self._on_state_change: + state = { + "strategy_type": "bfs", + "visited": list(visited), + "pending": [{"url": u, "parent_url": p} for u, p in current_level], + "depths": depths, + "pages_crawled": self._pages_crawled, + "cancelled": True, + } + self._last_state = state + await self._on_state_change(state) + return results async def _arun_stream( @@ -236,6 +310,9 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): Streaming mode: Processes one BFS level at a time and yields results immediately as they arrive. """ + # Reset cancel event for strategy reuse + self._cancel_event = asyncio.Event() + # Conditional state initialization for resume support if self._resume_state: visited = set(self._resume_state.get("visited", [])) @@ -252,6 +329,11 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): depths: Dict[str, int] = {start_url: 0} while current_level and not self._cancel_event.is_set(): + # Check external cancellation callback before processing this level + if await self._check_cancellation(): + self.logger.info("Crawl cancelled by user") + break + next_level: List[Tuple[str, Optional[str]]] = [] urls = [url for url, _ in current_level] visited.update(urls) @@ -293,6 +375,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): "pending": [{"url": u, "parent_url": p} for u, p in next_level], "depths": depths, "pages_crawled": self._pages_crawled, + "cancelled": self._cancel_event.is_set(), } self._last_state = state await self._on_state_change(state) @@ -301,9 +384,22 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy): # by considering these URLs as visited but not counting them toward the max_pages limit if results_count == 0 and urls: self.logger.warning(f"No results returned for {len(urls)} URLs, marking as visited") - + current_level = next_level + # Final state update if cancelled + if self._cancel_event.is_set() and self._on_state_change: + state = { + "strategy_type": "bfs", + "visited": list(visited), + "pending": [{"url": u, "parent_url": p} for u, p in current_level], + "depths": depths, + "pages_crawled": self._pages_crawled, + "cancelled": True, + } + self._last_state = state + await self._on_state_change(state) + async def shutdown(self) -> None: """ Clean up resources and signal cancellation of the crawl. diff --git a/crawl4ai/deep_crawling/dfs_strategy.py b/crawl4ai/deep_crawling/dfs_strategy.py index d98d06a7..5e592fc1 100644 --- a/crawl4ai/deep_crawling/dfs_strategy.py +++ b/crawl4ai/deep_crawling/dfs_strategy.py @@ -1,4 +1,5 @@ # dfs_deep_crawl_strategy.py +import asyncio from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple from ..models import CrawlResult @@ -38,6 +39,9 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy): in control of traversal. Every successful page bumps ``_pages_crawled`` and seeds new stack items discovered via :meth:`link_discovery`. """ + # Reset cancel event for strategy reuse + self._cancel_event = asyncio.Event() + # Conditional state initialization for resume support if self._resume_state: visited = set(self._resume_state.get("visited", [])) @@ -59,6 +63,11 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy): self._reset_seen(start_url) while stack and not self._cancel_event.is_set(): + # Check external cancellation callback before processing this URL + if await self._check_cancellation(): + self.logger.info("Crawl cancelled by user") + break + url, parent, depth = stack.pop() if url in visited or depth > self.max_depth: continue @@ -105,9 +114,28 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy): "depths": depths, "pages_crawled": self._pages_crawled, "dfs_seen": list(self._dfs_seen), + "cancelled": self._cancel_event.is_set(), } self._last_state = state await self._on_state_change(state) + + # Final state update if cancelled + if self._cancel_event.is_set() and self._on_state_change: + state = { + "strategy_type": "dfs", + "visited": list(visited), + "stack": [ + {"url": u, "parent_url": p, "depth": d} + for u, p, d in stack + ], + "depths": depths, + "pages_crawled": self._pages_crawled, + "dfs_seen": list(self._dfs_seen), + "cancelled": True, + } + self._last_state = state + await self._on_state_change(state) + return results async def _arun_stream( @@ -123,6 +151,9 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy): yielded before we even look at the next stack entry. Successful crawls still feed :meth:`link_discovery`, keeping DFS order intact. """ + # Reset cancel event for strategy reuse + self._cancel_event = asyncio.Event() + # Conditional state initialization for resume support if self._resume_state: visited = set(self._resume_state.get("visited", [])) @@ -141,6 +172,11 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy): self._reset_seen(start_url) while stack and not self._cancel_event.is_set(): + # Check external cancellation callback before processing this URL + if await self._check_cancellation(): + self.logger.info("Crawl cancelled by user") + break + url, parent, depth = stack.pop() if url in visited or depth > self.max_depth: continue @@ -183,10 +219,28 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy): "depths": depths, "pages_crawled": self._pages_crawled, "dfs_seen": list(self._dfs_seen), + "cancelled": self._cancel_event.is_set(), } self._last_state = state await self._on_state_change(state) + # Final state update if cancelled + if self._cancel_event.is_set() and self._on_state_change: + state = { + "strategy_type": "dfs", + "visited": list(visited), + "stack": [ + {"url": u, "parent_url": p, "depth": d} + for u, p, d in stack + ], + "depths": depths, + "pages_crawled": self._pages_crawled, + "dfs_seen": list(self._dfs_seen), + "cancelled": True, + } + self._last_state = state + await self._on_state_change(state) + async def link_discovery( self, result: CrawlResult, diff --git a/tests/deep_crawling/test_deep_crawl_cancellation.py b/tests/deep_crawling/test_deep_crawl_cancellation.py new file mode 100644 index 00000000..0ffb8a7e --- /dev/null +++ b/tests/deep_crawling/test_deep_crawl_cancellation.py @@ -0,0 +1,597 @@ +""" +Test Suite: Deep Crawl Cancellation Tests + +Tests that verify: +1. should_cancel callback is called before each URL +2. cancel() method immediately stops the crawl +3. cancelled property correctly reflects state +4. Strategy reuse works after cancellation +5. Both sync and async should_cancel callbacks work +6. Callback exceptions don't crash the crawl +7. State notifications include cancelled flag +""" + +import pytest +import asyncio +from typing import Dict, Any, List +from unittest.mock import MagicMock + +from crawl4ai.deep_crawling import ( + BFSDeepCrawlStrategy, + DFSDeepCrawlStrategy, + BestFirstCrawlingStrategy, +) + + +# ============================================================================ +# Helper Functions for Mock Crawler +# ============================================================================ + +def create_mock_config(stream=False): + """Create a mock CrawlerRunConfig.""" + config = MagicMock() + config.stream = stream + + def clone_config(**kwargs): + """Clone returns a new config with overridden values.""" + new_config = MagicMock() + new_config.stream = kwargs.get('stream', stream) + new_config.clone = MagicMock(side_effect=clone_config) + return new_config + + config.clone = MagicMock(side_effect=clone_config) + return config + + +def create_mock_crawler_with_links(num_links: int = 3): + """Create mock crawler that returns results with links.""" + call_count = 0 + + async def mock_arun_many(urls, config): + nonlocal call_count + results = [] + for url in urls: + call_count += 1 + result = MagicMock() + result.url = url + result.success = True + result.metadata = {} + + # Generate child links + links = [] + for i in range(num_links): + link_url = f"{url}/child{call_count}_{i}" + links.append({"href": link_url}) + + result.links = {"internal": links, "external": []} + results.append(result) + + # For streaming mode, return async generator + if config.stream: + async def gen(): + for r in results: + yield r + return gen() + return results + + crawler = MagicMock() + crawler.arun_many = mock_arun_many + return crawler + + +def create_mock_crawler_tracking(crawl_order: List[str], return_no_links: bool = False): + """Create mock crawler that tracks crawl order.""" + + async def mock_arun_many(urls, config): + results = [] + for url in urls: + crawl_order.append(url) + result = MagicMock() + result.url = url + result.success = True + result.metadata = {} + result.links = {"internal": [], "external": []} if return_no_links else {"internal": [{"href": f"{url}/child"}], "external": []} + results.append(result) + + # For streaming mode, return async generator + if config.stream: + async def gen(): + for r in results: + yield r + return gen() + return results + + crawler = MagicMock() + crawler.arun_many = mock_arun_many + return crawler + + +# ============================================================================ +# TEST SUITE: Cancellation via should_cancel Callback +# ============================================================================ + +class TestBFSCancellation: + """BFS strategy cancellation tests.""" + + @pytest.mark.asyncio + async def test_cancel_via_async_callback(self): + """Verify async should_cancel callback stops crawl.""" + pages_crawled = 0 + cancel_after = 3 + + async def check_cancel(): + return pages_crawled >= cancel_after + + async def track_pages(state: Dict[str, Any]): + nonlocal pages_crawled + pages_crawled = state.get("pages_crawled", 0) + + strategy = BFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + should_cancel=check_cancel, + on_state_change=track_pages, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=5) + mock_config = create_mock_config() + + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + # Should have stopped after cancel_after pages + assert strategy.cancelled == True + assert strategy._pages_crawled >= cancel_after + assert strategy._pages_crawled < 100 # Should not have crawled all pages + + @pytest.mark.asyncio + async def test_cancel_via_sync_callback(self): + """Verify sync should_cancel callback works.""" + cancel_flag = False + + def check_cancel(): + return cancel_flag + + async def set_cancel_after_3(state: Dict[str, Any]): + nonlocal cancel_flag + if state.get("pages_crawled", 0) >= 3: + cancel_flag = True + + strategy = BFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + should_cancel=check_cancel, + on_state_change=set_cancel_after_3, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=5) + mock_config = create_mock_config() + + await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert strategy.cancelled == True + assert strategy._pages_crawled >= 3 + + @pytest.mark.asyncio + async def test_cancel_method_stops_crawl(self): + """Verify cancel() method immediately stops the crawl.""" + strategy = BFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + ) + + async def cancel_after_2_pages(state: Dict[str, Any]): + if state.get("pages_crawled", 0) >= 2: + strategy.cancel() + + strategy._on_state_change = cancel_after_2_pages + + mock_crawler = create_mock_crawler_with_links(num_links=5) + mock_config = create_mock_config() + + await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert strategy.cancelled == True + assert strategy._pages_crawled >= 2 + assert strategy._pages_crawled < 100 + + @pytest.mark.asyncio + async def test_cancelled_property_reflects_state(self): + """Verify cancelled property correctly reflects state.""" + strategy = BFSDeepCrawlStrategy(max_depth=2, max_pages=10) + + # Before cancel + assert strategy.cancelled == False + + # After cancel() + strategy.cancel() + assert strategy.cancelled == True + + @pytest.mark.asyncio + async def test_strategy_reuse_after_cancellation(self): + """Verify strategy can be reused after cancellation.""" + call_count = 0 + + async def cancel_first_time(): + return call_count == 1 + + strategy = BFSDeepCrawlStrategy( + max_depth=1, + max_pages=5, + should_cancel=cancel_first_time, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=2) + mock_config = create_mock_config() + + # First crawl - should be cancelled + call_count = 1 + results1 = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + assert strategy.cancelled == True + + # Second crawl - should work normally (cancel_first_time returns False) + call_count = 2 + results2 = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + assert strategy.cancelled == False + assert len(results2) > len(results1) + + @pytest.mark.asyncio + async def test_callback_exception_continues_crawl(self): + """Verify callback exception doesn't crash crawl (fail-open).""" + exception_count = 0 + + async def failing_callback(): + nonlocal exception_count + exception_count += 1 + raise ConnectionError("Redis connection failed") + + strategy = BFSDeepCrawlStrategy( + max_depth=1, + max_pages=3, + should_cancel=failing_callback, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=2) + mock_config = create_mock_config() + + # Should not raise, should complete crawl + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert exception_count > 0 # Callback was called + assert len(results) > 0 # Crawl completed + assert strategy.cancelled == False # Not cancelled due to exception + + @pytest.mark.asyncio + async def test_state_includes_cancelled_flag(self): + """Verify state notifications include cancelled flag.""" + states: List[Dict] = [] + cancel_at = 3 + + async def capture_state(state: Dict[str, Any]): + states.append(state) + + async def cancel_after_3(): + return len(states) >= cancel_at + + strategy = BFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + should_cancel=cancel_after_3, + on_state_change=capture_state, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=5) + mock_config = create_mock_config() + + await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + # Last state should have cancelled=True + assert len(states) > 0 + assert states[-1].get("cancelled") == True + + @pytest.mark.asyncio + async def test_cancel_before_first_url(self): + """Verify cancel before first URL returns empty results.""" + async def always_cancel(): + return True + + strategy = BFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + should_cancel=always_cancel, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=5) + mock_config = create_mock_config() + + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert strategy.cancelled == True + assert len(results) == 0 + + +class TestDFSCancellation: + """DFS strategy cancellation tests.""" + + @pytest.mark.asyncio + async def test_cancel_via_callback(self): + """Verify DFS respects should_cancel callback.""" + pages_crawled = 0 + cancel_after = 3 + + async def check_cancel(): + return pages_crawled >= cancel_after + + async def track_pages(state: Dict[str, Any]): + nonlocal pages_crawled + pages_crawled = state.get("pages_crawled", 0) + + strategy = DFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + should_cancel=check_cancel, + on_state_change=track_pages, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=3) + mock_config = create_mock_config() + + await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert strategy.cancelled == True + assert strategy._pages_crawled >= cancel_after + assert strategy._pages_crawled < 100 + + @pytest.mark.asyncio + async def test_cancel_method_inherited(self): + """Verify DFS inherits cancel() from BFS.""" + strategy = DFSDeepCrawlStrategy(max_depth=2, max_pages=10) + + assert hasattr(strategy, 'cancel') + assert hasattr(strategy, 'cancelled') + assert hasattr(strategy, '_check_cancellation') + + strategy.cancel() + assert strategy.cancelled == True + + @pytest.mark.asyncio + async def test_stream_mode_cancellation(self): + """Verify DFS stream mode respects cancellation.""" + results_count = 0 + cancel_after = 2 + + async def check_cancel(): + return results_count >= cancel_after + + strategy = DFSDeepCrawlStrategy( + max_depth=5, + max_pages=100, + should_cancel=check_cancel, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=3) + mock_config = create_mock_config(stream=True) + + async for result in strategy._arun_stream("https://example.com", mock_crawler, mock_config): + results_count += 1 + + assert strategy.cancelled == True + assert results_count >= cancel_after + assert results_count < 100 + + +class TestBestFirstCancellation: + """Best-First strategy cancellation tests.""" + + @pytest.mark.asyncio + async def test_cancel_via_callback(self): + """Verify Best-First respects should_cancel callback.""" + pages_crawled = 0 + cancel_after = 3 + + async def check_cancel(): + return pages_crawled >= cancel_after + + async def track_pages(state: Dict[str, Any]): + nonlocal pages_crawled + pages_crawled = state.get("pages_crawled", 0) + + strategy = BestFirstCrawlingStrategy( + max_depth=5, + max_pages=100, + should_cancel=check_cancel, + on_state_change=track_pages, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=3) + mock_config = create_mock_config(stream=True) + + async for _ in strategy._arun_stream("https://example.com", mock_crawler, mock_config): + pass + + assert strategy.cancelled == True + assert strategy._pages_crawled >= cancel_after + assert strategy._pages_crawled < 100 + + @pytest.mark.asyncio + async def test_cancel_method_works(self): + """Verify Best-First cancel() method works.""" + strategy = BestFirstCrawlingStrategy(max_depth=2, max_pages=10) + + assert strategy.cancelled == False + strategy.cancel() + assert strategy.cancelled == True + + @pytest.mark.asyncio + async def test_batch_mode_cancellation(self): + """Verify Best-First batch mode respects cancellation.""" + pages_crawled = 0 + cancel_after = 2 + + async def check_cancel(): + return pages_crawled >= cancel_after + + async def track_pages(state: Dict[str, Any]): + nonlocal pages_crawled + pages_crawled = state.get("pages_crawled", 0) + + strategy = BestFirstCrawlingStrategy( + max_depth=5, + max_pages=100, + should_cancel=check_cancel, + on_state_change=track_pages, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=3) + mock_config = create_mock_config(stream=False) + + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert strategy.cancelled == True + assert len(results) >= cancel_after + assert len(results) < 100 + + +class TestCrossStrategyCancellation: + """Tests that apply to all strategies.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("strategy_class", [ + BFSDeepCrawlStrategy, + DFSDeepCrawlStrategy, + BestFirstCrawlingStrategy, + ]) + async def test_no_cancel_callback_means_no_cancellation(self, strategy_class): + """Verify crawl completes normally without should_cancel.""" + strategy = strategy_class(max_depth=1, max_pages=5) + + mock_crawler = create_mock_crawler_with_links(num_links=2) + + if strategy_class == BestFirstCrawlingStrategy: + mock_config = create_mock_config(stream=True) + results = [] + async for r in strategy._arun_stream("https://example.com", mock_crawler, mock_config): + results.append(r) + else: + mock_config = create_mock_config() + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + assert strategy.cancelled == False + assert len(results) > 0 + + @pytest.mark.asyncio + @pytest.mark.parametrize("strategy_class", [ + BFSDeepCrawlStrategy, + DFSDeepCrawlStrategy, + BestFirstCrawlingStrategy, + ]) + async def test_cancel_thread_safety(self, strategy_class): + """Verify cancel() is thread-safe (doesn't raise).""" + strategy = strategy_class(max_depth=2, max_pages=10) + + # Call cancel from multiple "threads" (simulated) + for _ in range(10): + strategy.cancel() + + # Should be cancelled without errors + assert strategy.cancelled == True + + @pytest.mark.asyncio + @pytest.mark.parametrize("strategy_class", [ + BFSDeepCrawlStrategy, + DFSDeepCrawlStrategy, + BestFirstCrawlingStrategy, + ]) + async def test_should_cancel_param_accepted(self, strategy_class): + """Verify should_cancel parameter is accepted by constructor.""" + async def dummy_cancel(): + return False + + # Should not raise + strategy = strategy_class( + max_depth=2, + max_pages=10, + should_cancel=dummy_cancel, + ) + + assert strategy._should_cancel == dummy_cancel + + +class TestCancellationEdgeCases: + """Edge case tests for cancellation.""" + + @pytest.mark.asyncio + async def test_cancel_during_batch_processing(self): + """Verify cancellation during batch doesn't lose results.""" + results_count = 0 + + async def cancel_mid_batch(): + # Cancel after receiving first result + return results_count >= 1 + + strategy = BFSDeepCrawlStrategy( + max_depth=2, + max_pages=100, + should_cancel=cancel_mid_batch, + ) + + async def track_results(state): + nonlocal results_count + results_count = state.get("pages_crawled", 0) + + strategy._on_state_change = track_results + + mock_crawler = create_mock_crawler_with_links(num_links=5) + mock_config = create_mock_config() + + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + # Should have at least the first batch of results + assert len(results) >= 1 + assert strategy.cancelled == True + + @pytest.mark.asyncio + async def test_partial_results_on_cancel(self): + """Verify partial results are returned on cancellation.""" + cancel_after = 5 + + async def check_cancel(): + return strategy._pages_crawled >= cancel_after + + strategy = BFSDeepCrawlStrategy( + max_depth=10, + max_pages=1000, + should_cancel=check_cancel, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=10) + mock_config = create_mock_config() + + results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + # Should have results up to cancellation point + assert len(results) >= cancel_after + assert strategy.cancelled == True + + @pytest.mark.asyncio + async def test_cancel_callback_called_once_per_level_bfs(self): + """Verify BFS checks cancellation once per level.""" + check_count = 0 + + async def count_checks(): + nonlocal check_count + check_count += 1 + return False # Never cancel + + strategy = BFSDeepCrawlStrategy( + max_depth=2, + max_pages=10, + should_cancel=count_checks, + ) + + mock_crawler = create_mock_crawler_with_links(num_links=2) + mock_config = create_mock_config() + + await strategy._arun_batch("https://example.com", mock_crawler, mock_config) + + # Should have checked at least once per level + assert check_count >= 1