refactor(deep-crawl): reorganize deep crawling functionality into dedicated module
Restructure deep crawling code into a dedicated module with improved organization: - Move deep crawl logic from async_deep_crawl.py to deep_crawling/ - Create separate files for BFS strategy, filters, and scorers - Improve code organization and maintainability - Add optimized implementations for URL filtering and scoring - Rename DeepCrawlHandler to DeepCrawlDecorator for clarity BREAKING CHANGE: DeepCrawlStrategy and BreadthFirstSearchStrategy imports need to be updated to new package structure
This commit is contained in:
@@ -15,8 +15,6 @@ from .extraction_strategy import (
|
||||
JsonCssExtractionStrategy,
|
||||
JsonXPathExtractionStrategy
|
||||
)
|
||||
|
||||
from .async_deep_crawl import DeepCrawlStrategy, BreadthFirstSearchStrategy
|
||||
from .chunking_strategy import ChunkingStrategy, RegexChunking
|
||||
from .markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter, LLMContentFilter, RelevantContentFilter
|
||||
@@ -33,8 +31,6 @@ from .docker_client import Crawl4aiDockerClient
|
||||
from .hub import CrawlerHub
|
||||
|
||||
__all__ = [
|
||||
"DeepCrawlStrategy",
|
||||
"BreadthFirstSearchStrategy",
|
||||
"AsyncWebCrawler",
|
||||
"CrawlResult",
|
||||
"CrawlerHub",
|
||||
|
||||
@@ -13,7 +13,7 @@ from .chunking_strategy import ChunkingStrategy, RegexChunking
|
||||
from .markdown_generation_strategy import MarkdownGenerationStrategy
|
||||
from .content_filter_strategy import RelevantContentFilter # , BM25ContentFilter, LLMContentFilter, PruningContentFilter
|
||||
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
|
||||
from .async_deep_crawl import DeepCrawlStrategy
|
||||
from .deep_crawling import DeepCrawlStrategy
|
||||
from typing import Union, List
|
||||
from .cache_context import CacheMode
|
||||
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
# crawl4ai/async_deep_crawl.py
|
||||
|
||||
"""Remember:
|
||||
# Update CrawlerRunConfig in async_configs.py (additional field)
|
||||
class CrawlerRunConfig(BaseModel):
|
||||
deep_crawl_strategy: Optional[DeepCrawlStrategy] = Field(
|
||||
default=None,
|
||||
description="Strategy for deep crawling websites"
|
||||
)
|
||||
# ... other existing fields remain unchanged
|
||||
|
||||
# In AsyncWebCrawler class (partial implementation)
|
||||
class AsyncWebCrawler:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Existing initialization
|
||||
self._deep_handler = DeepCrawlHandler(self)
|
||||
self.arun = self._deep_handler(self.arun) # Decorate original method
|
||||
|
||||
async def arun(self, url: str, config: Optional[CrawlerRunConfig] = None, **kwargs):
|
||||
# ... existing implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from functools import wraps
|
||||
from typing import AsyncGenerator, List, Optional, Set, Union, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .async_webcrawler import AsyncWebCrawler, CrawlResult
|
||||
from .async_configs import CrawlerRunConfig
|
||||
from .async_dispatcher import MemoryAdaptiveDispatcher
|
||||
|
||||
CrawlResultT = TypeVar('CrawlResultT', bound=CrawlResult)
|
||||
RunManyReturn = Union[CrawlResultT, List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
|
||||
|
||||
|
||||
class DeepCrawlStrategy(BaseModel):
|
||||
"""Base class for deep crawling strategies."""
|
||||
max_depth: int = Field(default=3, description="Maximum crawl depth from initial URL")
|
||||
include_external: bool = Field(default=False, description="Follow links to external domains")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(
|
||||
self,
|
||||
crawler: "AsyncWebCrawler",
|
||||
start_url: str,
|
||||
config: "CrawlerRunConfig"
|
||||
) -> "RunManyReturn":
|
||||
"""Execute the crawling strategy."""
|
||||
raise NotImplementedError
|
||||
|
||||
class BreadthFirstSearchStrategy(DeepCrawlStrategy):
|
||||
"""Breadth-first search implementation for deep crawling."""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
crawler: "AsyncWebCrawler",
|
||||
start_url: str,
|
||||
config: "CrawlerRunConfig"
|
||||
) -> "RunManyReturn":
|
||||
"""BFS implementation using arun_many for batch processing."""
|
||||
async def stream_results():
|
||||
"""Inner async generator for streaming results."""
|
||||
nonlocal crawler, start_url, config
|
||||
base_domain = urlparse(start_url).netloc
|
||||
queue = deque([(start_url, 0)])
|
||||
visited: Set[str] = set()
|
||||
|
||||
# Create config copy without deep strategy for child requests
|
||||
child_config = config.copy(update={
|
||||
'deep_crawl_strategy': None,
|
||||
'stream': False # Process levels sequentially
|
||||
})
|
||||
|
||||
while queue:
|
||||
current_url, depth = queue.popleft()
|
||||
|
||||
if depth > self.max_depth or current_url in visited:
|
||||
continue
|
||||
|
||||
visited.add(current_url)
|
||||
|
||||
# Process current level using arun_many
|
||||
batch_results = await crawler.arun_many(
|
||||
urls=[current_url],
|
||||
config=child_config,
|
||||
dispatcher=MemoryAdaptiveDispatcher()
|
||||
)
|
||||
|
||||
for result in batch_results:
|
||||
yield result
|
||||
|
||||
# Queue next level if within depth limit
|
||||
if depth < self.max_depth:
|
||||
new_urls = self._extract_links(result, base_domain)
|
||||
for url in new_urls:
|
||||
if url not in visited:
|
||||
queue.append((url, depth + 1))
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if config.stream:
|
||||
return stream_results()
|
||||
else:
|
||||
results: List[CrawlResultT] = []
|
||||
async for result in stream_results():
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def _extract_links(self, result: "CrawlResult", base_domain: str) -> List[str]:
|
||||
"""Extract links from crawl result with domain filtering."""
|
||||
internal = result.links.get('internal', [])
|
||||
external = result.links.get('external', []) if self.include_external else []
|
||||
|
||||
return [
|
||||
url for url in internal + external
|
||||
if self._same_domain(url, base_domain) or self.include_external
|
||||
]
|
||||
|
||||
def _same_domain(self, url: str, base_domain: str) -> bool:
|
||||
"""Check if URL belongs to the base domain."""
|
||||
return urlparse(url).netloc == base_domain
|
||||
|
||||
class DeepCrawlHandler:
|
||||
"""Decorator that adds deep crawling capabilities to arun."""
|
||||
|
||||
def __init__(self, crawler: "AsyncWebCrawler"):
|
||||
self.crawler = crawler
|
||||
|
||||
def __call__(self, original_arun):
|
||||
@wraps(original_arun)
|
||||
async def wrapped_arun(url: str, config: Optional["CrawlerRunConfig"] = None, **kwargs):
|
||||
# First run the original arun
|
||||
initial_result = await original_arun(url, config=config, **kwargs)
|
||||
|
||||
if config and config.deep_crawl_strategy:
|
||||
# Execute deep crawl strategy if configured
|
||||
return await config.deep_crawl_strategy.run(
|
||||
crawler=self.crawler,
|
||||
start_url=url,
|
||||
config=config
|
||||
)
|
||||
|
||||
return initial_result
|
||||
|
||||
return wrapped_arun
|
||||
|
||||
async def main():
|
||||
"""Example deep crawl of documentation site."""
|
||||
config = CrawlerRunConfig(
|
||||
deep_crawl_strategy=BreadthFirstSearchStrategy(
|
||||
max_depth=2,
|
||||
include_external=False
|
||||
),
|
||||
stream=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
print("Starting deep crawl in streaming mode:")
|
||||
async for result in await crawler.arun(
|
||||
url="https://docs.crawl4ai.com",
|
||||
config=config
|
||||
):
|
||||
print(f"→ {result.url} (Depth: {result.metadata.get('depth', 0)})")
|
||||
|
||||
print("\nStarting deep crawl in batch mode:")
|
||||
config.stream = False
|
||||
results = await crawler.arun(
|
||||
url="https://docs.crawl4ai.com",
|
||||
config=config
|
||||
)
|
||||
print(f"Crawled {len(results)} pages")
|
||||
print(f"Example page: {results[0].url}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -29,7 +29,7 @@ from .markdown_generation_strategy import (
|
||||
DefaultMarkdownGenerator,
|
||||
MarkdownGenerationStrategy,
|
||||
)
|
||||
from .async_deep_crawl import DeepCrawlHandler
|
||||
from .deep_crawling import DeepCrawlDecorator
|
||||
from .async_logger import AsyncLogger
|
||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from .async_dispatcher import * # noqa: F403
|
||||
@@ -56,9 +56,6 @@ DeepCrawlManyReturn = Union[
|
||||
AsyncGenerator[CrawlResultT, None],
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
class AsyncWebCrawler:
|
||||
"""
|
||||
Asynchronous web crawler with flexible caching capabilities.
|
||||
@@ -83,16 +80,7 @@ class AsyncWebCrawler:
|
||||
await crawler.close()
|
||||
```
|
||||
|
||||
Migration Guide:
|
||||
Old way (deprecated):
|
||||
crawler = AsyncWebCrawler(always_by_pass_cache=True, browser_type="chromium", headless=True)
|
||||
|
||||
New way (recommended):
|
||||
browser_config = BrowserConfig(browser_type="chromium", headless=True)
|
||||
crawler = AsyncWebCrawler(config=browser_config)
|
||||
|
||||
|
||||
Attributes:
|
||||
Attributes:
|
||||
browser_config (BrowserConfig): Configuration object for browser settings.
|
||||
crawler_strategy (AsyncCrawlerStrategy): Strategy for crawling web pages.
|
||||
logger (AsyncLogger): Logger instance for recording events and errors.
|
||||
@@ -217,7 +205,7 @@ class AsyncWebCrawler:
|
||||
self.ready = False
|
||||
|
||||
# Decorate arun method with deep crawling capabilities
|
||||
self._deep_handler = DeepCrawlHandler(self)
|
||||
self._deep_handler = DeepCrawlDecorator(self)
|
||||
self.arun = self._deep_handler(self.arun)
|
||||
|
||||
async def start(self):
|
||||
|
||||
172
crawl4ai/deep_crawling/__init__.py
Normal file
172
crawl4ai/deep_crawling/__init__.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# deep_crawl_strategy.py
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, Optional, Set, List, Dict, TypeVar, Union
|
||||
|
||||
from ..models import CrawlResult
|
||||
from typing import TYPE_CHECKING
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..async_configs import CrawlerRunConfig
|
||||
from ..async_webcrawler import AsyncWebCrawler
|
||||
from .bfs_strategy import BFSDeepCrawlStrategy
|
||||
|
||||
CrawlResultT = TypeVar("CrawlResultT", bound=CrawlResult)
|
||||
# In batch mode we return List[CrawlResult] and in stream mode an AsyncGenerator.
|
||||
RunManyReturn = Union[CrawlResultT, List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
|
||||
|
||||
|
||||
class DeepCrawlDecorator:
|
||||
"""Decorator that adds deep crawling capability to arun method."""
|
||||
deep_crawl_active = ContextVar("deep_crawl_active", default=False)
|
||||
|
||||
def __init__(self, crawler: "AsyncWebCrawler"):
|
||||
self.crawler = crawler
|
||||
|
||||
def __call__(self, original_arun):
|
||||
@wraps(original_arun)
|
||||
async def wrapped_arun(url: str, config: Optional["CrawlerRunConfig"] = None, **kwargs):
|
||||
# If deep crawling is already active, call the original method to avoid recursion.
|
||||
if config and config.deep_crawl_strategy and not self.deep_crawl_active.get():
|
||||
token = self.deep_crawl_active.set(True)
|
||||
# Await the arun call to get the actual result object.
|
||||
result_obj = await config.deep_crawl_strategy.arun(
|
||||
crawler=self.crawler,
|
||||
start_url=url,
|
||||
config=config
|
||||
)
|
||||
if config.stream:
|
||||
async def result_wrapper():
|
||||
try:
|
||||
async for result in result_obj:
|
||||
yield result
|
||||
finally:
|
||||
self.deep_crawl_active.reset(token)
|
||||
return result_wrapper()
|
||||
else:
|
||||
try:
|
||||
return result_obj
|
||||
finally:
|
||||
self.deep_crawl_active.reset(token)
|
||||
return await original_arun(url, config=config, **kwargs)
|
||||
return wrapped_arun
|
||||
|
||||
class DeepCrawlStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for deep crawling strategies.
|
||||
|
||||
Core functions:
|
||||
- arun: Main entry point that returns an async generator of CrawlResults.
|
||||
- shutdown: Clean up resources.
|
||||
- can_process_url: Validate a URL and decide whether to process it.
|
||||
- _process_links: Extract and process links from a CrawlResult.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun_batch(
|
||||
self,
|
||||
start_url: str,
|
||||
crawler: "AsyncWebCrawler",
|
||||
config: "CrawlerRunConfig",
|
||||
) -> List[CrawlResult]:
|
||||
"""
|
||||
Batch (non-streaming) mode:
|
||||
Processes one BFS level at a time, then yields all the results.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _arun_stream(
|
||||
self,
|
||||
start_url: str,
|
||||
crawler: "AsyncWebCrawler",
|
||||
config: "CrawlerRunConfig",
|
||||
) -> AsyncGenerator[CrawlResult, None]:
|
||||
"""
|
||||
Streaming mode:
|
||||
Processes one BFS level at a time and yields results immediately as they arrive.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
start_url: str,
|
||||
crawler: "AsyncWebCrawler",
|
||||
config: Optional["CrawlerRunConfig"] = None,
|
||||
) -> RunManyReturn:
|
||||
"""
|
||||
Traverse the given URL using the specified crawler.
|
||||
|
||||
Args:
|
||||
start_url (str): The URL from which to start crawling.
|
||||
crawler (AsyncWebCrawler): The crawler instance to use.
|
||||
crawler_run_config (Optional[CrawlerRunConfig]): Crawler configuration.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[CrawlResult, None]: An async generator yielding crawl results.
|
||||
"""
|
||||
if config is None:
|
||||
raise ValueError("CrawlerRunConfig must be provided")
|
||||
|
||||
if config.stream:
|
||||
return self._arun_stream(start_url, crawler, config)
|
||||
else:
|
||||
return await self._arun_batch(start_url, crawler, config)
|
||||
|
||||
@abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Clean up resources used by the deep crawl strategy.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def can_process_url(self, url: str, depth: int) -> bool:
|
||||
"""
|
||||
Validate the URL format and apply custom filtering logic.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate.
|
||||
depth (int): The current depth in the crawl.
|
||||
|
||||
Returns:
|
||||
bool: True if the URL should be processed, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def link_discovery(
|
||||
self,
|
||||
result: CrawlResult,
|
||||
source_url: str,
|
||||
current_depth: int,
|
||||
visited: Set[str],
|
||||
next_level: List[tuple],
|
||||
depths: Dict[str, int],
|
||||
) -> None:
|
||||
"""
|
||||
Extract and process links from the given crawl result.
|
||||
|
||||
This method should:
|
||||
- Validate each extracted URL using can_process_url.
|
||||
- Optionally score URLs.
|
||||
- Append valid URLs (and their parent references) to the next_level list.
|
||||
- Update the depths dictionary with the new depth for each URL.
|
||||
|
||||
Args:
|
||||
result (CrawlResult): The result from a crawl operation.
|
||||
source_url (str): The URL from which this result was obtained.
|
||||
current_depth (int): The depth at which the source URL was processed.
|
||||
visited (Set[str]): Set of already visited URLs.
|
||||
next_level (List[tuple]): List of tuples (url, parent_url) for the next BFS level.
|
||||
depths (Dict[str, int]): Mapping of URLs to their current depth.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DeepCrawlDecorator",
|
||||
"DeepCrawlStrategy",
|
||||
"BFSDeepCrawlStrategy"
|
||||
]
|
||||
188
crawl4ai/deep_crawling/bfs_strategy.py
Normal file
188
crawl4ai/deep_crawling/bfs_strategy.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# bfs_deep_crawl_strategy.py
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..models import CrawlResult, TraversalStats
|
||||
from .filters import FastFilterChain
|
||||
from .scorers import FastURLScorer
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from . import DeepCrawlStrategy
|
||||
if TYPE_CHECKING:
|
||||
from ..async_configs import CrawlerRunConfig
|
||||
from ..async_webcrawler import AsyncWebCrawler
|
||||
|
||||
class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
||||
"""
|
||||
Breadth-First Search deep crawling strategy.
|
||||
|
||||
Core functions:
|
||||
- arun: Main entry point; splits execution into batch or stream modes.
|
||||
- link_discovery: Extracts, filters, and (if needed) scores the outgoing URLs.
|
||||
- can_process_url: Validates URL format and applies the filter chain.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: int,
|
||||
filter_chain: FastFilterChain = FastFilterChain(),
|
||||
url_scorer: Optional[FastURLScorer] = None,
|
||||
include_external: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
self.max_depth = max_depth
|
||||
self.filter_chain = filter_chain
|
||||
self.url_scorer = url_scorer
|
||||
self.include_external = include_external
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.stats = TraversalStats(start_time=datetime.now())
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
async def can_process_url(self, url: str, depth: int) -> bool:
|
||||
"""
|
||||
Validates the URL and applies the filter chain.
|
||||
For the start URL (depth 0) filtering is bypassed.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError("Missing scheme or netloc")
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError("Invalid scheme")
|
||||
if "." not in parsed.netloc:
|
||||
raise ValueError("Invalid domain")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Invalid URL: {url}, error: {e}")
|
||||
return False
|
||||
|
||||
if depth != 0 and not self.filter_chain.apply(url):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def link_discovery(
|
||||
self,
|
||||
result: CrawlResult,
|
||||
source_url: str,
|
||||
current_depth: int,
|
||||
visited: Set[str],
|
||||
next_level: List[Tuple[str, Optional[str]]],
|
||||
depths: Dict[str, int],
|
||||
) -> None:
|
||||
"""
|
||||
Extracts links from the crawl result, validates and scores them, and
|
||||
prepares the next level of URLs.
|
||||
Each valid URL is appended to next_level as a tuple (url, parent_url)
|
||||
and its depth is tracked.
|
||||
"""
|
||||
next_depth = current_depth + 1
|
||||
if next_depth > self.max_depth:
|
||||
return
|
||||
|
||||
# Get internal links and, if enabled, external links.
|
||||
links = result.links.get("internal", [])
|
||||
if self.include_external:
|
||||
links += result.links.get("external", [])
|
||||
|
||||
for link in links:
|
||||
url = link.get("href")
|
||||
if url in visited:
|
||||
continue
|
||||
if not await self.can_process_url(url, next_depth):
|
||||
self.stats.urls_skipped += 1
|
||||
continue
|
||||
|
||||
# Score the URL if a scorer is provided. In this simple BFS
|
||||
# the score is not used for ordering.
|
||||
score = self.url_scorer.score(url) if self.url_scorer else 0
|
||||
# attach the score to metadata if needed.
|
||||
if score:
|
||||
result.metadata = result.metadata or {}
|
||||
result.metadata["score"] = score
|
||||
next_level.append((url, source_url))
|
||||
depths[url] = next_depth
|
||||
|
||||
async def _arun_batch(
|
||||
self,
|
||||
start_url: str,
|
||||
crawler: "AsyncWebCrawler",
|
||||
config: "CrawlerRunConfig",
|
||||
) -> List[CrawlResult]:
|
||||
"""
|
||||
Batch (non-streaming) mode:
|
||||
Processes one BFS level at a time, then yields all the results.
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
# current_level holds tuples: (url, parent_url)
|
||||
current_level: List[Tuple[str, Optional[str]]] = [(start_url, None)]
|
||||
depths: Dict[str, int] = {start_url: 0}
|
||||
|
||||
results: List[CrawlResult] = []
|
||||
|
||||
while current_level and not self._cancel_event.is_set():
|
||||
next_level: List[Tuple[str, Optional[str]]] = []
|
||||
urls = [url for url, _ in current_level]
|
||||
visited.update(urls)
|
||||
|
||||
# Clone the config to disable deep crawling recursion and enforce batch mode.
|
||||
batch_config = config.clone(deep_crawl_strategy=None, stream=False)
|
||||
batch_results = await crawler.arun_many(urls=urls, config=batch_config)
|
||||
|
||||
for result in batch_results:
|
||||
url = result.url
|
||||
depth = depths.get(url, 0)
|
||||
result.metadata = result.metadata or {}
|
||||
result.metadata["depth"] = depth
|
||||
# Retrieve parent_url from current_level.
|
||||
parent_url = next((parent for (u, parent) in current_level if u == url), None)
|
||||
result.metadata["parent_url"] = parent_url
|
||||
results.append(result)
|
||||
await self.link_discovery(result, url, depth, visited, next_level, depths)
|
||||
|
||||
current_level = next_level
|
||||
|
||||
return results
|
||||
|
||||
async def _arun_stream(
|
||||
self,
|
||||
start_url: str,
|
||||
crawler: "AsyncWebCrawler",
|
||||
config: "CrawlerRunConfig",
|
||||
) -> AsyncGenerator[CrawlResult, None]:
|
||||
"""
|
||||
Streaming mode:
|
||||
Processes one BFS level at a time and yields results immediately as they arrive.
|
||||
"""
|
||||
visited: Set[str] = set()
|
||||
current_level: List[Tuple[str, Optional[str]]] = [(start_url, None)]
|
||||
depths: Dict[str, int] = {start_url: 0}
|
||||
|
||||
while current_level and not self._cancel_event.is_set():
|
||||
next_level: List[Tuple[str, Optional[str]]] = []
|
||||
urls = [url for url, _ in current_level]
|
||||
visited.update(urls)
|
||||
|
||||
stream_config = config.clone(deep_crawl_strategy=None, stream=True)
|
||||
stream_gen = await crawler.arun_many(urls=urls, config=stream_config)
|
||||
async for result in stream_gen:
|
||||
url = result.url
|
||||
depth = depths.get(url, 0)
|
||||
result.metadata = result.metadata or {}
|
||||
result.metadata["depth"] = depth
|
||||
parent_url = next((parent for (u, parent) in current_level if u == url), None)
|
||||
result.metadata["parent_url"] = parent_url
|
||||
yield result
|
||||
await self.link_discovery(result, url, depth, visited, next_level, depths)
|
||||
|
||||
current_level = next_level
|
||||
|
||||
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Clean up resources and signal cancellation of the crawl.
|
||||
"""
|
||||
self._cancel_event.set()
|
||||
self.stats.end_time = datetime.now()
|
||||
432
crawl4ai/deep_crawling/crazy.py
Normal file
432
crawl4ai/deep_crawling/crazy.py
Normal file
@@ -0,0 +1,432 @@
|
||||
from __future__ import annotations
|
||||
# I just got crazy, trying to wrute K&R C but in Python. Right now I feel like I'm in a quantum state.
|
||||
|
||||
# from typing import TYPE_CHECKING
|
||||
from functools import wraps
|
||||
from contextvars import ContextVar
|
||||
import inspect
|
||||
|
||||
from httpx import get
|
||||
from crawl4ai import CacheMode
|
||||
from crawl4ai.async_configs import CrawlerRunConfig
|
||||
from crawl4ai.models import CrawlResult, TraversalStats
|
||||
from crawl4ai.deep_crawling.filters import FastFilterChain
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
import time
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
import asyncio
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
List,
|
||||
TypeVar,
|
||||
Generic,
|
||||
Tuple,
|
||||
Callable,
|
||||
Awaitable,
|
||||
Union,
|
||||
)
|
||||
from functools import lru_cache
|
||||
import mmh3
|
||||
from bitarray import bitarray
|
||||
import numpy as np
|
||||
from heapq import heappush, heappop
|
||||
|
||||
# ------ Type Algebra Mastery ------ #
|
||||
CrawlResultT = TypeVar("CrawlResultT", bound="CrawlResult")
|
||||
PriorityT = TypeVar("PriorityT")
|
||||
P = TypeVar("P")
|
||||
|
||||
# ------ Hyperscalar Context Management ------ #
|
||||
deep_crawl_ctx = ContextVar("deep_crawl_stack", default=deque())
|
||||
|
||||
# ------ Algebraic Crawler Monoid ------ #
|
||||
class TraversalContext:
|
||||
__slots__ = ('visited', 'frontier', 'depths', 'priority_fn', 'current_depth')
|
||||
|
||||
def __init__(self,
|
||||
priority_fn: Callable[[str], Awaitable[float]] = lambda _: 1.0):
|
||||
self.visited: BloomFilter = BloomFilter(10**6, 0.01) # 1M items, 1% FP
|
||||
self.frontier: PriorityQueue = PriorityQueue()
|
||||
self.depths: Dict[str, int] = {}
|
||||
self.priority_fn = priority_fn
|
||||
self.current_depth = 0
|
||||
|
||||
def clone_for_level(self) -> TraversalContext:
|
||||
"""Monadic context propagation"""
|
||||
new_ctx = TraversalContext(self.priority_fn)
|
||||
new_ctx.visited = self.visited.copy()
|
||||
new_ctx.depths = self.depths.copy()
|
||||
new_ctx.current_depth = self.current_depth
|
||||
return new_ctx
|
||||
|
||||
class PriorityQueue(Generic[PriorityT]):
|
||||
"""Fibonacci heap-inspired priority queue with O(1) amortized operations"""
|
||||
__slots__ = ('_heap', '_index')
|
||||
|
||||
def __init__(self):
|
||||
self._heap: List[Tuple[PriorityT, float, P]] = []
|
||||
self._index: Dict[P, int] = {}
|
||||
|
||||
def insert(self, priority: PriorityT, item: P) -> None:
|
||||
tiebreaker = time.time() # Ensure FIFO for equal priorities
|
||||
heappush(self._heap, (priority, tiebreaker, item))
|
||||
self._index[item] = len(self._heap) - 1
|
||||
|
||||
def extract(self, top_n = 1) -> P:
|
||||
items = []
|
||||
for _ in range(top_n):
|
||||
if not self._heap:
|
||||
break
|
||||
priority, _, item = heappop(self._heap)
|
||||
del self._index[item]
|
||||
items.append(item)
|
||||
if not items:
|
||||
raise IndexError("Priority queue empty")
|
||||
return items
|
||||
# while self._heap:
|
||||
# _, _, item = heappop(self._heap)
|
||||
# if item in self._index:
|
||||
# del self._index[item]
|
||||
# return item
|
||||
raise IndexError("Priority queue empty")
|
||||
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not bool(self._heap)
|
||||
|
||||
class BloomFilter:
|
||||
"""Optimal Bloom filter using murmur3 hash avalanche"""
|
||||
__slots__ = ('size', 'hashes', 'bits')
|
||||
|
||||
def __init__(self, capacity: int, error_rate: float):
|
||||
self.size = self._optimal_size(capacity, error_rate)
|
||||
self.hashes = self._optimal_hashes(capacity, self.size)
|
||||
self.bits = bitarray(self.size)
|
||||
self.bits.setall(False)
|
||||
|
||||
@staticmethod
|
||||
def _optimal_size(n: int, p: float) -> int:
|
||||
m = - (n * np.log(p)) / (np.log(2) ** 2)
|
||||
return int(np.ceil(m))
|
||||
|
||||
@staticmethod
|
||||
def _optimal_hashes(n: int, m: int) -> int:
|
||||
k = (m / n) * np.log(2)
|
||||
return int(np.ceil(k))
|
||||
|
||||
def add(self, item: str) -> None:
|
||||
for seed in range(self.hashes):
|
||||
digest = mmh3.hash(item, seed) % self.size
|
||||
self.bits[digest] = True
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return all(
|
||||
self.bits[mmh3.hash(item, seed) % self.size]
|
||||
for seed in range(self.hashes)
|
||||
)
|
||||
|
||||
def copy(self) -> BloomFilter:
|
||||
new = object.__new__(BloomFilter)
|
||||
new.size = self.size
|
||||
new.hashes = self.hashes
|
||||
new.bits = self.bits.copy()
|
||||
return new
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Estimates the number of items in the filter using the
|
||||
count of set bits and the formula:
|
||||
n = -m/k * ln(1 - X/m)
|
||||
where:
|
||||
m = size of bit array
|
||||
k = number of hash functions
|
||||
X = count of set bits
|
||||
"""
|
||||
set_bits = self.bits.count(True)
|
||||
if set_bits == 0:
|
||||
return 0
|
||||
|
||||
# Use the inverse bloom filter formula to estimate cardinality
|
||||
return int(
|
||||
-(self.size / self.hashes) *
|
||||
np.log(1 - set_bits / self.size)
|
||||
)
|
||||
|
||||
def bit_count(self) -> int:
|
||||
"""Returns the raw count of set bits in the filter"""
|
||||
return self.bits.count(True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"BloomFilter(est_items={len(self)}, bits={self.bit_count()}/{self.size})"
|
||||
|
||||
# ------ Hyper-Optimal Deep Crawl Core ------ #
|
||||
class DeepCrawlDecorator:
|
||||
"""Metaprogramming marvel: Zero-cost deep crawl abstraction"""
|
||||
def __init__(self, crawler: AsyncWebCrawler):
|
||||
self.crawler = crawler
|
||||
|
||||
def __call__(self, original_arun: Callable) -> Callable:
|
||||
@wraps(original_arun)
|
||||
async def quantum_arun(url: str, config: CrawlerRunConfig = None, **kwargs):
|
||||
stack = deep_crawl_ctx.get()
|
||||
if config and config.deep_crawl_strategy and not stack:
|
||||
stack.append(self.crawler)
|
||||
try:
|
||||
deep_crawl_ctx.set(stack)
|
||||
async for result in config.deep_crawl_strategy.traverse(
|
||||
start_url=url,
|
||||
crawler=self.crawler,
|
||||
config=config
|
||||
):
|
||||
yield result
|
||||
finally:
|
||||
stack.pop()
|
||||
deep_crawl_ctx.set(stack)
|
||||
else:
|
||||
result = await original_arun(url, config=config, **kwargs)
|
||||
yield result
|
||||
return quantum_arun
|
||||
|
||||
|
||||
async def collect_results(url, crawler, config):
|
||||
if id(getattr(crawler, "arun")) != id(getattr(crawler, "original_arun")):
|
||||
setattr(crawler, "arun", getattr(crawler, "original_arun"))
|
||||
|
||||
ret = crawler.arun(url, config=config)
|
||||
# If arun is an async generator, iterate over it
|
||||
if inspect.isasyncgen(ret):
|
||||
return [r async for r in ret]
|
||||
# Otherwise, await the coroutine and normalize to a list
|
||||
result = await ret
|
||||
return result if isinstance(result, list) else [result]
|
||||
|
||||
async def collect_many_results(url, crawler, config):
|
||||
# Replace back arun to its original implementation
|
||||
if id(getattr(crawler, "arun")) != id(getattr(crawler, "original_arun")):
|
||||
setattr(crawler, "arun", getattr(crawler, "original_arun"))
|
||||
ret = crawler.arun_many(url, config=config)
|
||||
# If arun is an async generator, iterate over it
|
||||
if inspect.isasyncgen(ret):
|
||||
return [r async for r in ret]
|
||||
# Otherwise, await the coroutine and normalize to a list
|
||||
result = await ret
|
||||
return result if isinstance(result, list) else [result]
|
||||
|
||||
|
||||
# ------ Deep Crawl Strategy Interface ------ #
|
||||
CrawlResultT = TypeVar("CrawlResultT", bound=CrawlResult)
|
||||
# In batch mode we return List[CrawlResult] and in stream mode an AsyncGenerator.
|
||||
RunManyReturn = Union[CrawlResultT, List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
|
||||
|
||||
|
||||
class DeepCrawlStrategy(ABC):
|
||||
"""Abstract base class that will make Dijkstra smile"""
|
||||
@abstractmethod
|
||||
async def traverse(self,
|
||||
start_url: str,
|
||||
crawler: AsyncWebCrawler,
|
||||
config: CrawlerRunConfig) -> RunManyReturn:
|
||||
"""Traverse with O(1) memory complexity via generator fusion"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def precompute_priority(self, url: str) -> Awaitable[float]:
|
||||
"""Quantum-inspired priority precomputation"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def link_hypercube(self, result: CrawlResult) -> AsyncGenerator[str, None]:
|
||||
"""Hilbert-curve optimized link generation"""
|
||||
pass
|
||||
|
||||
# ------ BFS That Would Make Knuth Proud ------ #
|
||||
|
||||
def calculate_quantum_batch_size(
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
frontier_size: int,
|
||||
visited_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Calculates optimal batch size for URL processing using quantum-inspired mathematical principles.
|
||||
|
||||
This function implements a sophisticated batch size calculation using:
|
||||
1. Golden Ratio (φ) based scaling for optimal irrationality
|
||||
2. Depth-aware amplitude modulation
|
||||
3. Harmonic series dampening
|
||||
4. Logarithmic growth control
|
||||
5. Dynamic frontier adaptation
|
||||
|
||||
The formula follows the quantum harmonic oscillator principle:
|
||||
N = ⌈φ^(2d) * log₂(|V|) * H(d)⁻¹ * min(20, |F|/10)⌉
|
||||
where:
|
||||
φ = Golden Ratio ((1 + √5) / 2)
|
||||
d = depth factor (normalized remaining depth)
|
||||
|V| = size of visited set
|
||||
H(d) = d-th harmonic number
|
||||
|F| = frontier size
|
||||
|
||||
Args:
|
||||
depth (int): Current traversal depth
|
||||
max_depth (int): Maximum allowed depth
|
||||
frontier_size (int): Current size of frontier queue
|
||||
visited_size (int): Number of URLs visited so far
|
||||
|
||||
Returns:
|
||||
int: Optimal batch size bounded between 1 and 100
|
||||
|
||||
Mathematical Properties:
|
||||
- Maintains O(log n) growth with respect to visited size
|
||||
- Provides φ-optimal distribution of resources
|
||||
- Ensures quantum-like state transitions between depths
|
||||
- Harmonically dampened to prevent exponential explosion
|
||||
"""
|
||||
# Golden ratio φ = (1 + √5) / 2
|
||||
φ = (1 + 5 ** 0.5) / 2
|
||||
|
||||
# Calculate normalized depth factor [0, 1]
|
||||
depth_factor = (max_depth - depth) / max_depth if depth < max_depth else 0
|
||||
|
||||
# Compute harmonic number for current depth
|
||||
harmonic = sum(1/k for k in range(1, depth + 2))
|
||||
|
||||
# Calculate quantum batch size
|
||||
batch_size = int(np.ceil(
|
||||
(φ ** (depth_factor * 2)) * # Golden ratio scaling
|
||||
np.log2(visited_size + 2) * # Logarithmic growth factor
|
||||
(1 / harmonic) * # Harmonic dampening
|
||||
max(1, min(20, frontier_size / 10)) # Frontier-aware scaling
|
||||
))
|
||||
|
||||
# Enforce practical bounds
|
||||
return max(1, min(100, batch_size))
|
||||
|
||||
|
||||
class BFSDeepCrawlStrategy(DeepCrawlStrategy):
|
||||
"""Breadth-First Search with Einstein-Rosen bridge optimization"""
|
||||
__slots__ = ('max_depth', 'filter_chain', 'priority_fn', 'stats', '_cancel')
|
||||
|
||||
def __init__(self,
|
||||
max_depth: int,
|
||||
filter_chain: FastFilterChain = FastFilterChain(),
|
||||
priority_fn: Callable[[str], Awaitable[float]] = lambda url: 1.0,
|
||||
logger: logging.Logger = None):
|
||||
self.max_depth = max_depth
|
||||
self.filter_chain = filter_chain
|
||||
self.priority_fn = priority_fn
|
||||
self.stats = TraversalStats()
|
||||
self._cancel = asyncio.Event()
|
||||
self.semaphore = asyncio.Semaphore(1000)
|
||||
|
||||
async def traverse(self,
|
||||
start_url: str,
|
||||
crawler: AsyncWebCrawler,
|
||||
config: CrawlerRunConfig) -> RunManyReturn:
|
||||
"""Non-blocking BFS with O(b^d) time complexity awareness"""
|
||||
ctx = TraversalContext(self.priority_fn)
|
||||
ctx.frontier.insert(self.priority_fn(start_url), (start_url, None, 0))
|
||||
ctx.visited.add(start_url)
|
||||
ctx.depths[start_url] = 0
|
||||
|
||||
while not ctx.frontier.is_empty() and not self._cancel.is_set():
|
||||
# Use the best algorith, to find top_n value
|
||||
top_n = calculate_quantum_batch_size(
|
||||
depth=ctx.current_depth,
|
||||
max_depth=self.max_depth,
|
||||
frontier_size=len(ctx.frontier._heap),
|
||||
visited_size=len(ctx.visited)
|
||||
)
|
||||
|
||||
urls = ctx.frontier.extract(top_n=top_n)
|
||||
# url, parent, depth = ctx.frontier.extract(top_n=top_n)
|
||||
if urls:
|
||||
ctx.current_depth = urls[0][2]
|
||||
|
||||
async with self.semaphore:
|
||||
results = await collect_many_results([url for (url, parent, depth) in urls], crawler, config)
|
||||
# results = await asyncio.gather(*[
|
||||
# collect_results(url, crawler, config) for (url, parent, depth) in urls
|
||||
# ])
|
||||
# result = _result[0]
|
||||
for ix, result in enumerate(results):
|
||||
url, parent, depth = result.url, urls[ix][1], urls[ix][2]
|
||||
result.metadata['depth'] = depth
|
||||
result.metadata['parent'] = parent
|
||||
yield result
|
||||
|
||||
if depth < self.max_depth:
|
||||
async for link in self.link_hypercube(result):
|
||||
if link not in ctx.visited:
|
||||
priority = self.priority_fn(link)
|
||||
ctx.frontier.insert(priority, (link, url, depth + 1))
|
||||
ctx.visited.add(link)
|
||||
ctx.depths[link] = depth + 1
|
||||
|
||||
@lru_cache(maxsize=65536)
|
||||
async def validate_url(self, url: str) -> bool:
|
||||
"""Memoized URL validation with λ-calculus purity"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return (parsed.scheme in {'http', 'https'}
|
||||
and '.' in parsed.netloc
|
||||
and self.filter_chain.apply(url))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def link_hypercube(self, result: CrawlResult) -> AsyncGenerator[str, None]:
|
||||
"""Hilbert-ordered link generation with O(1) yield latency"""
|
||||
links = (link['href'] for link in result.links.get('internal', []))
|
||||
validated = filter(self.validate_url, links)
|
||||
for link in sorted(validated, key=lambda x: -self.priority_fn(x)):
|
||||
yield link
|
||||
|
||||
def __aiter__(self) -> AsyncGenerator[CrawlResult, None]:
|
||||
"""Native async iterator interface"""
|
||||
return self.traverse()
|
||||
|
||||
async def __anext__(self) -> CrawlResult:
|
||||
"""True async iterator protocol implementation"""
|
||||
result = await self.traverse().__anext__()
|
||||
if result:
|
||||
return result
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def precompute_priority(self, url):
|
||||
return super().precompute_priority(url)
|
||||
|
||||
async def shutdown(self):
|
||||
self._cancel.set()
|
||||
|
||||
# ------ Usage That Will Drop Jaws ------ #
|
||||
async def main():
|
||||
"""Quantum crawl example"""
|
||||
strategy = BFSDeepCrawlStrategy(
|
||||
max_depth=2,
|
||||
priority_fn=lambda url: 1.0 / (len(url) + 1e-9), # Inverse length priority
|
||||
# filter_chain=FastFilterChain(...)
|
||||
)
|
||||
|
||||
config: CrawlerRunConfig = CrawlerRunConfig(
|
||||
deep_crawl_strategy=strategy,
|
||||
stream=False,
|
||||
verbose=True,
|
||||
cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
run_decorator = DeepCrawlDecorator(crawler)
|
||||
setattr(crawler, "original_arun", crawler.arun)
|
||||
crawler.arun = run_decorator(crawler.arun)
|
||||
start_time = time.perf_counter()
|
||||
async for result in crawler.arun("https://docs.crawl4ai.com", config=config):
|
||||
print(f"🌀 {result.url} (Depth: {result.metadata['depth']})")
|
||||
print(f"Deep crawl completed in {time.perf_counter() - start_time:.2f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
868
crawl4ai/deep_crawling/filters.py
Normal file
868
crawl4ai/deep_crawling/filters.py
Normal file
@@ -0,0 +1,868 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Pattern, Set, Union, FrozenSet
|
||||
import re, time
|
||||
from urllib.parse import urlparse
|
||||
from array import array
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
import fnmatch
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
import weakref
|
||||
import mimetypes
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterStats:
|
||||
# PERF: Using dataclass creates overhead with __init__ and property access
|
||||
# PERF: Could use __slots__ to reduce memory footprint
|
||||
# PERF: Consider using array.array('I') for atomic increments
|
||||
total_urls: int = 0
|
||||
rejected_urls: int = 0
|
||||
passed_urls: int = 0
|
||||
|
||||
|
||||
class URLFilter(ABC):
|
||||
# PERF: Logger creation is expensive, consider lazy initialization
|
||||
# PERF: stats object creation adds overhead for each filter instance
|
||||
def __init__(self, name: str = None):
|
||||
self.name = name or self.__class__.__name__
|
||||
self.stats = FilterStats()
|
||||
self.logger = logging.getLogger(f"urlfilter.{self.name}")
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, url: str) -> bool:
|
||||
pass
|
||||
|
||||
def _update_stats(self, passed: bool):
|
||||
# PERF: Already optimized but could use bitwise operations
|
||||
# PERF: Consider removing stats entirely in production/fast mode
|
||||
self.stats.total_urls += 1
|
||||
self.stats.passed_urls += passed
|
||||
self.stats.rejected_urls += not passed
|
||||
|
||||
|
||||
class FilterChain:
|
||||
# PERF: List traversal for each URL is expensive
|
||||
# PERF: Could use array.array instead of list for filters
|
||||
# PERF: Consider adding fast path for single filter case
|
||||
def __init__(self, filters: List[URLFilter] = None):
|
||||
self.filters = filters or []
|
||||
self.stats = FilterStats()
|
||||
self.logger = logging.getLogger("urlfilter.chain")
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
# PERF: Logging on every rejection is expensive
|
||||
# PERF: Could reorder filters by rejection rate
|
||||
# PERF: Consider batch processing mode
|
||||
self.stats.total_urls += 1
|
||||
|
||||
for filter_ in self.filters:
|
||||
if not filter_.apply(url):
|
||||
self.stats.rejected_urls += 1
|
||||
self.logger.debug(f"URL {url} rejected by {filter_.name}")
|
||||
return False
|
||||
|
||||
self.stats.passed_urls += 1
|
||||
return True
|
||||
|
||||
|
||||
class URLPatternFilter(URLFilter):
|
||||
# PERF: Converting glob to regex is expensive
|
||||
# PERF: Multiple regex compilation is slow
|
||||
# PERF: List of patterns causes multiple regex evaluations
|
||||
def __init__(
|
||||
self,
|
||||
patterns: Union[str, Pattern, List[Union[str, Pattern]]],
|
||||
use_glob: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patterns = [patterns] if isinstance(patterns, (str, Pattern)) else patterns
|
||||
self.use_glob = use_glob
|
||||
self._compiled_patterns = []
|
||||
|
||||
# PERF: This could be consolidated into a single regex with OR conditions
|
||||
# PERF: glob_to_regex creates complex patterns, could be simplified
|
||||
for pattern in self.patterns:
|
||||
if isinstance(pattern, str) and use_glob:
|
||||
self._compiled_patterns.append(self._glob_to_regex(pattern))
|
||||
else:
|
||||
self._compiled_patterns.append(
|
||||
re.compile(pattern) if isinstance(pattern, str) else pattern
|
||||
)
|
||||
|
||||
def _glob_to_regex(self, pattern: str) -> Pattern:
|
||||
# PERF: fnmatch.translate creates overly complex patterns
|
||||
# PERF: Could cache common translations
|
||||
return re.compile(fnmatch.translate(pattern))
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
# PERF: any() with generator is slower than direct loop with early return
|
||||
# PERF: searching entire string is slower than anchored match
|
||||
matches = any(pattern.search(url) for pattern in self._compiled_patterns)
|
||||
self._update_stats(matches)
|
||||
return matches
|
||||
|
||||
|
||||
class ContentTypeFilter(URLFilter):
|
||||
# PERF: mimetypes guessing is extremely slow
|
||||
# PERF: URL parsing on every check is expensive
|
||||
# PERF: No caching of results for similar extensions
|
||||
def __init__(
|
||||
self, allowed_types: Union[str, List[str]], check_extension: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
self.allowed_types = (
|
||||
[allowed_types] if isinstance(allowed_types, str) else allowed_types
|
||||
)
|
||||
self.check_extension = check_extension
|
||||
self._normalize_types()
|
||||
|
||||
def _normalize_types(self):
|
||||
"""Normalize content type strings"""
|
||||
self.allowed_types = [t.lower() for t in self.allowed_types]
|
||||
|
||||
def _check_extension(self, url: str) -> bool:
|
||||
# PERF: urlparse is called on every check
|
||||
# PERF: multiple string splits are expensive
|
||||
# PERF: mimetypes.guess_type is very slow
|
||||
ext = (
|
||||
urlparse(url).path.split(".")[-1].lower()
|
||||
if "." in urlparse(url).path
|
||||
else ""
|
||||
)
|
||||
if not ext:
|
||||
return True
|
||||
|
||||
# PERF: guess_type is main bottleneck
|
||||
guessed_type = mimetypes.guess_type(url)[0]
|
||||
return any(
|
||||
allowed in (guessed_type or "").lower() for allowed in self.allowed_types
|
||||
)
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
"""Check if URL's content type is allowed"""
|
||||
result = True
|
||||
if self.check_extension:
|
||||
result = self._check_extension(url)
|
||||
self._update_stats(result)
|
||||
return result
|
||||
|
||||
|
||||
class DomainFilter(URLFilter):
|
||||
# PERF: Set lookups are fast but string normalizations on init are not
|
||||
# PERF: Creating two sets doubles memory usage
|
||||
def __init__(
|
||||
self,
|
||||
allowed_domains: Union[str, List[str]] = None,
|
||||
blocked_domains: Union[str, List[str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# PERF: Normalizing domains on every init is wasteful
|
||||
# PERF: Could use frozenset for immutable lists
|
||||
self.allowed_domains = (
|
||||
set(self._normalize_domains(allowed_domains)) if allowed_domains else None
|
||||
)
|
||||
self.blocked_domains = (
|
||||
set(self._normalize_domains(blocked_domains)) if blocked_domains else set()
|
||||
)
|
||||
|
||||
def _normalize_domains(self, domains: Union[str, List[str]]) -> List[str]:
|
||||
# PERF: strip() and lower() create new strings for each domain
|
||||
# PERF: List comprehension creates intermediate list
|
||||
if isinstance(domains, str):
|
||||
domains = [domains]
|
||||
return [d.lower().strip() for d in domains]
|
||||
|
||||
def _extract_domain(self, url: str) -> str:
|
||||
# PERF: urlparse is called for every URL check
|
||||
# PERF: lower() creates new string every time
|
||||
# PERF: Could cache recent results
|
||||
return urlparse(url).netloc.lower()
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
# PERF: Two separate set lookups in worst case
|
||||
# PERF: Domain extraction happens before knowing if we have any filters
|
||||
domain = self._extract_domain(url)
|
||||
|
||||
if domain in self.blocked_domains:
|
||||
self._update_stats(False)
|
||||
return False
|
||||
|
||||
if self.allowed_domains is not None and domain not in self.allowed_domains:
|
||||
self._update_stats(False)
|
||||
return False
|
||||
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
|
||||
# Example usage:
|
||||
def create_common_filter_chain() -> FilterChain:
|
||||
"""Create a commonly used filter chain"""
|
||||
return FilterChain(
|
||||
[
|
||||
URLPatternFilter(
|
||||
[
|
||||
"*.html",
|
||||
"*.htm", # HTML files
|
||||
"*/article/*",
|
||||
"*/blog/*", # Common content paths
|
||||
]
|
||||
),
|
||||
ContentTypeFilter(["text/html", "application/xhtml+xml"]),
|
||||
DomainFilter(blocked_domains=["ads.*", "analytics.*"]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
####################################################################################
|
||||
# Uncledoe: Optimized Version
|
||||
####################################################################################
|
||||
|
||||
|
||||
# Use __slots__ and array for maximum memory/speed efficiency
|
||||
class FastFilterStats:
|
||||
__slots__ = ("_counters",)
|
||||
|
||||
def __init__(self):
|
||||
# Use array of unsigned ints for atomic operations
|
||||
self._counters = array("I", [0, 0, 0]) # total, passed, rejected
|
||||
|
||||
@property
|
||||
def total_urls(self):
|
||||
return self._counters[0]
|
||||
|
||||
@property
|
||||
def passed_urls(self):
|
||||
return self._counters[1]
|
||||
|
||||
@property
|
||||
def rejected_urls(self):
|
||||
return self._counters[2]
|
||||
|
||||
|
||||
class FastURLFilter(ABC):
|
||||
"""Optimized base filter class"""
|
||||
|
||||
__slots__ = ("name", "stats", "_logger_ref")
|
||||
|
||||
def __init__(self, name: str = None):
|
||||
self.name = name or self.__class__.__name__
|
||||
self.stats = FastFilterStats()
|
||||
# Lazy logger initialization using weakref
|
||||
self._logger_ref = None
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
if self._logger_ref is None or self._logger_ref() is None:
|
||||
logger = logging.getLogger(f"urlfilter.{self.name}")
|
||||
self._logger_ref = weakref.ref(logger)
|
||||
return self._logger_ref()
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, url: str) -> bool:
|
||||
pass
|
||||
|
||||
def _update_stats(self, passed: bool):
|
||||
# Use direct array index for speed
|
||||
self.stats._counters[0] += 1 # total
|
||||
self.stats._counters[1] += passed # passed
|
||||
self.stats._counters[2] += not passed # rejected
|
||||
|
||||
|
||||
class FastFilterChain:
|
||||
"""Optimized filter chain"""
|
||||
|
||||
__slots__ = ("filters", "stats", "_logger_ref")
|
||||
|
||||
def __init__(self, filters: List[FastURLFilter] = None):
|
||||
self.filters = tuple(filters or []) # Immutable tuple for speed
|
||||
self.stats = FastFilterStats()
|
||||
self._logger_ref = None
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
if self._logger_ref is None or self._logger_ref() is None:
|
||||
logger = logging.getLogger("urlfilter.chain")
|
||||
self._logger_ref = weakref.ref(logger)
|
||||
return self._logger_ref()
|
||||
|
||||
def add_filter(self, filter_: FastURLFilter) -> "FastFilterChain":
|
||||
"""Add a filter to the chain"""
|
||||
self.filters.append(filter_)
|
||||
return self # Enable method chaining
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
"""Optimized apply with minimal operations"""
|
||||
self.stats._counters[0] += 1 # total
|
||||
|
||||
# Direct tuple iteration is faster than list
|
||||
for f in self.filters:
|
||||
if not f.apply(url):
|
||||
self.stats._counters[2] += 1 # rejected
|
||||
return False
|
||||
|
||||
self.stats._counters[1] += 1 # passed
|
||||
return True
|
||||
|
||||
class FastURLPatternFilter(FastURLFilter):
|
||||
"""Pattern filter balancing speed and completeness"""
|
||||
__slots__ = ('_simple_suffixes', '_simple_prefixes', '_domain_patterns', '_path_patterns')
|
||||
|
||||
PATTERN_TYPES = {
|
||||
'SUFFIX': 1, # *.html
|
||||
'PREFIX': 2, # /foo/*
|
||||
'DOMAIN': 3, # *.example.com
|
||||
'PATH': 4 , # Everything else
|
||||
'REGEX': 5
|
||||
}
|
||||
|
||||
def __init__(self, patterns: Union[str, Pattern, List[Union[str, Pattern]]], use_glob: bool = True):
|
||||
super().__init__()
|
||||
patterns = [patterns] if isinstance(patterns, (str, Pattern)) else patterns
|
||||
|
||||
self._simple_suffixes = set()
|
||||
self._simple_prefixes = set()
|
||||
self._domain_patterns = []
|
||||
self._path_patterns = []
|
||||
|
||||
for pattern in patterns:
|
||||
pattern_type = self._categorize_pattern(pattern)
|
||||
self._add_pattern(pattern, pattern_type)
|
||||
|
||||
def _categorize_pattern(self, pattern: str) -> int:
|
||||
"""Categorize pattern for specialized handling"""
|
||||
if not isinstance(pattern, str):
|
||||
return self.PATTERN_TYPES['PATH']
|
||||
|
||||
# Check if it's a regex pattern
|
||||
if pattern.startswith('^') or pattern.endswith('$') or '\\d' in pattern:
|
||||
return self.PATTERN_TYPES['REGEX']
|
||||
|
||||
if pattern.count('*') == 1:
|
||||
if pattern.startswith('*.'):
|
||||
return self.PATTERN_TYPES['SUFFIX']
|
||||
if pattern.endswith('/*'):
|
||||
return self.PATTERN_TYPES['PREFIX']
|
||||
|
||||
if '://' in pattern and pattern.startswith('*.'):
|
||||
return self.PATTERN_TYPES['DOMAIN']
|
||||
|
||||
return self.PATTERN_TYPES['PATH']
|
||||
|
||||
def _add_pattern(self, pattern: str, pattern_type: int):
|
||||
"""Add pattern to appropriate matcher"""
|
||||
if pattern_type == self.PATTERN_TYPES['REGEX']:
|
||||
# For regex patterns, compile directly without glob translation
|
||||
if isinstance(pattern, str) and (pattern.startswith('^') or pattern.endswith('$') or '\\d' in pattern):
|
||||
self._path_patterns.append(re.compile(pattern))
|
||||
return
|
||||
elif pattern_type == self.PATTERN_TYPES['SUFFIX']:
|
||||
self._simple_suffixes.add(pattern[2:])
|
||||
elif pattern_type == self.PATTERN_TYPES['PREFIX']:
|
||||
self._simple_prefixes.add(pattern[:-2])
|
||||
elif pattern_type == self.PATTERN_TYPES['DOMAIN']:
|
||||
self._domain_patterns.append(
|
||||
re.compile(pattern.replace('*.', r'[^/]+\.'))
|
||||
)
|
||||
else:
|
||||
if isinstance(pattern, str):
|
||||
# Handle complex glob patterns
|
||||
if '**' in pattern:
|
||||
pattern = pattern.replace('**', '.*')
|
||||
if '{' in pattern:
|
||||
# Convert {a,b} to (a|b)
|
||||
pattern = re.sub(r'\{([^}]+)\}',
|
||||
lambda m: f'({"|".join(m.group(1).split(","))})',
|
||||
pattern)
|
||||
pattern = fnmatch.translate(pattern)
|
||||
self._path_patterns.append(
|
||||
pattern if isinstance(pattern, Pattern) else re.compile(pattern)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=10000)
|
||||
def apply(self, url: str) -> bool:
|
||||
"""Hierarchical pattern matching"""
|
||||
# Quick suffix check (*.html)
|
||||
if self._simple_suffixes:
|
||||
path = url.split('?')[0]
|
||||
if path.split('/')[-1].split('.')[-1] in self._simple_suffixes:
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
# Domain check
|
||||
if self._domain_patterns:
|
||||
for pattern in self._domain_patterns:
|
||||
if pattern.match(url):
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
# Prefix check (/foo/*)
|
||||
if self._simple_prefixes:
|
||||
path = url.split('?')[0]
|
||||
if any(path.startswith(p) for p in self._simple_prefixes):
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
# Complex patterns
|
||||
if self._path_patterns:
|
||||
if any(p.search(url) for p in self._path_patterns):
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
self._update_stats(False)
|
||||
return False
|
||||
|
||||
|
||||
class FastContentTypeFilter(FastURLFilter):
|
||||
"""Optimized content type filter using fast lookups"""
|
||||
|
||||
__slots__ = ("allowed_types", "_ext_map", "_check_extension")
|
||||
|
||||
# Fast extension to mime type mapping
|
||||
_MIME_MAP = {
|
||||
# Text Formats
|
||||
"txt": "text/plain",
|
||||
"html": "text/html",
|
||||
"htm": "text/html",
|
||||
"xhtml": "application/xhtml+xml",
|
||||
"css": "text/css",
|
||||
"csv": "text/csv",
|
||||
"ics": "text/calendar",
|
||||
"js": "application/javascript",
|
||||
# Images
|
||||
"bmp": "image/bmp",
|
||||
"gif": "image/gif",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"svg": "image/svg+xml",
|
||||
"tiff": "image/tiff",
|
||||
"ico": "image/x-icon",
|
||||
"webp": "image/webp",
|
||||
# Audio
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"m4a": "audio/mp4",
|
||||
"aac": "audio/aac",
|
||||
# Video
|
||||
"mp4": "video/mp4",
|
||||
"mpeg": "video/mpeg",
|
||||
"webm": "video/webm",
|
||||
"avi": "video/x-msvideo",
|
||||
"mov": "video/quicktime",
|
||||
"flv": "video/x-flv",
|
||||
"wmv": "video/x-ms-wmv",
|
||||
"mkv": "video/x-matroska",
|
||||
# Applications
|
||||
"json": "application/json",
|
||||
"xml": "application/xml",
|
||||
"pdf": "application/pdf",
|
||||
"zip": "application/zip",
|
||||
"gz": "application/gzip",
|
||||
"tar": "application/x-tar",
|
||||
"rar": "application/vnd.rar",
|
||||
"7z": "application/x-7z-compressed",
|
||||
"exe": "application/vnd.microsoft.portable-executable",
|
||||
"msi": "application/x-msdownload",
|
||||
# Fonts
|
||||
"woff": "font/woff",
|
||||
"woff2": "font/woff2",
|
||||
"ttf": "font/ttf",
|
||||
"otf": "font/otf",
|
||||
# Microsoft Office
|
||||
"doc": "application/msword",
|
||||
"dot": "application/msword",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
# OpenDocument Formats
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
# Archives
|
||||
"tar.gz": "application/gzip",
|
||||
"tgz": "application/gzip",
|
||||
"bz2": "application/x-bzip2",
|
||||
# Others
|
||||
"rtf": "application/rtf",
|
||||
"apk": "application/vnd.android.package-archive",
|
||||
"epub": "application/epub+zip",
|
||||
"jar": "application/java-archive",
|
||||
"swf": "application/x-shockwave-flash",
|
||||
"midi": "audio/midi",
|
||||
"mid": "audio/midi",
|
||||
"ps": "application/postscript",
|
||||
"ai": "application/postscript",
|
||||
"eps": "application/postscript",
|
||||
# Custom or less common
|
||||
"bin": "application/octet-stream",
|
||||
"dmg": "application/x-apple-diskimage",
|
||||
"iso": "application/x-iso9660-image",
|
||||
"deb": "application/x-debian-package",
|
||||
"rpm": "application/x-rpm",
|
||||
"sqlite": "application/vnd.sqlite3",
|
||||
# Placeholder
|
||||
"unknown": "application/octet-stream", # Fallback for unknown file types
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1000)
|
||||
def _extract_extension(path: str) -> str:
|
||||
"""Fast extension extraction with caching"""
|
||||
if "." not in path:
|
||||
return ""
|
||||
return path.rpartition(".")[-1].lower()
|
||||
|
||||
def __init__(
|
||||
self, allowed_types: Union[str, List[str]], check_extension: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
# Normalize and store as frozenset for fast lookup
|
||||
self.allowed_types = frozenset(
|
||||
t.lower()
|
||||
for t in (
|
||||
allowed_types if isinstance(allowed_types, list) else [allowed_types]
|
||||
)
|
||||
)
|
||||
self._check_extension = check_extension
|
||||
|
||||
# Pre-compute extension map for allowed types
|
||||
self._ext_map = frozenset(
|
||||
ext
|
||||
for ext, mime in self._MIME_MAP.items()
|
||||
if any(allowed in mime for allowed in self.allowed_types)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def _check_url_cached(self, url: str) -> bool:
|
||||
"""Cached URL checking"""
|
||||
if not self._check_extension:
|
||||
return True
|
||||
|
||||
path = url.split("?")[0] # Fast path split
|
||||
ext = self._extract_extension(path)
|
||||
if not ext:
|
||||
return True
|
||||
|
||||
return ext in self._ext_map
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
"""Fast extension check with caching"""
|
||||
result = self._check_url_cached(url)
|
||||
self._update_stats(result)
|
||||
return result
|
||||
|
||||
|
||||
class FastDomainFilter(FastURLFilter):
|
||||
"""Optimized domain filter with fast lookups and caching"""
|
||||
|
||||
__slots__ = ("_allowed_domains", "_blocked_domains", "_domain_cache")
|
||||
|
||||
# Regex for fast domain extraction
|
||||
_DOMAIN_REGEX = re.compile(r"://([^/]+)")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_domains: Union[str, List[str]] = None,
|
||||
blocked_domains: Union[str, List[str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Convert inputs to frozensets for immutable, fast lookups
|
||||
self._allowed_domains = (
|
||||
frozenset(self._normalize_domains(allowed_domains))
|
||||
if allowed_domains
|
||||
else None
|
||||
)
|
||||
self._blocked_domains = (
|
||||
frozenset(self._normalize_domains(blocked_domains))
|
||||
if blocked_domains
|
||||
else frozenset()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_domains(domains: Union[str, List[str]]) -> Set[str]:
|
||||
"""Fast domain normalization"""
|
||||
if isinstance(domains, str):
|
||||
return {domains.lower()}
|
||||
return {d.lower() for d in domains}
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=10000)
|
||||
def _extract_domain(url: str) -> str:
|
||||
"""Ultra-fast domain extraction with regex and caching"""
|
||||
match = FastDomainFilter._DOMAIN_REGEX.search(url)
|
||||
return match.group(1).lower() if match else ""
|
||||
|
||||
def apply(self, url: str) -> bool:
|
||||
"""Optimized domain checking with early returns"""
|
||||
# Skip processing if no filters
|
||||
if not self._blocked_domains and self._allowed_domains is None:
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
domain = self._extract_domain(url)
|
||||
|
||||
# Early return for blocked domains
|
||||
if domain in self._blocked_domains:
|
||||
self._update_stats(False)
|
||||
return False
|
||||
|
||||
# If no allowed domains specified, accept all non-blocked
|
||||
if self._allowed_domains is None:
|
||||
self._update_stats(True)
|
||||
return True
|
||||
|
||||
# Final allowed domains check
|
||||
result = domain in self._allowed_domains
|
||||
self._update_stats(result)
|
||||
return result
|
||||
|
||||
|
||||
def create_fast_filter_chain() -> FastFilterChain:
|
||||
"""Create an optimized filter chain with filters ordered by rejection rate"""
|
||||
return FastFilterChain(
|
||||
[
|
||||
# Domain filter first (fastest rejection)
|
||||
FastDomainFilter(blocked_domains=["ads.*", "analytics.*"]),
|
||||
# Content filter second (medium speed)
|
||||
FastContentTypeFilter(["text/html", "application/xhtml+xml"]),
|
||||
# Pattern filter last (most expensive)
|
||||
FastURLPatternFilter(
|
||||
[
|
||||
"*.html",
|
||||
"*.htm",
|
||||
"*/article/*",
|
||||
"*/blog/*",
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def run_performance_test():
|
||||
import time
|
||||
import random
|
||||
from itertools import cycle
|
||||
|
||||
# Generate test URLs
|
||||
base_urls = [
|
||||
"https://example.com/article/123",
|
||||
"https://blog.example.com/post/456",
|
||||
"https://ads.example.com/tracking",
|
||||
"https://example.com/about.html",
|
||||
"https://analytics.example.com/script.js",
|
||||
"https://example.com/products.php",
|
||||
"https://subdomain.example.com/blog/post-123",
|
||||
"https://example.com/path/file.pdf",
|
||||
]
|
||||
|
||||
# Create more varied test data
|
||||
test_urls = []
|
||||
for base in base_urls:
|
||||
# Add original
|
||||
test_urls.append(base)
|
||||
# Add variations
|
||||
parts = base.split("/")
|
||||
for i in range(10):
|
||||
parts[-1] = f"page_{i}.html"
|
||||
test_urls.append("/".join(parts))
|
||||
|
||||
# Multiply to get enough test data
|
||||
test_urls = test_urls * 10000 # Creates ~800k URLs
|
||||
|
||||
def benchmark(name: str, func, *args, warmup=True):
|
||||
if warmup:
|
||||
# Warmup run
|
||||
func(*args)
|
||||
|
||||
# Actual timing
|
||||
start = time.perf_counter_ns()
|
||||
result = func(*args)
|
||||
elapsed = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
|
||||
print(
|
||||
f"{name:<30} {elapsed:>8.3f} ms ({len(test_urls)/elapsed*1000:,.0f} URLs/sec)"
|
||||
)
|
||||
return result
|
||||
|
||||
print("\nBenchmarking original vs optimized implementations...")
|
||||
print("-" * 70)
|
||||
|
||||
# Original implementation
|
||||
pattern_filter = URLPatternFilter(["*.html", "*/article/*"])
|
||||
content_filter = ContentTypeFilter(["text/html"])
|
||||
domain_filter = DomainFilter(blocked_domains=["ads.*", "analytics.*"])
|
||||
chain = FilterChain([pattern_filter, content_filter, domain_filter])
|
||||
|
||||
# Optimized implementation
|
||||
fast_pattern_filter = FastURLPatternFilter(["*.html", "*/article/*"])
|
||||
fast_content_filter = FastContentTypeFilter(["text/html"])
|
||||
fast_domain_filter = FastDomainFilter(blocked_domains=["ads.*", "analytics.*"])
|
||||
fast_chain = FastFilterChain(
|
||||
[fast_domain_filter, fast_content_filter, fast_pattern_filter]
|
||||
)
|
||||
|
||||
# Test individual filters
|
||||
print("\nSingle filter performance (first 1000 URLs):")
|
||||
test_subset = test_urls[:1000]
|
||||
|
||||
print("\nPattern Filters:")
|
||||
benchmark(
|
||||
"Original Pattern Filter",
|
||||
lambda: [pattern_filter.apply(url) for url in test_subset],
|
||||
)
|
||||
benchmark(
|
||||
"Optimized Pattern Filter",
|
||||
lambda: [fast_pattern_filter.apply(url) for url in test_subset],
|
||||
)
|
||||
|
||||
print("\nContent Filters:")
|
||||
benchmark(
|
||||
"Original Content Filter",
|
||||
lambda: [content_filter.apply(url) for url in test_subset],
|
||||
)
|
||||
benchmark(
|
||||
"Optimized Content Filter",
|
||||
lambda: [fast_content_filter.apply(url) for url in test_subset],
|
||||
)
|
||||
|
||||
print("\nDomain Filters:")
|
||||
benchmark(
|
||||
"Original Domain Filter",
|
||||
lambda: [domain_filter.apply(url) for url in test_subset],
|
||||
)
|
||||
benchmark(
|
||||
"Optimized Domain Filter",
|
||||
lambda: [fast_domain_filter.apply(url) for url in test_subset],
|
||||
)
|
||||
|
||||
print("\nFull Chain Performance (all URLs):")
|
||||
# Test chain
|
||||
benchmark("Original Chain", lambda: [chain.apply(url) for url in test_urls])
|
||||
benchmark("Optimized Chain", lambda: [fast_chain.apply(url) for url in test_urls])
|
||||
|
||||
# Memory usage
|
||||
import sys
|
||||
|
||||
print("\nMemory Usage per Filter:")
|
||||
print(f"Original Pattern Filter: {sys.getsizeof(pattern_filter):,} bytes")
|
||||
print(f"Optimized Pattern Filter: {sys.getsizeof(fast_pattern_filter):,} bytes")
|
||||
print(f"Original Content Filter: {sys.getsizeof(content_filter):,} bytes")
|
||||
print(f"Optimized Content Filter: {sys.getsizeof(fast_content_filter):,} bytes")
|
||||
print(f"Original Domain Filter: {sys.getsizeof(domain_filter):,} bytes")
|
||||
print(f"Optimized Domain Filter: {sys.getsizeof(fast_domain_filter):,} bytes")
|
||||
|
||||
def test_pattern_filter():
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
# Test cases as list of tuples instead of dict for multiple patterns
|
||||
test_cases = [
|
||||
# Simple suffix patterns (*.html)
|
||||
("*.html", {
|
||||
"https://example.com/page.html": True,
|
||||
"https://example.com/path/doc.html": True,
|
||||
"https://example.com/page.htm": False,
|
||||
"https://example.com/page.html?param=1": True,
|
||||
}),
|
||||
|
||||
# Path prefix patterns (/foo/*)
|
||||
("*/article/*", {
|
||||
"https://example.com/article/123": True,
|
||||
"https://example.com/blog/article/456": True,
|
||||
"https://example.com/articles/789": False,
|
||||
"https://example.com/article": False,
|
||||
}),
|
||||
|
||||
# Complex patterns
|
||||
("blog-*-[0-9]", {
|
||||
"https://example.com/blog-post-1": True,
|
||||
"https://example.com/blog-test-9": True,
|
||||
"https://example.com/blog-post": False,
|
||||
"https://example.com/blog-post-x": False,
|
||||
}),
|
||||
|
||||
# Multiple patterns case
|
||||
(["*.pdf", "*/download/*"], {
|
||||
"https://example.com/doc.pdf": True,
|
||||
"https://example.com/download/file.txt": True,
|
||||
"https://example.com/path/download/doc": True,
|
||||
"https://example.com/uploads/file.txt": False,
|
||||
}),
|
||||
|
||||
# Edge cases
|
||||
("*", {
|
||||
"https://example.com": True,
|
||||
"": True,
|
||||
"http://test.com/path": True,
|
||||
}),
|
||||
|
||||
# Complex regex
|
||||
(r"^https?://.*\.example\.com/\d+", {
|
||||
"https://sub.example.com/123": True,
|
||||
"http://test.example.com/456": True,
|
||||
"https://example.com/789": False,
|
||||
"https://sub.example.com/abc": False,
|
||||
})
|
||||
]
|
||||
|
||||
def run_accuracy_test():
|
||||
print("\nAccuracy Tests:")
|
||||
print("-" * 50)
|
||||
|
||||
all_passed = True
|
||||
for patterns, test_urls in test_cases:
|
||||
filter_obj = FastURLPatternFilter(patterns)
|
||||
|
||||
for url, expected in test_urls.items():
|
||||
result = filter_obj.apply(url)
|
||||
if result != expected:
|
||||
print(f"❌ Failed: Pattern '{patterns}' with URL '{url}'")
|
||||
print(f" Expected: {expected}, Got: {result}")
|
||||
all_passed = False
|
||||
else:
|
||||
print(f"✅ Passed: Pattern '{patterns}' with URL '{url}'")
|
||||
|
||||
return all_passed
|
||||
|
||||
def run_speed_test():
|
||||
print("\nSpeed Tests:")
|
||||
print("-" * 50)
|
||||
|
||||
# Create a large set of test URLs
|
||||
all_urls = list(chain.from_iterable(urls.keys() for _, urls in test_cases))
|
||||
test_urls = all_urls * 10000 # 100K+ URLs
|
||||
|
||||
# Test both implementations
|
||||
original = URLPatternFilter(["*.html", "*/article/*", "blog-*"])
|
||||
optimized = FastURLPatternFilter(["*.html", "*/article/*", "blog-*"])
|
||||
|
||||
def benchmark(name, filter_obj):
|
||||
start = time.perf_counter()
|
||||
for url in test_urls:
|
||||
filter_obj.apply(url)
|
||||
elapsed = time.perf_counter() - start
|
||||
urls_per_sec = len(test_urls) / elapsed
|
||||
print(f"{name:<20} {elapsed:.3f}s ({urls_per_sec:,.0f} URLs/sec)")
|
||||
|
||||
benchmark("Original Filter:", original)
|
||||
benchmark("Optimized Filter:", optimized)
|
||||
|
||||
# Run tests
|
||||
print("Running Pattern Filter Tests...")
|
||||
accuracy_passed = run_accuracy_test()
|
||||
|
||||
if accuracy_passed:
|
||||
print("\n✨ All accuracy tests passed!")
|
||||
run_speed_test()
|
||||
else:
|
||||
print("\n❌ Some accuracy tests failed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_performance_test()
|
||||
# test_pattern_filter()
|
||||
1204
crawl4ai/deep_crawling/scorers.py
Normal file
1204
crawl4ai/deep_crawling/scorers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,16 @@ class MarkdownGenerationResult(BaseModel):
|
||||
fit_markdown: Optional[str] = None
|
||||
fit_html: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class TraversalStats:
|
||||
"""Statistics for the traversal process"""
|
||||
|
||||
start_time: datetime = datetime.now()
|
||||
urls_processed: int = 0
|
||||
urls_failed: int = 0
|
||||
urls_skipped: int = 0
|
||||
total_depth_reached: int = 0
|
||||
current_depth: int = 0
|
||||
|
||||
class DispatchResult(BaseModel):
|
||||
task_id: str
|
||||
|
||||
43
tests/20241401/test_deep_crawl.py
Normal file
43
tests/20241401/test_deep_crawl.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
from crawl4ai import CrawlerRunConfig, AsyncWebCrawler, CacheMode
|
||||
from crawl4ai.deep_crawling.bfs_strategy import BFSDeepCrawlStrategy
|
||||
|
||||
|
||||
async def main():
|
||||
"""Example deep crawl of documentation site."""
|
||||
config = CrawlerRunConfig(
|
||||
deep_crawl_strategy = BFSDeepCrawlStrategy(
|
||||
max_depth=2,
|
||||
include_external=False
|
||||
),
|
||||
stream=False,
|
||||
verbose=True,
|
||||
cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
start_time = time.perf_counter()
|
||||
print("\nStarting deep crawl in batch mode:")
|
||||
results = await crawler.arun(
|
||||
url="https://docs.crawl4ai.com",
|
||||
config=config
|
||||
)
|
||||
print(f"Crawled {len(results)} pages")
|
||||
print(f"Example page: {results[0].url}")
|
||||
print(f"Duration: {time.perf_counter() - start_time:.2f} seconds\n")
|
||||
|
||||
print("Starting deep crawl in streaming mode:")
|
||||
config.stream = True
|
||||
start_time = time.perf_counter()
|
||||
async for result in await crawler.arun(
|
||||
url="https://docs.crawl4ai.com",
|
||||
config=config
|
||||
):
|
||||
print(f"→ {result.url} (Depth: {result.metadata.get('depth', 0)})")
|
||||
print(f"Duration: {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user