diff --git a/crawl4ai/async_executor.py b/crawl4ai/async_executor.py new file mode 100644 index 00000000..4f391b7b --- /dev/null +++ b/crawl4ai/async_executor.py @@ -0,0 +1,741 @@ +from __future__ import annotations +import asyncio +import psutil +import logging +import time +import sqlite3 +import aiosqlite +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Callable, Any, Set, Type +from typing import Awaitable +from pathlib import Path +import json +from datetime import datetime +from typing import ClassVar, Type, Union +import inspect + +# Imports from your crawler package +from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy +from .chunking_strategy import ChunkingStrategy, RegexChunking +from .extraction_strategy import ExtractionStrategy +from .models import CrawlResult +from .config import MIN_WORD_THRESHOLD +from .async_webcrawler import AsyncWebCrawler +from .config import MAX_METRICS_HISTORY + +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +# self.logger.error(f"Executor {self.__class__.__name__}: Error message", exc_info=True) +# self.logger.info(f"Executor {self.__class__.__name__}: Info message") +# self.logger.warning(f"Executor {self.__class__.__name__}: Warning message") + + +# Enums and Constants +class ExecutionMode(Enum): + """Execution mode for the crawler executor.""" + SPEED = "speed" + RESOURCE = "resource" + +class TaskState(Enum): + """Possible states for a crawling task.""" + PENDING = "pending" + SCHEDULED = "scheduled" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + RETRYING = "retrying" + +# Types of callbacks we should support +class CallbackType(Enum): + PRE_EXECUTION = "pre_execution" # Before processing a URL + POST_EXECUTION = "post_execution" # After successful processing + ON_ERROR = "on_error" # When an error occurs + ON_RETRY = "on_retry" # Before retrying a failed URL + ON_BATCH_START = "on_batch_start" # Before starting a batch + ON_BATCH_END = "on_batch_end" # After completing a batch + ON_COMPLETE = "on_complete" # After all URLs are processed + + +@dataclass +class SystemMetrics: + """System resource metrics.""" + cpu_percent: float + memory_percent: float + available_memory: int + timestamp: float + + @classmethod + def capture(cls) -> 'SystemMetrics': + """Capture current system metrics.""" + return cls( + cpu_percent=psutil.cpu_percent(), + memory_percent=psutil.virtual_memory().percent, + available_memory=psutil.virtual_memory().available, + timestamp=time.time() + ) + +@dataclass +class TaskMetadata: + """Metadata for a crawling task.""" + url: str + state: TaskState + attempts: int = 0 + last_attempt: Optional[float] = None + error: Optional[str] = None + result: Optional[Any] = None + +@dataclass +class ExecutorMetrics: + """Performance and resource metrics for the executor.""" + # Performance metrics + total_urls: int = 0 + completed_urls: int = 0 + failed_urls: int = 0 + start_time: Optional[float] = None + total_retries: int = 0 + response_times: List[float] = field(default_factory=list) + + # Resource metrics + system_metrics: List[SystemMetrics] = field(default_factory=list) + active_connections: int = 0 + + def capture_system_metrics(self): + """Capture system metrics and enforce history size limit.""" + metrics = SystemMetrics.capture() + self.system_metrics.append(metrics) + if len(self.system_metrics) > MAX_METRICS_HISTORY: + self.system_metrics.pop(0) # Remove the oldest metric + + @property + def urls_per_second(self) -> float: + """Calculate URLs processed per second.""" + if not self.start_time or not self.completed_urls: + return 0.0 + duration = time.time() - self.start_time + return self.completed_urls / duration if duration > 0 else 0 + + @property + def success_rate(self) -> float: + """Calculate success rate as percentage.""" + if not self.total_urls: + return 0.0 + return (self.completed_urls / self.total_urls) * 100 + + @property + def retry_rate(self) -> float: + """Calculate retry rate as percentage.""" + if not self.total_urls: + return 0.0 + return (self.total_retries / self.total_urls) * 100 + + @property + def average_response_time(self) -> float: + """Calculate average response time in seconds.""" + if not self.response_times: + return 0.0 + return sum(self.response_times) / len(self.response_times) + + def to_dict(self) -> Dict[str, Any]: + """Convert metrics to dictionary format.""" + return { + "performance": { + "urls_per_second": self.urls_per_second, + "success_rate": self.success_rate, + "retry_rate": self.retry_rate, + "average_response_time": self.average_response_time, + "total_urls": self.total_urls, + "completed_urls": self.completed_urls, + "failed_urls": self.failed_urls + }, + "resources": { + "cpu_utilization": self.system_metrics[-1].cpu_percent if self.system_metrics else 0, + "memory_usage": self.system_metrics[-1].memory_percent if self.system_metrics else 0, + "active_connections": self.active_connections + } + } + +class ResourceMonitor: + """Monitors and manages system resources.""" + + def __init__(self, mode: ExecutionMode): + self.mode = mode + self.metrics_history: List[SystemMetrics] = [] + self._setup_thresholds() + + def _setup_thresholds(self): + """Set up resource thresholds based on execution mode.""" + if self.mode == ExecutionMode.SPEED: + self.memory_threshold = 80 # 80% memory usage limit + self.cpu_threshold = 90 # 90% CPU usage limit + else: + self.memory_threshold = 40 # 40% memory usage limit + self.cpu_threshold = 30 # 30% CPU usage limit + + async def check_resources(self) -> bool: + """Check if system resources are within acceptable limits.""" + metrics = SystemMetrics.capture() + self.metrics_history.append(metrics) + + # Keep only last hour of metrics + cutoff_time = time.time() - 3600 + self.metrics_history = [m for m in self.metrics_history if m.timestamp > cutoff_time] + + return (metrics.cpu_percent < self.cpu_threshold and + metrics.memory_percent < self.memory_threshold) + + def get_optimal_batch_size(self, total_urls: int) -> int: + metrics = SystemMetrics.capture() + if self.mode == ExecutionMode.SPEED: + base_size = min(1000, total_urls) + + # Adjust based on resource usage + cpu_factor = max(0.0, (self.cpu_threshold - metrics.cpu_percent) / self.cpu_threshold) + mem_factor = max(0.0, (self.memory_threshold - metrics.memory_percent) / self.memory_threshold) + + min_factor = min(cpu_factor, mem_factor) + adjusted_size = max(1, int(base_size * min_factor)) + return min(total_urls, adjusted_size) + else: + # For resource optimization, use a conservative batch size based on resource usage + cpu_factor = max(0.1, (self.cpu_threshold - metrics.cpu_percent) / self.cpu_threshold) + mem_factor = max(0.1, (self.memory_threshold - metrics.memory_percent) / self.memory_threshold) + + min_factor = min(cpu_factor, mem_factor) + adjusted_size = max(1, int(50 * min_factor)) + return min(total_urls, adjusted_size) + +class ExecutorControl: + """Control interface for the executor.""" + + def __init__(self): + self._paused = False + self._cancelled = False + self._pause_event = asyncio.Event() + self._pause_event.set() # Not paused initially + self._lock = asyncio.Lock() # Lock to protect shared state + self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + + async def pause(self): + """Pause the execution.""" + async with self._lock: + self._paused = True + self._pause_event.clear() + + async def resume(self): + """Resume the execution.""" + async with self._lock: + self._paused = False + self._pause_event.set() + + async def cancel(self): + """Cancel all pending operations.""" + async with self._lock: + self._cancelled = True + self._pause_event.set() # Release any paused operations + + async def is_paused(self) -> bool: + """Check if execution is paused.""" + async with self._lock: + return self._paused + + async def is_cancelled(self) -> bool: + """Check if execution is cancelled.""" + async with self._lock: + return self._cancelled + + async def wait_if_paused(self, timeout: Optional[float] = None): + """Wait if execution is paused, with an optional timeout.""" + try: + await asyncio.wait_for(self._pause_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + # Timeout occurred, handle as needed + async with self._lock: + self._paused = False # Optionally reset the paused state + self._pause_event.set() + # Optionally log a warning + self.logger.warning(f"ExecutorControl: wait_if_paused() timed out after {timeout} seconds. Proceeding with execution.") + + async def reset(self): + """Reset control state.""" + async with self._lock: + self._paused = False + self._cancelled = False + self._pause_event.set() + + +class ExecutorStrategy(ABC): + """Abstract Base class for executor strategies. + + Callbacks: + - PRE_EXECUTION: Callable[[str, Dict[str, Any]], None] + - POST_EXECUTION: Callable[[str, Any, Dict[str, Any]], None] + - ON_ERROR: Callable[[str, Exception, Dict[str, Any]], None] + - ON_RETRY: Callable[[str, int, Dict[str, Any]], None] + - ON_BATCH_START: Callable[[List[str], Dict[str, Any]], None] + - ON_BATCH_END: Callable[[List[str], Dict[str, Any]], None] + - ON_COMPLETE: Callable[[Dict[str, Any], Dict[str, Any]], None] + """ + + def __init__( + self, + crawler: AsyncWebCrawler, + mode: ExecutionMode, + # callbacks: Optional[Dict[CallbackType, Callable]] = None, + callbacks: Optional[Dict[CallbackType, Callable[[Any], Union[Awaitable[None], None]]]] = None, + persistence_path: Optional[Path] = None, + **crawl_config_kwargs + ): + self.crawler = crawler + self.mode = mode + self.callbacks = callbacks or {} + self.resource_monitor = ResourceMonitor(mode) + self.tasks: Dict[str, TaskMetadata] = {} + self.active_tasks: Set[str] = set() + self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + self.metrics = ExecutorMetrics() + self.control = ExecutorControl() + self.crawl_config_kwargs = crawl_config_kwargs # Store parameters for arun + + async def get_status(self) -> Dict[str, Any]: + """Get current executor status and metrics.""" + return { + "status": { + "paused": await self.control.is_paused(), + "cancelled": await self.control.is_cancelled(), + "active_tasks": len(self.active_tasks) + }, + "metrics": self.metrics.to_dict() + } + + async def clear_state(self): + """Reset executor state.""" + self.tasks.clear() + self.active_tasks.clear() + self.metrics = ExecutorMetrics() + await self.control.reset() + await self.persistence.clear() # Implement this method + + async def _execute_callback( + self, + callback_type: CallbackType, + *args, + **kwargs + ): + """Execute callback if it exists.""" + if callback := self.callbacks.get(callback_type): + try: + if inspect.iscoroutinefunction(callback): + await callback(*args, **kwargs) + else: + callback(*args, **kwargs) + except Exception as e: + # self.logger.error(f"Callback {callback_type} failed: {e}") + self.logger.error(f"Executor {self.__class__.__name__}: Callback {callback_type.value} failed: {e}", exc_info=True) + + async def _process_url(self, url: str) -> CrawlResult: + max_retries = self.crawl_config_kwargs.get('max_retries', 3) + backoff_factor = self.crawl_config_kwargs.get('backoff_factor', 1) + attempts = 0 + + while attempts <= max_retries: + # Invoke PRE_EXECUTION callback + await self._execute_callback(CallbackType.PRE_EXECUTION, url, self.metrics.to_dict()) + + """Process a single URL using the crawler.""" + # Wait if execution is paused + await self.control.wait_if_paused(timeout=300) + + # Check if cancelled + if await self.control.is_cancelled(): + raise asyncio.CancelledError("Execution was cancelled") + + start_time = time.time() + self.metrics.active_connections += 1 + + try: + result = await self.crawler.arun(url, **self.crawl_config_kwargs) + self.metrics.completed_urls += 1 + self.metrics.response_times.append(time.time() - start_time) + # Invoke POST_EXECUTION callback + await self._execute_callback(CallbackType.POST_EXECUTION, url, result, self.metrics.to_dict()) + + return result + + except Exception as e: + attempts += 1 + self.metrics.failed_urls += 1 + # self.logger.error(f"Error processing URL {url}: {e}") + self.logger.error(f"Executor {self.__class__.__name__}: Error processing URL {url}: {e}", exc_info=True) + # Invoke ON_ERROR callback + await self._execute_callback(CallbackType.ON_ERROR, url, e, self.metrics.to_dict()) + + if attempts <= max_retries: + # Invoke ON_RETRY callback + await self._execute_callback(CallbackType.ON_RETRY, url, attempts, self.metrics.to_dict()) + # Wait before retrying + await asyncio.sleep(backoff_factor * attempts) + else: + raise e + + finally: + self.metrics.active_connections -= 1 + # Update system metrics + # INFO: Uncomment this line if you want to capture system metrics after each URL, but it causes a performance hit + # self.metrics.system_metrics.append(SystemMetrics.capture()) + # Exit the loop if successful or retries exceeded + if attempts > max_retries: + break + + async def execute(self, urls: List[str]) -> Dict[str, Any]: + """Execute crawling tasks.""" + # Initialize metrics + self.metrics.total_urls = len(urls) + self.metrics.start_time = time.time() + + # Create context with metrics (used for callbacks) + context = { + "mode": self.mode, + "start_time": self.metrics.start_time, + "total_urls": self.metrics.total_urls + } + + # Invoke ON_BATCH_START callback + await self._execute_callback(CallbackType.ON_BATCH_START, urls, context) + + results = {} + batch_errors = [] + + # Use the crawler within an async context manager + async with self.crawler: + # Check for cancellation before starting + if await self.control.is_cancelled(): + raise asyncio.CancelledError("Execution was cancelled") + + # Wait if paused + await self.control.wait_if_paused(timeout=300) + + # Prepare list of batches + batches = [] + total_urls_remaining = len(urls) + index = 0 + + while index < len(urls): + batch_size = self.resource_monitor.get_optimal_batch_size(total_urls_remaining) + batch_urls = urls[index:index + batch_size] + batches.append(batch_urls) + index += batch_size + total_urls_remaining -= batch_size + + # Process each batch + for batch_urls in batches: + # Check for cancellation + if await self.control.is_cancelled(): + raise asyncio.CancelledError("Execution was cancelled") + + # Wait if paused + await self.control.wait_if_paused(timeout=300) + + try: + # Process the batch + batch_results = await self.process_batch(batch_urls) + # Update results + results.update(batch_results) + # Capture system metrics after each batch + self.metrics.capture_system_metrics() + # Update system metrics after each batch + # self.metrics.system_metrics.append(SystemMetrics.capture()) # Has memory leak issue + # Invoke ON_BATCH_END callback + await self._execute_callback(CallbackType.ON_BATCH_END, batch_urls, context) + except Exception as e: + # Handle batch-level exceptions + self.logger.error(f"Error processing batch: {e}") + await self._execute_callback(CallbackType.ON_ERROR, "batch", e, context) + # Collect the error + batch_errors.append((batch_urls, e)) + # Continue to next batch instead of raising + continue + + # Execution complete + await self._execute_callback(CallbackType.ON_COMPLETE, results, context) + + # Log final metrics and batch errors if any + final_status = await self.get_status() + # self.logger.info(f"Execution completed. Metrics: {final_status}") + self.logger.info(f"Executor {self.__class__.__name__}: Execution completed. Metrics: {final_status}") + + if batch_errors: + # self.logger.warning(f"Execution completed with errors in {len(batch_errors)} batches.") + self.logger.warning(f"Executor {self.__class__.__name__}: Execution completed with errors in {len(batch_errors)} batches.") + + return results + + + @abstractmethod + async def process_batch(self, batch_urls: List[str]) -> Dict[str, Any]: + """Process a batch of URLs.""" + pass + +class SpeedOptimizedExecutor(ExecutorStrategy): + """Executor optimized for speed.""" + + def __init__( + self, + crawler: AsyncWebCrawler, + callbacks: Optional[Dict[CallbackType, Callable]] = None, + persistence_path: Optional[Path] = None, + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = None, + bypass_cache: bool = False, + css_selector: str = None, + screenshot: bool = False, + user_agent: str = None, + verbose=True, + connection_pool_size: int = 1000, + dns_cache_size: int = 10000, + backoff_factor: int = 1, + **kwargs + ): + if chunking_strategy is None: + chunking_strategy = RegexChunking() + + super().__init__( + crawler=crawler, + mode=ExecutionMode.SPEED, + callbacks=callbacks, + persistence_path=persistence_path, + word_count_threshold=word_count_threshold, + extraction_strategy=extraction_strategy, + chunking_strategy=chunking_strategy, + bypass_cache=bypass_cache, + css_selector=css_selector, + screenshot=screenshot, + user_agent=user_agent, + verbose=verbose, + **kwargs + ) + + self.connection_pool_size = connection_pool_size + self.dns_cache_size = dns_cache_size + self.backoff_factor = backoff_factor + + self.logger.info( + # "Initialized speed-optimized executor with:" + f"Executor {self.__class__.__name__}: Initialized with:" + f" connection_pool_size={self.connection_pool_size}," + f" dns_cache_size={self.dns_cache_size}" + ) + + async def process_batch(self, batch_urls: List[str]) -> Dict[str, Any]: + """Process a batch of URLs concurrently.""" + batch_tasks = [self._process_url(url) for url in batch_urls] + + # Execute batch with concurrency control + batch_results_list = await asyncio.gather(*batch_tasks, return_exceptions=True) + + batch_results = {} + for url, result in zip(batch_urls, batch_results_list): + if isinstance(result, Exception): + batch_results[url] = {"success": False, "error": str(result)} + else: + batch_results[url] = {"success": True, "result": result} + + return batch_results + +class ResourceOptimizedExecutor(ExecutorStrategy): + """Executor optimized for resource usage.""" + + def __init__( + self, + crawler: AsyncWebCrawler, + callbacks: Optional[Dict[CallbackType, Callable]] = None, + persistence_path: Optional[Path] = None, + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = None, + bypass_cache: bool = False, + css_selector: str = None, + screenshot: bool = False, + user_agent: str = None, + verbose=True, + connection_pool_size: int = 50, + dns_cache_size: int = 1000, + backoff_factor: int = 5, + max_concurrent_tasks: int = 5, + **kwargs + ): + if chunking_strategy is None: + chunking_strategy = RegexChunking() + + super().__init__( + crawler=crawler, + mode=ExecutionMode.RESOURCE, + callbacks=callbacks, + persistence_path=persistence_path, + word_count_threshold=word_count_threshold, + extraction_strategy=extraction_strategy, + chunking_strategy=chunking_strategy, + bypass_cache=bypass_cache, + css_selector=css_selector, + screenshot=screenshot, + user_agent=user_agent, + verbose=verbose, + **kwargs + ) + + self.connection_pool_size = connection_pool_size + self.dns_cache_size = dns_cache_size + self.backoff_factor = backoff_factor + self.max_concurrent_tasks = max_concurrent_tasks + + self.logger.info( + # "Initialized resource-optimized executor with:" + f"Executor {self.__class__.__name__}: Initialized with:" + f" connection_pool_size={self.connection_pool_size}," + f" dns_cache_size={self.dns_cache_size}," + f" max_concurrent_tasks={self.max_concurrent_tasks}" + ) + + async def process_batch(self, batch_urls: List[str]) -> Dict[str, Any]: + """Process a batch of URLs with resource optimization.""" + batch_results = {} + semaphore = asyncio.Semaphore(self.max_concurrent_tasks) + + # Wait until resources are available before processing batch + while not await self.resource_monitor.check_resources(): + # self.logger.warning("Resource limits reached, waiting...") + self.logger.warning(f"Executor {self.__class__.__name__}: Resource limits reached, waiting...") + await asyncio.sleep(self.backoff_factor) + # Check for cancellation + if await self.control.is_cancelled(): + raise asyncio.CancelledError("Execution was cancelled") + + async def process_url_with_semaphore(url): + async with semaphore: + # Check for cancellation + if await self.control.is_cancelled(): + raise asyncio.CancelledError("Execution was cancelled") + # Wait if paused + await self.control.wait_if_paused(timeout=300) + + try: + result = await self._process_url(url) + batch_results[url] = {"success": True, "result": result} + except Exception as e: + batch_results[url] = {"success": False, "error": str(e)} + finally: + # Update system metrics after each URL + # INFO: Uncomment this line if you want to capture system metrics after each URL, but it causes a performance hit + # self.metrics.system_metrics.append(SystemMetrics.capture()) + # Controlled delay between URLs + await asyncio.sleep(0.1) # Small delay for resource management + + tasks = [process_url_with_semaphore(url) for url in batch_urls] + await asyncio.gather(*tasks) + + return batch_results + + + + + +async def main(): + # Sample callback functions + async def pre_execution_callback(url: str, context: Dict[str, Any]): + print(f"Pre-execution callback: About to process URL {url}") + + async def post_execution_callback(url: str, result: Any, context: Dict[str, Any]): + print(f"Post-execution callback: Successfully processed URL {url}") + + async def on_error_callback(url: str, error: Exception, context: Dict[str, Any]): + print(f"Error callback: Error processing URL {url}: {error}") + + async def on_retry_callback(url: str, attempt: int, context: Dict[str, Any]): + print(f"Retry callback: Retrying URL {url}, attempt {attempt}") + + async def on_batch_start_callback(urls: List[str], context: Dict[str, Any]): + print(f"Batch start callback: Starting batch with {len(urls)} URLs") + + async def on_batch_end_callback(urls: List[str], context: Dict[str, Any]): + print(f"Batch end callback: Completed batch with {len(urls)} URLs") + + async def on_complete_callback(results: Dict[str, Any], context: Dict[str, Any]): + print(f"Complete callback: Execution completed with {len(results)} results") + + # Sample URLs to crawl + urls = [ + "https://www.example.com", + "https://www.python.org", + "https://www.asyncio.org", + # Add more URLs as needed + ] + + # Instantiate the crawler + crawler = AsyncWebCrawler() + + # Set up callbacks + callbacks = { + CallbackType.PRE_EXECUTION: pre_execution_callback, + CallbackType.POST_EXECUTION: post_execution_callback, + CallbackType.ON_ERROR: on_error_callback, + CallbackType.ON_RETRY: on_retry_callback, + CallbackType.ON_BATCH_START: on_batch_start_callback, + CallbackType.ON_BATCH_END: on_batch_end_callback, + CallbackType.ON_COMPLETE: on_complete_callback, + } + + # Instantiate the executors + speed_executor = SpeedOptimizedExecutor( + crawler=crawler, + callbacks=callbacks, + max_retries=2, # Example additional config + ) + + resource_executor = ResourceOptimizedExecutor( + crawler=crawler, + callbacks=callbacks, + max_concurrent_tasks=3, # Limit concurrency + max_retries=2, # Example additional config + ) + + # Choose which executor to use + executor = speed_executor # Or resource_executor + + # Start the execution in a background task + execution_task = asyncio.create_task(executor.execute(urls)) + + # Simulate control operations + await asyncio.sleep(2) # Let it run for a bit + print("Pausing execution...") + await executor.control.pause() + await asyncio.sleep(2) # Wait while paused + print("Resuming execution...") + await executor.control.resume() + + # Wait for execution to complete + results = await execution_task + + # Print the results + print("Execution results:") + for url, result in results.items(): + print(f"{url}: {result}") + + # Get and print final metrics + final_status = await executor.get_status() + print("Final executor status and metrics:") + print(final_status) + +# Run the main function +if __name__ == "__main__": + asyncio.run(main()) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/tests/async/test_async_executor.py b/tests/async/test_async_executor.py new file mode 100644 index 00000000..8cdaeb10 --- /dev/null +++ b/tests/async/test_async_executor.py @@ -0,0 +1,219 @@ +import os, sys +import unittest +import asynctest +import asyncio +import time + +from typing import Dict, Any, List +from unittest.mock import AsyncMock, MagicMock, patch + +# Add the parent directory to the Python path +parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(parent_dir) + +# Assuming all classes and imports are already available from the code above +from crawl4ai.async_webcrawler import AsyncWebCrawler +from crawl4ai.config import MAX_METRICS_HISTORY +from crawl4ai.async_executor import ( + SpeedOptimizedExecutor, + ResourceOptimizedExecutor, + AsyncWebCrawler, + ExecutionMode, + SystemMetrics, + CallbackType +) + +class TestAsyncExecutor(asynctest.TestCase): + async def setUp(self): + # Set up a mock crawler + self.mock_crawler = AsyncMock(spec=AsyncWebCrawler) + self.mock_crawler.arun = AsyncMock(side_effect=self.mock_crawl) + + # Sample URLs + self.urls = [ + "https://www.example.com", + "https://www.python.org", + "https://www.asyncio.org", + "https://www.nonexistenturl.xyz", # This will simulate a failure + ] + + # Set up callbacks + self.callbacks = { + CallbackType.PRE_EXECUTION: AsyncMock(), + CallbackType.POST_EXECUTION: AsyncMock(), + CallbackType.ON_ERROR: AsyncMock(), + CallbackType.ON_RETRY: AsyncMock(), + CallbackType.ON_BATCH_START: AsyncMock(), + CallbackType.ON_BATCH_END: AsyncMock(), + CallbackType.ON_COMPLETE: AsyncMock(), + } + + async def mock_crawl(self, url: str, **kwargs): + if "nonexistenturl" in url: + raise Exception("Failed to fetch URL") + return f"Mock content for {url}" + + async def test_speed_executor_basic(self): + """Test basic functionality of SpeedOptimizedExecutor.""" + executor = SpeedOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_retries=1, + ) + + results = await executor.execute(self.urls) + + # Assertions + self.assertEqual(len(results), len(self.urls)) + self.mock_crawler.arun.assert_awaited() + self.callbacks[CallbackType.PRE_EXECUTION].assert_awaited() + self.callbacks[CallbackType.POST_EXECUTION].assert_awaited() + self.callbacks[CallbackType.ON_ERROR].assert_awaited() + + async def test_resource_executor_basic(self): + """Test basic functionality of ResourceOptimizedExecutor.""" + executor = ResourceOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_concurrent_tasks=2, + max_retries=1, + ) + + results = await executor.execute(self.urls) + + # Assertions + self.assertEqual(len(results), len(self.urls)) + self.mock_crawler.arun.assert_awaited() + self.callbacks[CallbackType.PRE_EXECUTION].assert_awaited() + self.callbacks[CallbackType.POST_EXECUTION].assert_awaited() + self.callbacks[CallbackType.ON_ERROR].assert_awaited() + + async def test_pause_and_resume(self): + """Test the pause and resume functionality.""" + executor = SpeedOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_retries=1, + ) + + execution_task = asyncio.create_task(executor.execute(self.urls)) + await asyncio.sleep(0.1) + await executor.control.pause() + self.assertTrue(await executor.control.is_paused()) + + # Ensure that execution is paused + await asyncio.sleep(0.5) + await executor.control.resume() + self.assertFalse(await executor.control.is_paused()) + + results = await execution_task + + # Assertions + self.assertEqual(len(results), len(self.urls)) + + async def test_cancellation(self): + """Test the cancellation functionality.""" + executor = SpeedOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_retries=1, + ) + + execution_task = asyncio.create_task(executor.execute(self.urls)) + await asyncio.sleep(0.1) + await executor.control.cancel() + self.assertTrue(await executor.control.is_cancelled()) + + with self.assertRaises(asyncio.CancelledError): + await execution_task + + async def test_max_retries(self): + """Test that the executor respects the max_retries setting.""" + executor = SpeedOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_retries=2, + ) + + results = await executor.execute(self.urls) + + # The failing URL should have been retried + self.assertEqual(self.mock_crawler.arun.call_count, len(self.urls) + 2) + self.assertEqual(executor.metrics.total_retries, 2) + + async def test_callbacks_invoked(self): + """Test that all callbacks are invoked appropriately.""" + executor = SpeedOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_retries=1, + ) + + await executor.execute(self.urls) + + # Check that callbacks were called the correct number of times + self.assertEqual( + self.callbacks[CallbackType.PRE_EXECUTION].call_count, + len(self.urls) * (1 + executor.metrics.total_retries), + ) + self.assertEqual( + self.callbacks[CallbackType.POST_EXECUTION].call_count, + executor.metrics.completed_urls, + ) + self.assertEqual( + self.callbacks[CallbackType.ON_ERROR].call_count, + executor.metrics.failed_urls * (1 + executor.metrics.total_retries), + ) + self.callbacks[CallbackType.ON_COMPLETE].assert_awaited_once() + + async def test_resource_limits(self): + """Test that the ResourceOptimizedExecutor respects resource limits.""" + with patch('psutil.cpu_percent', return_value=95), \ + patch('psutil.virtual_memory', return_value=MagicMock(percent=85, available=1000)): + executor = ResourceOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_concurrent_tasks=2, + max_retries=1, + ) + + results = await executor.execute(self.urls) + + # Assertions + self.assertEqual(len(results), len(self.urls)) + # Since resources are over threshold, batch size should be minimized + batch_sizes = [executor.resource_monitor.get_optimal_batch_size(len(self.urls))] + self.assertTrue(all(size == 1 for size in batch_sizes)) + + async def test_system_metrics_limit(self): + """Test that the system_metrics list does not grow indefinitely.""" + executor = SpeedOptimizedExecutor( + crawler=self.mock_crawler, + callbacks=self.callbacks, + max_retries=1, + ) + + # Simulate many batches to exceed MAX_METRICS_HISTORY + original_max_history = MAX_METRICS_HISTORY + try: + # Temporarily reduce MAX_METRICS_HISTORY for the test + globals()['MAX_METRICS_HISTORY'] = 5 + + # Mock capture_system_metrics to increase system_metrics length + with patch.object(executor.metrics, 'capture_system_metrics') as mock_capture: + def side_effect(): + executor.metrics.system_metrics.append(SystemMetrics(0, 0, 0, time.time())) + if len(executor.metrics.system_metrics) > MAX_METRICS_HISTORY: + executor.metrics.system_metrics.pop(0) + mock_capture.side_effect = side_effect + + await executor.execute(self.urls * 3) # Multiply URLs to create more batches + + # Assertions + self.assertLessEqual(len(executor.metrics.system_metrics), MAX_METRICS_HISTORY) + finally: + # Restore original MAX_METRICS_HISTORY + globals()['MAX_METRICS_HISTORY'] = original_max_history + +if __name__ == "__main__": + unittest.main() \ No newline at end of file