diff --git a/tests/pipeline/test_batch_crawl.py b/tests/pipeline/test_batch_crawl.py new file mode 100644 index 00000000..84d7a63f --- /dev/null +++ b/tests/pipeline/test_batch_crawl.py @@ -0,0 +1,405 @@ +# test_batch_crawl.py + +import asyncio +import unittest +from unittest.mock import Mock, patch, AsyncMock +import pytest + +from crawl4ai.pipeline import Pipeline, create_pipeline, batch_crawl +from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig +from crawl4ai.async_logger import AsyncLogger +from crawl4ai.models import CrawlResult, CrawlResultContainer +from crawl4ai.browser_hub_manager import BrowserHubManager + + +# Utility function for tests +async def create_mock_result(url, success=True, status_code=200, html=""): + """Create a mock crawl result for testing""" + result = CrawlResult( + url=url, + html=html, + success=success, + status_code=status_code, + error_message="" if success else f"Error crawling {url}" + ) + return CrawlResultContainer(result) + + +class TestBatchCrawl(unittest.IsolatedAsyncioTestCase): + """Test cases for the batch_crawl function""" + + async def asyncSetUp(self): + """Set up test environment""" + self.logger = AsyncLogger(verbose=False) + self.browser_config = BrowserConfig(headless=True) + self.crawler_config = CrawlerRunConfig() + + # URLs for testing + self.test_urls = [ + "https://example.com/1", + "https://example.com/2", + "https://example.com/3", + "https://example.com/4", + "https://example.com/5" + ] + + # Mock pipeline to avoid actual crawling + self.mock_pipeline = AsyncMock() + self.mock_pipeline.crawl = AsyncMock() + + # Set up pipeline to return success for most URLs, but failure for one + async def mock_crawl(url, config=None): + if url == "https://example.com/3": + return await create_mock_result(url, success=False, status_code=404) + return await create_mock_result(url, success=True) + + self.mock_pipeline.crawl.side_effect = mock_crawl + + # Patch the create_pipeline function + self.create_pipeline_patch = patch( + 'crawl4ai.pipeline.create_pipeline', + return_value=self.mock_pipeline + ) + self.mock_create_pipeline = self.create_pipeline_patch.start() + + async def asyncTearDown(self): + """Clean up after tests""" + self.create_pipeline_patch.stop() + await BrowserHubManager.shutdown_all() + + # === Basic Functionality Tests === + + async def test_simple_batch_with_single_config(self): + """Test basic batch crawling with one configuration for all URLs""" + # Call the batch_crawl function with a list of URLs and single config + results = await batch_crawl( + urls=self.test_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config + ) + + # Verify we got results for all URLs + self.assertEqual(len(results), len(self.test_urls)) + + # Check that pipeline.crawl was called for each URL + self.assertEqual(self.mock_pipeline.crawl.call_count, len(self.test_urls)) + + # Check success/failure as expected + success_count = sum(1 for r in results if r.success) + self.assertEqual(success_count, len(self.test_urls) - 1) # All except URL 3 + + # Verify URLs in results match input URLs + result_urls = sorted([r.url for r in results]) + self.assertEqual(result_urls, sorted(self.test_urls)) + + async def test_batch_with_crawl_specs(self): + """Test batch crawling with different configurations per URL""" + # Create different configs for each URL + crawl_specs = [ + {"url": url, "crawler_config": CrawlerRunConfig(screenshot=i % 2 == 0)} + for i, url in enumerate(self.test_urls) + ] + + # Call batch_crawl with crawl specs + results = await batch_crawl( + crawl_specs=crawl_specs, + browser_config=self.browser_config + ) + + # Verify results + self.assertEqual(len(results), len(crawl_specs)) + + # Verify each URL was crawled with its specific config + for i, spec in enumerate(crawl_specs): + call_args = self.mock_pipeline.crawl.call_args_list[i] + self.assertEqual(call_args[1]['url'], spec['url']) + self.assertEqual( + call_args[1]['config'].screenshot, + spec['crawler_config'].screenshot + ) + + # === Advanced Configuration Tests === + + async def test_with_multiple_browser_configs(self): + """Test using different browser configurations for different URLs""" + # Create different browser configs + browser_config1 = BrowserConfig(headless=True, browser_type="chromium") + browser_config2 = BrowserConfig(headless=True, browser_type="firefox") + + # Create crawl specs with different browser configs + crawl_specs = [ + { + "url": self.test_urls[0], + "browser_config": browser_config1, + "crawler_config": self.crawler_config + }, + { + "url": self.test_urls[1], + "browser_config": browser_config2, + "crawler_config": self.crawler_config + } + ] + + # Call batch_crawl with mixed browser configs + results = await batch_crawl(crawl_specs=crawl_specs) + + # Verify results + self.assertEqual(len(results), len(crawl_specs)) + + # Verify create_pipeline was called with different browser configs + self.assertEqual(self.mock_create_pipeline.call_count, 2) + + # Check call arguments for create_pipeline + call_args_list = self.mock_create_pipeline.call_args_list + self.assertEqual(call_args_list[0][1]['browser_config'], browser_config1) + self.assertEqual(call_args_list[1][1]['browser_config'], browser_config2) + + async def test_with_existing_browser_hub(self): + """Test using a pre-initialized browser hub""" + # Create a mock browser hub + mock_hub = AsyncMock() + + # Call batch_crawl with browser hub + results = await batch_crawl( + urls=self.test_urls, + browser_hub=mock_hub, + crawler_config=self.crawler_config + ) + + # Verify create_pipeline was called with the browser hub + self.mock_create_pipeline.assert_called_with( + browser_hub=mock_hub, + logger=self.logger + ) + + # Verify results + self.assertEqual(len(results), len(self.test_urls)) + + # === Error Handling and Retry Tests === + + async def test_retry_on_failure(self): + """Test retrying failed URLs up to max_tries""" + # Modify mock to fail first 2 times for URL 3, then succeed + attempt_counts = {url: 0 for url in self.test_urls} + + async def mock_crawl_with_retries(url, config=None): + attempt_counts[url] += 1 + if url == "https://example.com/3" and attempt_counts[url] <= 2: + return await create_mock_result(url, success=False, status_code=500) + return await create_mock_result(url, success=True) + + self.mock_pipeline.crawl.side_effect = mock_crawl_with_retries + + # Call batch_crawl with retry configuration + results = await batch_crawl( + urls=self.test_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config, + max_tries=3 + ) + + # Verify all URLs succeeded after retries + self.assertTrue(all(r.success for r in results)) + + # Check retry count for URL 3 + self.assertEqual(attempt_counts["https://example.com/3"], 3) + + # Check other URLs were only tried once + for url in self.test_urls: + if url != "https://example.com/3": + self.assertEqual(attempt_counts[url], 1) + + async def test_give_up_after_max_tries(self): + """Test that crawling gives up after max_tries""" + # Modify mock to always fail for URL 3 + async def mock_crawl_always_fail(url, config=None): + if url == "https://example.com/3": + return await create_mock_result(url, success=False, status_code=500) + return await create_mock_result(url, success=True) + + self.mock_pipeline.crawl.side_effect = mock_crawl_always_fail + + # Call batch_crawl with retry configuration + results = await batch_crawl( + urls=self.test_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config, + max_tries=3 + ) + + # Find result for URL 3 + url3_result = next(r for r in results if r.url == "https://example.com/3") + + # Verify URL 3 still failed after max retries + self.assertFalse(url3_result.success) + + # Verify retry metadata is present (assuming we add this to the result) + self.assertEqual(url3_result.attempt_count, 3) + self.assertTrue(hasattr(url3_result, 'retry_error_messages')) + + async def test_exception_during_crawl(self): + """Test handling exceptions during crawling""" + # Modify mock to raise exception for URL 4 + async def mock_crawl_with_exception(url, config=None): + if url == "https://example.com/4": + raise RuntimeError("Simulated crawler exception") + return await create_mock_result(url, success=True) + + self.mock_pipeline.crawl.side_effect = mock_crawl_with_exception + + # Call batch_crawl + results = await batch_crawl( + urls=self.test_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config + ) + + # Verify we still get results for all URLs + self.assertEqual(len(results), len(self.test_urls)) + + # Find result for URL 4 + url4_result = next(r for r in results if r.url == "https://example.com/4") + + # Verify URL 4 is marked as failed + self.assertFalse(url4_result.success) + + # Verify exception info is captured + self.assertIn("Simulated crawler exception", url4_result.error_message) + + # === Performance and Control Tests === + + async def test_concurrency_limit(self): + """Test limiting concurrent crawls""" + # Create a slow mock crawl function to test concurrency + crawl_started = {url: asyncio.Event() for url in self.test_urls} + crawl_proceed = {url: asyncio.Event() for url in self.test_urls} + + async def slow_mock_crawl(url, config=None): + crawl_started[url].set() + await crawl_proceed[url].wait() + return await create_mock_result(url) + + self.mock_pipeline.crawl.side_effect = slow_mock_crawl + + # Start batch_crawl with concurrency limit of 2 + task = asyncio.create_task( + batch_crawl( + urls=self.test_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config, + concurrency=2 + ) + ) + + # Wait for first 2 crawls to start + await asyncio.wait( + [crawl_started[self.test_urls[0]].wait(), + crawl_started[self.test_urls[1]].wait()], + timeout=1 + ) + + # Verify only 2 crawls started + started_count = sum(1 for url in self.test_urls if crawl_started[url].is_set()) + self.assertEqual(started_count, 2) + + # Allow first crawl to complete + crawl_proceed[self.test_urls[0]].set() + + # Wait for next crawl to start + await asyncio.wait([crawl_started[self.test_urls[2]].wait()], timeout=1) + + # Now 3 total should have started (2 running, 1 completed) + started_count = sum(1 for url in self.test_urls if crawl_started[url].is_set()) + self.assertEqual(started_count, 3) + + # Allow all remaining crawls to complete + for url in self.test_urls: + crawl_proceed[url].set() + + # Wait for batch_crawl to complete + results = await task + + # Verify all URLs were crawled + self.assertEqual(len(results), len(self.test_urls)) + + async def test_cancel_batch_crawl(self): + """Test cancelling a batch crawl operation""" + # Create a crawl function that won't complete unless signaled + crawl_started = {url: asyncio.Event() for url in self.test_urls} + + async def endless_mock_crawl(url, config=None): + crawl_started[url].set() + # This will wait forever unless cancelled + await asyncio.Future() + + self.mock_pipeline.crawl.side_effect = endless_mock_crawl + + # Start batch_crawl + task = asyncio.create_task( + batch_crawl( + urls=self.test_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config + ) + ) + + # Wait for at least one crawl to start + await asyncio.wait( + [crawl_started[self.test_urls[0]].wait()], + timeout=1 + ) + + # Cancel the task + task.cancel() + + # Verify task was cancelled + with self.assertRaises(asyncio.CancelledError): + await task + + # === Edge Cases Tests === + + async def test_empty_url_list(self): + """Test behavior with empty URL list""" + results = await batch_crawl( + urls=[], + browser_config=self.browser_config, + crawler_config=self.crawler_config + ) + + # Should return empty list + self.assertEqual(results, []) + + # Verify crawl wasn't called + self.mock_pipeline.crawl.assert_not_called() + + async def test_mix_of_valid_and_invalid_urls(self): + """Test with a mix of valid and invalid URLs""" + # Include some invalid URLs + mixed_urls = [ + "https://example.com/valid", + "invalid-url", + "http:/missing-slash", + "https://example.com/valid2" + ] + + # Call batch_crawl + results = await batch_crawl( + urls=mixed_urls, + browser_config=self.browser_config, + crawler_config=self.crawler_config + ) + + # Should have results for all URLs + self.assertEqual(len(results), len(mixed_urls)) + + # Check invalid URLs were marked as failed + for result in results: + if result.url in ["invalid-url", "http:/missing-slash"]: + self.assertFalse(result.success) + self.assertIn("Invalid URL", result.error_message) + else: + self.assertTrue(result.success) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file