Add cancellation support for deep crawl strategies
- Add should_cancel callback parameter to BFS, DFS, and BestFirst strategies - Add cancel() method for immediate cancellation (thread-safe) - Add cancelled property to check cancellation status - Add _check_cancellation() internal method supporting both sync/async callbacks - Reset cancel event on strategy reuse for multiple crawls - Include cancelled flag in state notifications via on_state_change - Handle callback exceptions gracefully (fail-open, log warning) - Add comprehensive test suite with 26 tests covering all edge cases This enables external callers (e.g., cloud platforms) to stop a running deep crawl mid-execution and retrieve partial results.
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable
|
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from ..models import TraversalStats
|
from ..models import TraversalStats
|
||||||
@@ -44,6 +44,8 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
# Optional resume/callback parameters for crash recovery
|
# Optional resume/callback parameters for crash recovery
|
||||||
resume_state: Optional[Dict[str, Any]] = None,
|
resume_state: Optional[Dict[str, Any]] = None,
|
||||||
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
|
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
|
||||||
|
# Optional cancellation callback - checked before each URL is processed
|
||||||
|
should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None,
|
||||||
):
|
):
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.filter_chain = filter_chain
|
self.filter_chain = filter_chain
|
||||||
@@ -63,6 +65,7 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
# Store for use in arun methods
|
# Store for use in arun methods
|
||||||
self._resume_state = resume_state
|
self._resume_state = resume_state
|
||||||
self._on_state_change = on_state_change
|
self._on_state_change = on_state_change
|
||||||
|
self._should_cancel = should_cancel
|
||||||
self._last_state: Optional[Dict[str, Any]] = None
|
self._last_state: Optional[Dict[str, Any]] = None
|
||||||
# Shadow list for queue items (only used when on_state_change is set)
|
# Shadow list for queue items (only used when on_state_change is set)
|
||||||
self._queue_shadow: Optional[List[Tuple[float, int, str, Optional[str]]]] = None
|
self._queue_shadow: Optional[List[Tuple[float, int, str, Optional[str]]]] = None
|
||||||
@@ -89,6 +92,55 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""
|
||||||
|
Cancel the crawl. Thread-safe, can be called from any context.
|
||||||
|
|
||||||
|
The crawl will stop before processing the next URL. The current URL
|
||||||
|
being processed (if any) will complete before the crawl stops.
|
||||||
|
"""
|
||||||
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cancelled(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the crawl was/is cancelled. Thread-safe.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the crawl has been cancelled, False otherwise.
|
||||||
|
"""
|
||||||
|
return self._cancel_event.is_set()
|
||||||
|
|
||||||
|
async def _check_cancellation(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if crawl should be cancelled.
|
||||||
|
|
||||||
|
Handles both internal cancel flag and external should_cancel callback.
|
||||||
|
Supports both sync and async callbacks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if crawl should be cancelled, False otherwise.
|
||||||
|
"""
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._should_cancel:
|
||||||
|
try:
|
||||||
|
# Handle both sync and async callbacks
|
||||||
|
result = self._should_cancel()
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
result = await result
|
||||||
|
|
||||||
|
if result:
|
||||||
|
self._cancel_event.set()
|
||||||
|
self.stats.end_time = datetime.now()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
# Fail-open: log warning and continue crawling
|
||||||
|
self.logger.warning(f"should_cancel callback error: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
async def link_discovery(
|
async def link_discovery(
|
||||||
self,
|
self,
|
||||||
result: CrawlResult,
|
result: CrawlResult,
|
||||||
@@ -148,6 +200,9 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
The queue items are tuples of (score, depth, url, parent_url). Lower scores
|
The queue items are tuples of (score, depth, url, parent_url). Lower scores
|
||||||
are treated as higher priority. URLs are processed in batches for efficiency.
|
are treated as higher priority. URLs are processed in batches for efficiency.
|
||||||
"""
|
"""
|
||||||
|
# Reset cancel event for strategy reuse
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
|
||||||
|
|
||||||
# Conditional state initialization for resume support
|
# Conditional state initialization for resume support
|
||||||
@@ -180,7 +235,12 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
if self._pages_crawled >= self.max_pages:
|
if self._pages_crawled >= self.max_pages:
|
||||||
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
|
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Check external cancellation callback before processing this batch
|
||||||
|
if await self._check_cancellation():
|
||||||
|
self.logger.info("Crawl cancelled by user")
|
||||||
|
break
|
||||||
|
|
||||||
# Calculate how many more URLs we can process in this batch
|
# Calculate how many more URLs we can process in this batch
|
||||||
remaining = self.max_pages - self._pages_crawled
|
remaining = self.max_pages - self._pages_crawled
|
||||||
batch_size = min(BATCH_SIZE, remaining)
|
batch_size = min(BATCH_SIZE, remaining)
|
||||||
@@ -262,11 +322,26 @@ class BestFirstCrawlingStrategy(DeepCrawlStrategy):
|
|||||||
],
|
],
|
||||||
"depths": depths,
|
"depths": depths,
|
||||||
"pages_crawled": self._pages_crawled,
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"cancelled": self._cancel_event.is_set(),
|
||||||
}
|
}
|
||||||
self._last_state = state
|
self._last_state = state
|
||||||
await self._on_state_change(state)
|
await self._on_state_change(state)
|
||||||
|
|
||||||
# End of crawl.
|
# Final state update if cancelled
|
||||||
|
if self._cancel_event.is_set() and self._on_state_change and self._queue_shadow is not None:
|
||||||
|
state = {
|
||||||
|
"strategy_type": "best_first",
|
||||||
|
"visited": list(visited),
|
||||||
|
"queue_items": [
|
||||||
|
{"score": s, "depth": d, "url": u, "parent_url": p}
|
||||||
|
for s, d, u, p in self._queue_shadow
|
||||||
|
],
|
||||||
|
"depths": depths,
|
||||||
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"cancelled": True,
|
||||||
|
}
|
||||||
|
self._last_state = state
|
||||||
|
await self._on_state_change(state)
|
||||||
|
|
||||||
async def _arun_batch(
|
async def _arun_batch(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable
|
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple, Any, Callable, Awaitable, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from ..models import TraversalStats
|
from ..models import TraversalStats
|
||||||
@@ -34,6 +34,8 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
# Optional resume/callback parameters for crash recovery
|
# Optional resume/callback parameters for crash recovery
|
||||||
resume_state: Optional[Dict[str, Any]] = None,
|
resume_state: Optional[Dict[str, Any]] = None,
|
||||||
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
|
on_state_change: Optional[Callable[[Dict[str, Any]], Awaitable[None]]] = None,
|
||||||
|
# Optional cancellation callback - checked before each URL is processed
|
||||||
|
should_cancel: Optional[Callable[[], Union[bool, Awaitable[bool]]]] = None,
|
||||||
):
|
):
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.filter_chain = filter_chain
|
self.filter_chain = filter_chain
|
||||||
@@ -54,6 +56,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
# Store for use in arun methods
|
# Store for use in arun methods
|
||||||
self._resume_state = resume_state
|
self._resume_state = resume_state
|
||||||
self._on_state_change = on_state_change
|
self._on_state_change = on_state_change
|
||||||
|
self._should_cancel = should_cancel
|
||||||
self._last_state: Optional[Dict[str, Any]] = None
|
self._last_state: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
async def can_process_url(self, url: str, depth: int) -> bool:
|
async def can_process_url(self, url: str, depth: int) -> bool:
|
||||||
@@ -78,6 +81,55 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""
|
||||||
|
Cancel the crawl. Thread-safe, can be called from any context.
|
||||||
|
|
||||||
|
The crawl will stop before processing the next URL. The current URL
|
||||||
|
being processed (if any) will complete before the crawl stops.
|
||||||
|
"""
|
||||||
|
self._cancel_event.set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cancelled(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the crawl was/is cancelled. Thread-safe.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the crawl has been cancelled, False otherwise.
|
||||||
|
"""
|
||||||
|
return self._cancel_event.is_set()
|
||||||
|
|
||||||
|
async def _check_cancellation(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if crawl should be cancelled.
|
||||||
|
|
||||||
|
Handles both internal cancel flag and external should_cancel callback.
|
||||||
|
Supports both sync and async callbacks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if crawl should be cancelled, False otherwise.
|
||||||
|
"""
|
||||||
|
if self._cancel_event.is_set():
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._should_cancel:
|
||||||
|
try:
|
||||||
|
# Handle both sync and async callbacks
|
||||||
|
result = self._should_cancel()
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
result = await result
|
||||||
|
|
||||||
|
if result:
|
||||||
|
self._cancel_event.set()
|
||||||
|
self.stats.end_time = datetime.now()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
# Fail-open: log warning and continue crawling
|
||||||
|
self.logger.warning(f"should_cancel callback error: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
async def link_discovery(
|
async def link_discovery(
|
||||||
self,
|
self,
|
||||||
result: CrawlResult,
|
result: CrawlResult,
|
||||||
@@ -162,6 +214,9 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
Batch (non-streaming) mode:
|
Batch (non-streaming) mode:
|
||||||
Processes one BFS level at a time, then yields all the results.
|
Processes one BFS level at a time, then yields all the results.
|
||||||
"""
|
"""
|
||||||
|
# Reset cancel event for strategy reuse
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
# Conditional state initialization for resume support
|
# Conditional state initialization for resume support
|
||||||
if self._resume_state:
|
if self._resume_state:
|
||||||
visited = set(self._resume_state.get("visited", []))
|
visited = set(self._resume_state.get("visited", []))
|
||||||
@@ -185,7 +240,12 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
if self._pages_crawled >= self.max_pages:
|
if self._pages_crawled >= self.max_pages:
|
||||||
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
|
self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Check external cancellation callback before processing this level
|
||||||
|
if await self._check_cancellation():
|
||||||
|
self.logger.info("Crawl cancelled by user")
|
||||||
|
break
|
||||||
|
|
||||||
next_level: List[Tuple[str, Optional[str]]] = []
|
next_level: List[Tuple[str, Optional[str]]] = []
|
||||||
urls = [url for url, _ in current_level]
|
urls = [url for url, _ in current_level]
|
||||||
|
|
||||||
@@ -218,12 +278,26 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
|
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
|
||||||
"depths": depths,
|
"depths": depths,
|
||||||
"pages_crawled": self._pages_crawled,
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"cancelled": self._cancel_event.is_set(),
|
||||||
}
|
}
|
||||||
self._last_state = state
|
self._last_state = state
|
||||||
await self._on_state_change(state)
|
await self._on_state_change(state)
|
||||||
|
|
||||||
current_level = next_level
|
current_level = next_level
|
||||||
|
|
||||||
|
# Final state update if cancelled
|
||||||
|
if self._cancel_event.is_set() and self._on_state_change:
|
||||||
|
state = {
|
||||||
|
"strategy_type": "bfs",
|
||||||
|
"visited": list(visited),
|
||||||
|
"pending": [{"url": u, "parent_url": p} for u, p in current_level],
|
||||||
|
"depths": depths,
|
||||||
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"cancelled": True,
|
||||||
|
}
|
||||||
|
self._last_state = state
|
||||||
|
await self._on_state_change(state)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _arun_stream(
|
async def _arun_stream(
|
||||||
@@ -236,6 +310,9 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
Streaming mode:
|
Streaming mode:
|
||||||
Processes one BFS level at a time and yields results immediately as they arrive.
|
Processes one BFS level at a time and yields results immediately as they arrive.
|
||||||
"""
|
"""
|
||||||
|
# Reset cancel event for strategy reuse
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
# Conditional state initialization for resume support
|
# Conditional state initialization for resume support
|
||||||
if self._resume_state:
|
if self._resume_state:
|
||||||
visited = set(self._resume_state.get("visited", []))
|
visited = set(self._resume_state.get("visited", []))
|
||||||
@@ -252,6 +329,11 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
depths: Dict[str, int] = {start_url: 0}
|
depths: Dict[str, int] = {start_url: 0}
|
||||||
|
|
||||||
while current_level and not self._cancel_event.is_set():
|
while current_level and not self._cancel_event.is_set():
|
||||||
|
# Check external cancellation callback before processing this level
|
||||||
|
if await self._check_cancellation():
|
||||||
|
self.logger.info("Crawl cancelled by user")
|
||||||
|
break
|
||||||
|
|
||||||
next_level: List[Tuple[str, Optional[str]]] = []
|
next_level: List[Tuple[str, Optional[str]]] = []
|
||||||
urls = [url for url, _ in current_level]
|
urls = [url for url, _ in current_level]
|
||||||
visited.update(urls)
|
visited.update(urls)
|
||||||
@@ -293,6 +375,7 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
|
"pending": [{"url": u, "parent_url": p} for u, p in next_level],
|
||||||
"depths": depths,
|
"depths": depths,
|
||||||
"pages_crawled": self._pages_crawled,
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"cancelled": self._cancel_event.is_set(),
|
||||||
}
|
}
|
||||||
self._last_state = state
|
self._last_state = state
|
||||||
await self._on_state_change(state)
|
await self._on_state_change(state)
|
||||||
@@ -301,9 +384,22 @@ class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
|||||||
# by considering these URLs as visited but not counting them toward the max_pages limit
|
# by considering these URLs as visited but not counting them toward the max_pages limit
|
||||||
if results_count == 0 and urls:
|
if results_count == 0 and urls:
|
||||||
self.logger.warning(f"No results returned for {len(urls)} URLs, marking as visited")
|
self.logger.warning(f"No results returned for {len(urls)} URLs, marking as visited")
|
||||||
|
|
||||||
current_level = next_level
|
current_level = next_level
|
||||||
|
|
||||||
|
# Final state update if cancelled
|
||||||
|
if self._cancel_event.is_set() and self._on_state_change:
|
||||||
|
state = {
|
||||||
|
"strategy_type": "bfs",
|
||||||
|
"visited": list(visited),
|
||||||
|
"pending": [{"url": u, "parent_url": p} for u, p in current_level],
|
||||||
|
"depths": depths,
|
||||||
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"cancelled": True,
|
||||||
|
}
|
||||||
|
self._last_state = state
|
||||||
|
await self._on_state_change(state)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clean up resources and signal cancellation of the crawl.
|
Clean up resources and signal cancellation of the crawl.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# dfs_deep_crawl_strategy.py
|
# dfs_deep_crawl_strategy.py
|
||||||
|
import asyncio
|
||||||
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple
|
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple
|
||||||
|
|
||||||
from ..models import CrawlResult
|
from ..models import CrawlResult
|
||||||
@@ -38,6 +39,9 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
|
|||||||
in control of traversal. Every successful page bumps ``_pages_crawled`` and
|
in control of traversal. Every successful page bumps ``_pages_crawled`` and
|
||||||
seeds new stack items discovered via :meth:`link_discovery`.
|
seeds new stack items discovered via :meth:`link_discovery`.
|
||||||
"""
|
"""
|
||||||
|
# Reset cancel event for strategy reuse
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
# Conditional state initialization for resume support
|
# Conditional state initialization for resume support
|
||||||
if self._resume_state:
|
if self._resume_state:
|
||||||
visited = set(self._resume_state.get("visited", []))
|
visited = set(self._resume_state.get("visited", []))
|
||||||
@@ -59,6 +63,11 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
|
|||||||
self._reset_seen(start_url)
|
self._reset_seen(start_url)
|
||||||
|
|
||||||
while stack and not self._cancel_event.is_set():
|
while stack and not self._cancel_event.is_set():
|
||||||
|
# Check external cancellation callback before processing this URL
|
||||||
|
if await self._check_cancellation():
|
||||||
|
self.logger.info("Crawl cancelled by user")
|
||||||
|
break
|
||||||
|
|
||||||
url, parent, depth = stack.pop()
|
url, parent, depth = stack.pop()
|
||||||
if url in visited or depth > self.max_depth:
|
if url in visited or depth > self.max_depth:
|
||||||
continue
|
continue
|
||||||
@@ -105,9 +114,28 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
|
|||||||
"depths": depths,
|
"depths": depths,
|
||||||
"pages_crawled": self._pages_crawled,
|
"pages_crawled": self._pages_crawled,
|
||||||
"dfs_seen": list(self._dfs_seen),
|
"dfs_seen": list(self._dfs_seen),
|
||||||
|
"cancelled": self._cancel_event.is_set(),
|
||||||
}
|
}
|
||||||
self._last_state = state
|
self._last_state = state
|
||||||
await self._on_state_change(state)
|
await self._on_state_change(state)
|
||||||
|
|
||||||
|
# Final state update if cancelled
|
||||||
|
if self._cancel_event.is_set() and self._on_state_change:
|
||||||
|
state = {
|
||||||
|
"strategy_type": "dfs",
|
||||||
|
"visited": list(visited),
|
||||||
|
"stack": [
|
||||||
|
{"url": u, "parent_url": p, "depth": d}
|
||||||
|
for u, p, d in stack
|
||||||
|
],
|
||||||
|
"depths": depths,
|
||||||
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"dfs_seen": list(self._dfs_seen),
|
||||||
|
"cancelled": True,
|
||||||
|
}
|
||||||
|
self._last_state = state
|
||||||
|
await self._on_state_change(state)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _arun_stream(
|
async def _arun_stream(
|
||||||
@@ -123,6 +151,9 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
|
|||||||
yielded before we even look at the next stack entry. Successful crawls
|
yielded before we even look at the next stack entry. Successful crawls
|
||||||
still feed :meth:`link_discovery`, keeping DFS order intact.
|
still feed :meth:`link_discovery`, keeping DFS order intact.
|
||||||
"""
|
"""
|
||||||
|
# Reset cancel event for strategy reuse
|
||||||
|
self._cancel_event = asyncio.Event()
|
||||||
|
|
||||||
# Conditional state initialization for resume support
|
# Conditional state initialization for resume support
|
||||||
if self._resume_state:
|
if self._resume_state:
|
||||||
visited = set(self._resume_state.get("visited", []))
|
visited = set(self._resume_state.get("visited", []))
|
||||||
@@ -141,6 +172,11 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
|
|||||||
self._reset_seen(start_url)
|
self._reset_seen(start_url)
|
||||||
|
|
||||||
while stack and not self._cancel_event.is_set():
|
while stack and not self._cancel_event.is_set():
|
||||||
|
# Check external cancellation callback before processing this URL
|
||||||
|
if await self._check_cancellation():
|
||||||
|
self.logger.info("Crawl cancelled by user")
|
||||||
|
break
|
||||||
|
|
||||||
url, parent, depth = stack.pop()
|
url, parent, depth = stack.pop()
|
||||||
if url in visited or depth > self.max_depth:
|
if url in visited or depth > self.max_depth:
|
||||||
continue
|
continue
|
||||||
@@ -183,10 +219,28 @@ class DFSDeepCrawlStrategy(BFSDeepCrawlStrategy):
|
|||||||
"depths": depths,
|
"depths": depths,
|
||||||
"pages_crawled": self._pages_crawled,
|
"pages_crawled": self._pages_crawled,
|
||||||
"dfs_seen": list(self._dfs_seen),
|
"dfs_seen": list(self._dfs_seen),
|
||||||
|
"cancelled": self._cancel_event.is_set(),
|
||||||
}
|
}
|
||||||
self._last_state = state
|
self._last_state = state
|
||||||
await self._on_state_change(state)
|
await self._on_state_change(state)
|
||||||
|
|
||||||
|
# Final state update if cancelled
|
||||||
|
if self._cancel_event.is_set() and self._on_state_change:
|
||||||
|
state = {
|
||||||
|
"strategy_type": "dfs",
|
||||||
|
"visited": list(visited),
|
||||||
|
"stack": [
|
||||||
|
{"url": u, "parent_url": p, "depth": d}
|
||||||
|
for u, p, d in stack
|
||||||
|
],
|
||||||
|
"depths": depths,
|
||||||
|
"pages_crawled": self._pages_crawled,
|
||||||
|
"dfs_seen": list(self._dfs_seen),
|
||||||
|
"cancelled": True,
|
||||||
|
}
|
||||||
|
self._last_state = state
|
||||||
|
await self._on_state_change(state)
|
||||||
|
|
||||||
async def link_discovery(
|
async def link_discovery(
|
||||||
self,
|
self,
|
||||||
result: CrawlResult,
|
result: CrawlResult,
|
||||||
|
|||||||
597
tests/deep_crawling/test_deep_crawl_cancellation.py
Normal file
597
tests/deep_crawling/test_deep_crawl_cancellation.py
Normal file
@@ -0,0 +1,597 @@
|
|||||||
|
"""
|
||||||
|
Test Suite: Deep Crawl Cancellation Tests
|
||||||
|
|
||||||
|
Tests that verify:
|
||||||
|
1. should_cancel callback is called before each URL
|
||||||
|
2. cancel() method immediately stops the crawl
|
||||||
|
3. cancelled property correctly reflects state
|
||||||
|
4. Strategy reuse works after cancellation
|
||||||
|
5. Both sync and async should_cancel callbacks work
|
||||||
|
6. Callback exceptions don't crash the crawl
|
||||||
|
7. State notifications include cancelled flag
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from crawl4ai.deep_crawling import (
|
||||||
|
BFSDeepCrawlStrategy,
|
||||||
|
DFSDeepCrawlStrategy,
|
||||||
|
BestFirstCrawlingStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Helper Functions for Mock Crawler
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def create_mock_config(stream=False):
|
||||||
|
"""Create a mock CrawlerRunConfig."""
|
||||||
|
config = MagicMock()
|
||||||
|
config.stream = stream
|
||||||
|
|
||||||
|
def clone_config(**kwargs):
|
||||||
|
"""Clone returns a new config with overridden values."""
|
||||||
|
new_config = MagicMock()
|
||||||
|
new_config.stream = kwargs.get('stream', stream)
|
||||||
|
new_config.clone = MagicMock(side_effect=clone_config)
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
config.clone = MagicMock(side_effect=clone_config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_crawler_with_links(num_links: int = 3):
|
||||||
|
"""Create mock crawler that returns results with links."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def mock_arun_many(urls, config):
|
||||||
|
nonlocal call_count
|
||||||
|
results = []
|
||||||
|
for url in urls:
|
||||||
|
call_count += 1
|
||||||
|
result = MagicMock()
|
||||||
|
result.url = url
|
||||||
|
result.success = True
|
||||||
|
result.metadata = {}
|
||||||
|
|
||||||
|
# Generate child links
|
||||||
|
links = []
|
||||||
|
for i in range(num_links):
|
||||||
|
link_url = f"{url}/child{call_count}_{i}"
|
||||||
|
links.append({"href": link_url})
|
||||||
|
|
||||||
|
result.links = {"internal": links, "external": []}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# For streaming mode, return async generator
|
||||||
|
if config.stream:
|
||||||
|
async def gen():
|
||||||
|
for r in results:
|
||||||
|
yield r
|
||||||
|
return gen()
|
||||||
|
return results
|
||||||
|
|
||||||
|
crawler = MagicMock()
|
||||||
|
crawler.arun_many = mock_arun_many
|
||||||
|
return crawler
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_crawler_tracking(crawl_order: List[str], return_no_links: bool = False):
|
||||||
|
"""Create mock crawler that tracks crawl order."""
|
||||||
|
|
||||||
|
async def mock_arun_many(urls, config):
|
||||||
|
results = []
|
||||||
|
for url in urls:
|
||||||
|
crawl_order.append(url)
|
||||||
|
result = MagicMock()
|
||||||
|
result.url = url
|
||||||
|
result.success = True
|
||||||
|
result.metadata = {}
|
||||||
|
result.links = {"internal": [], "external": []} if return_no_links else {"internal": [{"href": f"{url}/child"}], "external": []}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# For streaming mode, return async generator
|
||||||
|
if config.stream:
|
||||||
|
async def gen():
|
||||||
|
for r in results:
|
||||||
|
yield r
|
||||||
|
return gen()
|
||||||
|
return results
|
||||||
|
|
||||||
|
crawler = MagicMock()
|
||||||
|
crawler.arun_many = mock_arun_many
|
||||||
|
return crawler
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# TEST SUITE: Cancellation via should_cancel Callback
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestBFSCancellation:
|
||||||
|
"""BFS strategy cancellation tests."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_via_async_callback(self):
|
||||||
|
"""Verify async should_cancel callback stops crawl."""
|
||||||
|
pages_crawled = 0
|
||||||
|
cancel_after = 3
|
||||||
|
|
||||||
|
async def check_cancel():
|
||||||
|
return pages_crawled >= cancel_after
|
||||||
|
|
||||||
|
async def track_pages(state: Dict[str, Any]):
|
||||||
|
nonlocal pages_crawled
|
||||||
|
pages_crawled = state.get("pages_crawled", 0)
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
on_state_change=track_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=5)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
# Should have stopped after cancel_after pages
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert strategy._pages_crawled >= cancel_after
|
||||||
|
assert strategy._pages_crawled < 100 # Should not have crawled all pages
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_via_sync_callback(self):
|
||||||
|
"""Verify sync should_cancel callback works."""
|
||||||
|
cancel_flag = False
|
||||||
|
|
||||||
|
def check_cancel():
|
||||||
|
return cancel_flag
|
||||||
|
|
||||||
|
async def set_cancel_after_3(state: Dict[str, Any]):
|
||||||
|
nonlocal cancel_flag
|
||||||
|
if state.get("pages_crawled", 0) >= 3:
|
||||||
|
cancel_flag = True
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
on_state_change=set_cancel_after_3,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=5)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert strategy._pages_crawled >= 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_method_stops_crawl(self):
|
||||||
|
"""Verify cancel() method immediately stops the crawl."""
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cancel_after_2_pages(state: Dict[str, Any]):
|
||||||
|
if state.get("pages_crawled", 0) >= 2:
|
||||||
|
strategy.cancel()
|
||||||
|
|
||||||
|
strategy._on_state_change = cancel_after_2_pages
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=5)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert strategy._pages_crawled >= 2
|
||||||
|
assert strategy._pages_crawled < 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancelled_property_reflects_state(self):
|
||||||
|
"""Verify cancelled property correctly reflects state."""
|
||||||
|
strategy = BFSDeepCrawlStrategy(max_depth=2, max_pages=10)
|
||||||
|
|
||||||
|
# Before cancel
|
||||||
|
assert strategy.cancelled == False
|
||||||
|
|
||||||
|
# After cancel()
|
||||||
|
strategy.cancel()
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strategy_reuse_after_cancellation(self):
|
||||||
|
"""Verify strategy can be reused after cancellation."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def cancel_first_time():
|
||||||
|
return call_count == 1
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=1,
|
||||||
|
max_pages=5,
|
||||||
|
should_cancel=cancel_first_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=2)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
# First crawl - should be cancelled
|
||||||
|
call_count = 1
|
||||||
|
results1 = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
# Second crawl - should work normally (cancel_first_time returns False)
|
||||||
|
call_count = 2
|
||||||
|
results2 = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
assert strategy.cancelled == False
|
||||||
|
assert len(results2) > len(results1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_exception_continues_crawl(self):
|
||||||
|
"""Verify callback exception doesn't crash crawl (fail-open)."""
|
||||||
|
exception_count = 0
|
||||||
|
|
||||||
|
async def failing_callback():
|
||||||
|
nonlocal exception_count
|
||||||
|
exception_count += 1
|
||||||
|
raise ConnectionError("Redis connection failed")
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=1,
|
||||||
|
max_pages=3,
|
||||||
|
should_cancel=failing_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=2)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
# Should not raise, should complete crawl
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert exception_count > 0 # Callback was called
|
||||||
|
assert len(results) > 0 # Crawl completed
|
||||||
|
assert strategy.cancelled == False # Not cancelled due to exception
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_includes_cancelled_flag(self):
|
||||||
|
"""Verify state notifications include cancelled flag."""
|
||||||
|
states: List[Dict] = []
|
||||||
|
cancel_at = 3
|
||||||
|
|
||||||
|
async def capture_state(state: Dict[str, Any]):
|
||||||
|
states.append(state)
|
||||||
|
|
||||||
|
async def cancel_after_3():
|
||||||
|
return len(states) >= cancel_at
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=cancel_after_3,
|
||||||
|
on_state_change=capture_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=5)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
# Last state should have cancelled=True
|
||||||
|
assert len(states) > 0
|
||||||
|
assert states[-1].get("cancelled") == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_before_first_url(self):
|
||||||
|
"""Verify cancel before first URL returns empty results."""
|
||||||
|
async def always_cancel():
|
||||||
|
return True
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=always_cancel,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=5)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDFSCancellation:
|
||||||
|
"""DFS strategy cancellation tests."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_via_callback(self):
|
||||||
|
"""Verify DFS respects should_cancel callback."""
|
||||||
|
pages_crawled = 0
|
||||||
|
cancel_after = 3
|
||||||
|
|
||||||
|
async def check_cancel():
|
||||||
|
return pages_crawled >= cancel_after
|
||||||
|
|
||||||
|
async def track_pages(state: Dict[str, Any]):
|
||||||
|
nonlocal pages_crawled
|
||||||
|
pages_crawled = state.get("pages_crawled", 0)
|
||||||
|
|
||||||
|
strategy = DFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
on_state_change=track_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=3)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert strategy._pages_crawled >= cancel_after
|
||||||
|
assert strategy._pages_crawled < 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_method_inherited(self):
|
||||||
|
"""Verify DFS inherits cancel() from BFS."""
|
||||||
|
strategy = DFSDeepCrawlStrategy(max_depth=2, max_pages=10)
|
||||||
|
|
||||||
|
assert hasattr(strategy, 'cancel')
|
||||||
|
assert hasattr(strategy, 'cancelled')
|
||||||
|
assert hasattr(strategy, '_check_cancellation')
|
||||||
|
|
||||||
|
strategy.cancel()
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_mode_cancellation(self):
|
||||||
|
"""Verify DFS stream mode respects cancellation."""
|
||||||
|
results_count = 0
|
||||||
|
cancel_after = 2
|
||||||
|
|
||||||
|
async def check_cancel():
|
||||||
|
return results_count >= cancel_after
|
||||||
|
|
||||||
|
strategy = DFSDeepCrawlStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=3)
|
||||||
|
mock_config = create_mock_config(stream=True)
|
||||||
|
|
||||||
|
async for result in strategy._arun_stream("https://example.com", mock_crawler, mock_config):
|
||||||
|
results_count += 1
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert results_count >= cancel_after
|
||||||
|
assert results_count < 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestBestFirstCancellation:
|
||||||
|
"""Best-First strategy cancellation tests."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_via_callback(self):
|
||||||
|
"""Verify Best-First respects should_cancel callback."""
|
||||||
|
pages_crawled = 0
|
||||||
|
cancel_after = 3
|
||||||
|
|
||||||
|
async def check_cancel():
|
||||||
|
return pages_crawled >= cancel_after
|
||||||
|
|
||||||
|
async def track_pages(state: Dict[str, Any]):
|
||||||
|
nonlocal pages_crawled
|
||||||
|
pages_crawled = state.get("pages_crawled", 0)
|
||||||
|
|
||||||
|
strategy = BestFirstCrawlingStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
on_state_change=track_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=3)
|
||||||
|
mock_config = create_mock_config(stream=True)
|
||||||
|
|
||||||
|
async for _ in strategy._arun_stream("https://example.com", mock_crawler, mock_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert strategy._pages_crawled >= cancel_after
|
||||||
|
assert strategy._pages_crawled < 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_method_works(self):
|
||||||
|
"""Verify Best-First cancel() method works."""
|
||||||
|
strategy = BestFirstCrawlingStrategy(max_depth=2, max_pages=10)
|
||||||
|
|
||||||
|
assert strategy.cancelled == False
|
||||||
|
strategy.cancel()
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_mode_cancellation(self):
|
||||||
|
"""Verify Best-First batch mode respects cancellation."""
|
||||||
|
pages_crawled = 0
|
||||||
|
cancel_after = 2
|
||||||
|
|
||||||
|
async def check_cancel():
|
||||||
|
return pages_crawled >= cancel_after
|
||||||
|
|
||||||
|
async def track_pages(state: Dict[str, Any]):
|
||||||
|
nonlocal pages_crawled
|
||||||
|
pages_crawled = state.get("pages_crawled", 0)
|
||||||
|
|
||||||
|
strategy = BestFirstCrawlingStrategy(
|
||||||
|
max_depth=5,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
on_state_change=track_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=3)
|
||||||
|
mock_config = create_mock_config(stream=False)
|
||||||
|
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
assert len(results) >= cancel_after
|
||||||
|
assert len(results) < 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossStrategyCancellation:
|
||||||
|
"""Tests that apply to all strategies."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("strategy_class", [
|
||||||
|
BFSDeepCrawlStrategy,
|
||||||
|
DFSDeepCrawlStrategy,
|
||||||
|
BestFirstCrawlingStrategy,
|
||||||
|
])
|
||||||
|
async def test_no_cancel_callback_means_no_cancellation(self, strategy_class):
|
||||||
|
"""Verify crawl completes normally without should_cancel."""
|
||||||
|
strategy = strategy_class(max_depth=1, max_pages=5)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=2)
|
||||||
|
|
||||||
|
if strategy_class == BestFirstCrawlingStrategy:
|
||||||
|
mock_config = create_mock_config(stream=True)
|
||||||
|
results = []
|
||||||
|
async for r in strategy._arun_stream("https://example.com", mock_crawler, mock_config):
|
||||||
|
results.append(r)
|
||||||
|
else:
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
assert strategy.cancelled == False
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("strategy_class", [
|
||||||
|
BFSDeepCrawlStrategy,
|
||||||
|
DFSDeepCrawlStrategy,
|
||||||
|
BestFirstCrawlingStrategy,
|
||||||
|
])
|
||||||
|
async def test_cancel_thread_safety(self, strategy_class):
|
||||||
|
"""Verify cancel() is thread-safe (doesn't raise)."""
|
||||||
|
strategy = strategy_class(max_depth=2, max_pages=10)
|
||||||
|
|
||||||
|
# Call cancel from multiple "threads" (simulated)
|
||||||
|
for _ in range(10):
|
||||||
|
strategy.cancel()
|
||||||
|
|
||||||
|
# Should be cancelled without errors
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("strategy_class", [
|
||||||
|
BFSDeepCrawlStrategy,
|
||||||
|
DFSDeepCrawlStrategy,
|
||||||
|
BestFirstCrawlingStrategy,
|
||||||
|
])
|
||||||
|
async def test_should_cancel_param_accepted(self, strategy_class):
|
||||||
|
"""Verify should_cancel parameter is accepted by constructor."""
|
||||||
|
async def dummy_cancel():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
strategy = strategy_class(
|
||||||
|
max_depth=2,
|
||||||
|
max_pages=10,
|
||||||
|
should_cancel=dummy_cancel,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert strategy._should_cancel == dummy_cancel
|
||||||
|
|
||||||
|
|
||||||
|
class TestCancellationEdgeCases:
|
||||||
|
"""Edge case tests for cancellation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_during_batch_processing(self):
|
||||||
|
"""Verify cancellation during batch doesn't lose results."""
|
||||||
|
results_count = 0
|
||||||
|
|
||||||
|
async def cancel_mid_batch():
|
||||||
|
# Cancel after receiving first result
|
||||||
|
return results_count >= 1
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=2,
|
||||||
|
max_pages=100,
|
||||||
|
should_cancel=cancel_mid_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def track_results(state):
|
||||||
|
nonlocal results_count
|
||||||
|
results_count = state.get("pages_crawled", 0)
|
||||||
|
|
||||||
|
strategy._on_state_change = track_results
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=5)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
# Should have at least the first batch of results
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_results_on_cancel(self):
|
||||||
|
"""Verify partial results are returned on cancellation."""
|
||||||
|
cancel_after = 5
|
||||||
|
|
||||||
|
async def check_cancel():
|
||||||
|
return strategy._pages_crawled >= cancel_after
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=10,
|
||||||
|
max_pages=1000,
|
||||||
|
should_cancel=check_cancel,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=10)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
results = await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
# Should have results up to cancellation point
|
||||||
|
assert len(results) >= cancel_after
|
||||||
|
assert strategy.cancelled == True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_callback_called_once_per_level_bfs(self):
|
||||||
|
"""Verify BFS checks cancellation once per level."""
|
||||||
|
check_count = 0
|
||||||
|
|
||||||
|
async def count_checks():
|
||||||
|
nonlocal check_count
|
||||||
|
check_count += 1
|
||||||
|
return False # Never cancel
|
||||||
|
|
||||||
|
strategy = BFSDeepCrawlStrategy(
|
||||||
|
max_depth=2,
|
||||||
|
max_pages=10,
|
||||||
|
should_cancel=count_checks,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_crawler = create_mock_crawler_with_links(num_links=2)
|
||||||
|
mock_config = create_mock_config()
|
||||||
|
|
||||||
|
await strategy._arun_batch("https://example.com", mock_crawler, mock_config)
|
||||||
|
|
||||||
|
# Should have checked at least once per level
|
||||||
|
assert check_count >= 1
|
||||||
Reference in New Issue
Block a user