Files
crawl4ai/tests/async/test_async_executor.py

219 lines
8.0 KiB
Python

import os, sys
import unittest
import asynctest
import asyncio
import time
from typing import Dict, Any, List
from unittest.mock import AsyncMock, MagicMock, patch
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(parent_dir)
# Assuming all classes and imports are already available from the code above
from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.config import MAX_METRICS_HISTORY
from crawl4ai.async_executor import (
SpeedOptimizedExecutor,
ResourceOptimizedExecutor,
AsyncWebCrawler,
ExecutionMode,
SystemMetrics,
CallbackType
)
class TestAsyncExecutor(asynctest.TestCase):
async def setUp(self):
# Set up a mock crawler
self.mock_crawler = AsyncMock(spec=AsyncWebCrawler)
self.mock_crawler.arun = AsyncMock(side_effect=self.mock_crawl)
# Sample URLs
self.urls = [
"https://www.example.com",
"https://www.python.org",
"https://www.asyncio.org",
"https://www.nonexistenturl.xyz", # This will simulate a failure
]
# Set up callbacks
self.callbacks = {
CallbackType.PRE_EXECUTION: AsyncMock(),
CallbackType.POST_EXECUTION: AsyncMock(),
CallbackType.ON_ERROR: AsyncMock(),
CallbackType.ON_RETRY: AsyncMock(),
CallbackType.ON_BATCH_START: AsyncMock(),
CallbackType.ON_BATCH_END: AsyncMock(),
CallbackType.ON_COMPLETE: AsyncMock(),
}
async def mock_crawl(self, url: str, **kwargs):
if "nonexistenturl" in url:
raise Exception("Failed to fetch URL")
return f"Mock content for {url}"
async def test_speed_executor_basic(self):
"""Test basic functionality of SpeedOptimizedExecutor."""
executor = SpeedOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_retries=1,
)
results = await executor.execute(self.urls)
# Assertions
self.assertEqual(len(results), len(self.urls))
self.mock_crawler.arun.assert_awaited()
self.callbacks[CallbackType.PRE_EXECUTION].assert_awaited()
self.callbacks[CallbackType.POST_EXECUTION].assert_awaited()
self.callbacks[CallbackType.ON_ERROR].assert_awaited()
async def test_resource_executor_basic(self):
"""Test basic functionality of ResourceOptimizedExecutor."""
executor = ResourceOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_concurrent_tasks=2,
max_retries=1,
)
results = await executor.execute(self.urls)
# Assertions
self.assertEqual(len(results), len(self.urls))
self.mock_crawler.arun.assert_awaited()
self.callbacks[CallbackType.PRE_EXECUTION].assert_awaited()
self.callbacks[CallbackType.POST_EXECUTION].assert_awaited()
self.callbacks[CallbackType.ON_ERROR].assert_awaited()
async def test_pause_and_resume(self):
"""Test the pause and resume functionality."""
executor = SpeedOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_retries=1,
)
execution_task = asyncio.create_task(executor.execute(self.urls))
await asyncio.sleep(0.1)
await executor.control.pause()
self.assertTrue(await executor.control.is_paused())
# Ensure that execution is paused
await asyncio.sleep(0.5)
await executor.control.resume()
self.assertFalse(await executor.control.is_paused())
results = await execution_task
# Assertions
self.assertEqual(len(results), len(self.urls))
async def test_cancellation(self):
"""Test the cancellation functionality."""
executor = SpeedOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_retries=1,
)
execution_task = asyncio.create_task(executor.execute(self.urls))
await asyncio.sleep(0.1)
await executor.control.cancel()
self.assertTrue(await executor.control.is_cancelled())
with self.assertRaises(asyncio.CancelledError):
await execution_task
async def test_max_retries(self):
"""Test that the executor respects the max_retries setting."""
executor = SpeedOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_retries=2,
)
results = await executor.execute(self.urls)
# The failing URL should have been retried
self.assertEqual(self.mock_crawler.arun.call_count, len(self.urls) + 2)
self.assertEqual(executor.metrics.total_retries, 2)
async def test_callbacks_invoked(self):
"""Test that all callbacks are invoked appropriately."""
executor = SpeedOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_retries=1,
)
await executor.execute(self.urls)
# Check that callbacks were called the correct number of times
self.assertEqual(
self.callbacks[CallbackType.PRE_EXECUTION].call_count,
len(self.urls) * (1 + executor.metrics.total_retries),
)
self.assertEqual(
self.callbacks[CallbackType.POST_EXECUTION].call_count,
executor.metrics.completed_urls,
)
self.assertEqual(
self.callbacks[CallbackType.ON_ERROR].call_count,
executor.metrics.failed_urls * (1 + executor.metrics.total_retries),
)
self.callbacks[CallbackType.ON_COMPLETE].assert_awaited_once()
async def test_resource_limits(self):
"""Test that the ResourceOptimizedExecutor respects resource limits."""
with patch('psutil.cpu_percent', return_value=95), \
patch('psutil.virtual_memory', return_value=MagicMock(percent=85, available=1000)):
executor = ResourceOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_concurrent_tasks=2,
max_retries=1,
)
results = await executor.execute(self.urls)
# Assertions
self.assertEqual(len(results), len(self.urls))
# Since resources are over threshold, batch size should be minimized
batch_sizes = [executor.resource_monitor.get_optimal_batch_size(len(self.urls))]
self.assertTrue(all(size == 1 for size in batch_sizes))
async def test_system_metrics_limit(self):
"""Test that the system_metrics list does not grow indefinitely."""
executor = SpeedOptimizedExecutor(
crawler=self.mock_crawler,
callbacks=self.callbacks,
max_retries=1,
)
# Simulate many batches to exceed MAX_METRICS_HISTORY
original_max_history = MAX_METRICS_HISTORY
try:
# Temporarily reduce MAX_METRICS_HISTORY for the test
globals()['MAX_METRICS_HISTORY'] = 5
# Mock capture_system_metrics to increase system_metrics length
with patch.object(executor.metrics, 'capture_system_metrics') as mock_capture:
def side_effect():
executor.metrics.system_metrics.append(SystemMetrics(0, 0, 0, time.time()))
if len(executor.metrics.system_metrics) > MAX_METRICS_HISTORY:
executor.metrics.system_metrics.pop(0)
mock_capture.side_effect = side_effect
await executor.execute(self.urls * 3) # Multiply URLs to create more batches
# Assertions
self.assertLessEqual(len(executor.metrics.system_metrics), MAX_METRICS_HISTORY)
finally:
# Restore original MAX_METRICS_HISTORY
globals()['MAX_METRICS_HISTORY'] = original_max_history
if __name__ == "__main__":
unittest.main()