Refactor:Moved deep_crawl_strategy, inside crawler run config

This commit is contained in:
Aravind Karnam
2025-01-30 16:18:15 +05:30
parent 858c18df39
commit ca3f0126d3
9 changed files with 79 additions and 57 deletions

View File

@@ -10,6 +10,7 @@ from .config import (
from .user_agent_generator import UserAgentGenerator, UAGen, ValidUAGenerator, OnlineUAGenerator from .user_agent_generator import UserAgentGenerator, UAGen, ValidUAGenerator, OnlineUAGenerator
from .extraction_strategy import ExtractionStrategy from .extraction_strategy import ExtractionStrategy
from .chunking_strategy import ChunkingStrategy, RegexChunking from .chunking_strategy import ChunkingStrategy, RegexChunking
from .deep_crawl import DeepCrawlStrategy
from .markdown_generation_strategy import MarkdownGenerationStrategy from .markdown_generation_strategy import MarkdownGenerationStrategy
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter, LLMContentFilter, PruningContentFilter from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter, LLMContentFilter, PruningContentFilter
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
@@ -395,6 +396,7 @@ class CrawlerRunConfig:
word_count_threshold: int = MIN_WORD_THRESHOLD, word_count_threshold: int = MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None, extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
deep_crawl_strategy: DeepCrawlStrategy = None,
markdown_generator: MarkdownGenerationStrategy = None, markdown_generator: MarkdownGenerationStrategy = None,
content_filter : RelevantContentFilter = None, content_filter : RelevantContentFilter = None,
only_text: bool = False, only_text: bool = False,
@@ -468,6 +470,7 @@ class CrawlerRunConfig:
self.word_count_threshold = word_count_threshold self.word_count_threshold = word_count_threshold
self.extraction_strategy = extraction_strategy self.extraction_strategy = extraction_strategy
self.chunking_strategy = chunking_strategy self.chunking_strategy = chunking_strategy
self.deep_crawl_strategy = deep_crawl_strategy
self.markdown_generator = markdown_generator self.markdown_generator = markdown_generator
self.content_filter = content_filter self.content_filter = content_filter
self.only_text = only_text self.only_text = only_text
@@ -555,6 +558,14 @@ class CrawlerRunConfig:
raise ValueError( raise ValueError(
"extraction_strategy must be an instance of ExtractionStrategy" "extraction_strategy must be an instance of ExtractionStrategy"
) )
if self.deep_crawl_strategy is not None and not isinstance(
self.deep_crawl_strategy, DeepCrawlStrategy
):
raise ValueError(
"deep_crawl_strategy must be an instance of DeepCrawlStrategy"
)
if self.chunking_strategy is not None and not isinstance( if self.chunking_strategy is not None and not isinstance(
self.chunking_strategy, ChunkingStrategy self.chunking_strategy, ChunkingStrategy
): ):
@@ -573,6 +584,7 @@ class CrawlerRunConfig:
word_count_threshold=kwargs.get("word_count_threshold", 200), word_count_threshold=kwargs.get("word_count_threshold", 200),
extraction_strategy=kwargs.get("extraction_strategy"), extraction_strategy=kwargs.get("extraction_strategy"),
chunking_strategy=kwargs.get("chunking_strategy", RegexChunking()), chunking_strategy=kwargs.get("chunking_strategy", RegexChunking()),
deep_crawl_strategy=kwargs.get("deep_crawl_strategy"),
markdown_generator=kwargs.get("markdown_generator"), markdown_generator=kwargs.get("markdown_generator"),
content_filter=kwargs.get("content_filter"), content_filter=kwargs.get("content_filter"),
only_text=kwargs.get("only_text", False), only_text=kwargs.get("only_text", False),
@@ -656,6 +668,7 @@ class CrawlerRunConfig:
"word_count_threshold": self.word_count_threshold, "word_count_threshold": self.word_count_threshold,
"extraction_strategy": self.extraction_strategy, "extraction_strategy": self.extraction_strategy,
"chunking_strategy": self.chunking_strategy, "chunking_strategy": self.chunking_strategy,
"deep_crawl_strategy": self.deep_crawl_strategy,
"markdown_generator": self.markdown_generator, "markdown_generator": self.markdown_generator,
"content_filter": self.content_filter, "content_filter": self.content_filter,
"only_text": self.only_text, "only_text": self.only_text,

View File

@@ -38,7 +38,7 @@ from .async_logger import AsyncLogger
from .async_configs import BrowserConfig, CrawlerRunConfig from .async_configs import BrowserConfig, CrawlerRunConfig
from .async_dispatcher import * # noqa: F403 from .async_dispatcher import * # noqa: F403
from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher, RateLimiter from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher, RateLimiter
from .traversal import TraversalStrategy from .deep_crawl import DeepCrawlStrategy
from .config import MIN_WORD_THRESHOLD from .config import MIN_WORD_THRESHOLD
from .utils import ( from .utils import (
@@ -53,11 +53,17 @@ from .utils import (
from typing import Union, AsyncGenerator, List, TypeVar from typing import Union, AsyncGenerator, List, TypeVar
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
CrawlResultT = TypeVar("CrawlResultT", bound=CrawlResult)
RunManyReturn = Union[List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
from .__version__ import __version__ as crawl4ai_version from .__version__ import __version__ as crawl4ai_version
CrawlResultT = TypeVar("CrawlResultT", bound=CrawlResult)
RunManyReturn = Union[List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
DeepCrawlSingleReturn = Union[List[CrawlResultT], AsyncGenerator[CrawlResultT, None]]
DeepCrawlManyReturn = Union[
List[List[CrawlResultT]],
AsyncGenerator[CrawlResultT, None],
]
class AsyncWebCrawler: class AsyncWebCrawler:
""" """
@@ -289,7 +295,7 @@ class AsyncWebCrawler:
user_agent: str = None, user_agent: str = None,
verbose=True, verbose=True,
**kwargs, **kwargs,
) -> CrawlResult: ) -> Union[CrawlResult, DeepCrawlSingleReturn]:
""" """
Runs the crawler for a single source: URL (web, local file, or raw HTML). Runs the crawler for a single source: URL (web, local file, or raw HTML).
@@ -391,6 +397,23 @@ class AsyncWebCrawler:
extracted_content = None extracted_content = None
start_time = time.perf_counter() start_time = time.perf_counter()
if crawler_config.deep_crawl_strategy:
if crawler_config.stream:
return crawler_config.deep_crawl_strategy.arun(
start_url=url,
crawler=self,
crawler_run_config=crawler_config,
)
else:
results = []
async for result in crawler_config.deep_crawl_strategy.arun(
start_url=url,
crawler=self,
crawler_run_config=crawler_config,
):
results.append(result)
return results
# Try to get cached result if appropriate # Try to get cached result if appropriate
if cache_context.should_read(): if cache_context.should_read():
cached_result = await async_db_manager.aget_cached_url(url) cached_result = await async_db_manager.aget_cached_url(url)
@@ -743,7 +766,7 @@ class AsyncWebCrawler:
user_agent: str = None, user_agent: str = None,
verbose=True, verbose=True,
**kwargs, **kwargs,
) -> RunManyReturn: ) -> Union[RunManyReturn, DeepCrawlManyReturn]:
""" """
Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy. Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy.
@@ -830,7 +853,7 @@ class AsyncWebCrawler:
async def adeep_crawl( async def adeep_crawl(
self, self,
url: str, url: str,
strategy: TraversalStrategy, strategy: DeepCrawlStrategy,
crawler_run_config: Optional[CrawlerRunConfig] = None, crawler_run_config: Optional[CrawlerRunConfig] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> Union[AsyncGenerator[CrawlResult,None],List[CrawlResult]]: ) -> Union[AsyncGenerator[CrawlResult,None],List[CrawlResult]]:

View File

@@ -1,4 +1,4 @@
from .bfs_traversal_strategy import BFSTraversalStrategy from .bfs_deep_crawl_strategy import BFSDeepCrawlStrategy
from .filters import ( from .filters import (
URLFilter, URLFilter,
FilterChain, FilterChain,
@@ -12,10 +12,10 @@ from .scorers import (
FreshnessScorer, FreshnessScorer,
CompositeScorer, CompositeScorer,
) )
from .traversal_strategy import TraversalStrategy from .deep_crawl_strategty import DeepCrawlStrategy
__all__ = [ __all__ = [
"BFSTraversalStrategy", "BFSDeepCrawlStrategy",
"FilterChain", "FilterChain",
"URLFilter", "URLFilter",
"URLPatternFilter", "URLPatternFilter",
@@ -25,5 +25,5 @@ __all__ = [
"PathDepthScorer", "PathDepthScorer",
"FreshnessScorer", "FreshnessScorer",
"CompositeScorer", "CompositeScorer",
"TraversalStrategy", "DeepCrawlStrategy",
] ]

View File

@@ -3,15 +3,14 @@ from datetime import datetime
import asyncio import asyncio
import logging import logging
from urllib.parse import urlparse from urllib.parse import urlparse
from ..async_configs import CrawlerRunConfig
from ..models import CrawlResult, TraversalStats from ..models import CrawlResult, TraversalStats
from .filters import FilterChain from .filters import FilterChain
from .scorers import URLScorer from .scorers import URLScorer
from .traversal_strategy import TraversalStrategy from .deep_crawl_strategty import DeepCrawlStrategy
from ..config import DEEP_CRAWL_BATCH_SIZE from ..config import DEEP_CRAWL_BATCH_SIZE
class BFSTraversalStrategy(TraversalStrategy): class BFSDeepCrawlStrategy(DeepCrawlStrategy):
"""Best-First Search traversal strategy with filtering and scoring.""" """Best-First Search traversal strategy with filtering and scoring."""
def __init__( def __init__(
@@ -98,11 +97,11 @@ class BFSTraversalStrategy(TraversalStrategy):
self.stats.total_depth_reached, next_depth self.stats.total_depth_reached, next_depth
) )
async def deep_crawl( async def arun(
self, self,
start_url: str, start_url: str,
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler",
crawler_run_config: Optional[CrawlerRunConfig] = None, crawler_run_config: Optional["CrawlerRunConfig"] = None,
) -> AsyncGenerator[CrawlResult, None]: ) -> AsyncGenerator[CrawlResult, None]:
"""Implement BFS traversal strategy""" """Implement BFS traversal strategy"""
@@ -136,7 +135,9 @@ class BFSTraversalStrategy(TraversalStrategy):
""" """
# Collect batch of URLs into active_crawls to process # Collect batch of URLs into active_crawls to process
async with active_crawls_lock: async with active_crawls_lock:
while len(active_crawls) < DEEP_CRAWL_BATCH_SIZE and not queue.empty(): while (
len(active_crawls) < DEEP_CRAWL_BATCH_SIZE and not queue.empty()
):
score, depth, url, parent_url = await queue.get() score, depth, url, parent_url = await queue.get()
active_crawls[url] = { active_crawls[url] = {
"depth": depth, "depth": depth,
@@ -151,14 +152,14 @@ class BFSTraversalStrategy(TraversalStrategy):
continue continue
# Process batch # Process batch
try: try:
stream_config = ( # This is very important to ensure recursively you don't deep_crawl down the children.
crawler_run_config.clone(stream=True) if crawler_run_config:
if crawler_run_config crawler_run_config = crawler_run_config.clone(
else CrawlerRunConfig(stream=True) deep_crawl_strategy=None, stream=True
) )
async for result in await crawler.arun_many( async for result in await crawler.arun_many(
urls=list(active_crawls.keys()), urls=list(active_crawls.keys()),
config=stream_config, config=crawler_run_config
): ):
async with active_crawls_lock: async with active_crawls_lock:
crawl_info = active_crawls.pop(result.url, None) crawl_info = active_crawls.pop(result.url, None)

View File

@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator from typing import AsyncGenerator, Optional
from ..async_configs import CrawlerRunConfig
from ..models import CrawlResult from ..models import CrawlResult
class TraversalStrategy(ABC): class DeepCrawlStrategy(ABC):
@abstractmethod @abstractmethod
async def deep_crawl( async def arun(
self, self,
url: str, url: str,
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler",
crawler_run_config: CrawlerRunConfig = None, crawler_run_config: Optional["CrawlerRunConfig"] = None,
) -> AsyncGenerator[CrawlResult, None]: ) -> AsyncGenerator[CrawlResult, None]:
"""Traverse the given URL using the specified crawler. """Traverse the given URL using the specified crawler.

View File

@@ -140,7 +140,6 @@ class CrawlResult(BaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class AsyncCrawlResponse(BaseModel): class AsyncCrawlResponse(BaseModel):
html: str html: str
response_headers: Dict[str, str] response_headers: Dict[str, str]

View File

@@ -1,18 +1,25 @@
# basic_scraper_example.py # basic_scraper_example.py
from crawl4ai.async_configs import CrawlerRunConfig, BrowserConfig from crawl4ai.async_configs import CrawlerRunConfig, BrowserConfig
from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy
from crawl4ai.traversal import ( from crawl4ai.deep_crawl import (
BFSTraversalStrategy, BFSDeepCrawlStrategy,
FilterChain, FilterChain,
URLPatternFilter, URLPatternFilter,
ContentTypeFilter, ContentTypeFilter,
DomainFilter,
KeywordRelevanceScorer,
PathDepthScorer,
FreshnessScorer,
CompositeScorer,
) )
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
import re import re
import time import time
import logging
browser_config = BrowserConfig(headless=True, viewport_width=800, viewport_height=600) browser_config = BrowserConfig(headless=True, viewport_width=800, viewport_height=600)
async def basic_scraper_example(): async def basic_scraper_example():
""" """
Basic example: Scrape a blog site for articles Basic example: Scrape a blog site for articles
@@ -31,7 +38,7 @@ async def basic_scraper_example():
) )
# Initialize the strategy with basic configuration # Initialize the strategy with basic configuration
bfs_strategy = BFSTraversalStrategy( bfs_strategy = BFSDeepCrawlStrategy(
max_depth=2, # Only go 2 levels deep max_depth=2, # Only go 2 levels deep
filter_chain=filter_chain, filter_chain=filter_chain,
url_scorer=None, # Use default scoring url_scorer=None, # Use default scoring
@@ -44,8 +51,8 @@ async def basic_scraper_example():
) as crawler: ) as crawler:
# Start scraping # Start scraping
try: try:
results = await crawler.adeep_crawl( results = await crawler.arun(
"https://crawl4ai.com/mkdocs", strategy=bfs_strategy "https://crawl4ai.com/mkdocs", CrawlerRunConfig(deep_crawl_strategy=bfs_strategy)
) )
# Process results # Process results
print(f"Crawled {len(results)} pages:") print(f"Crawled {len(results)} pages:")
@@ -55,23 +62,6 @@ async def basic_scraper_example():
except Exception as e: except Exception as e:
print(f"Error during scraping: {e}") print(f"Error during scraping: {e}")
# advanced_scraper_example.py
import logging
from crawl4ai.traversal import (
BFSTraversalStrategy,
FilterChain,
URLPatternFilter,
ContentTypeFilter,
DomainFilter,
KeywordRelevanceScorer,
PathDepthScorer,
FreshnessScorer,
CompositeScorer,
)
async def advanced_scraper_example(): async def advanced_scraper_example():
""" """
Advanced example: Intelligent news site scraping Advanced example: Intelligent news site scraping
@@ -121,7 +111,7 @@ async def advanced_scraper_example():
) )
# Initialize strategy with advanced configuration # Initialize strategy with advanced configuration
bfs_strategy = BFSTraversalStrategy( bfs_strategy = BFSDeepCrawlStrategy(
max_depth=2, filter_chain=filter_chain, url_scorer=scorer max_depth=2, filter_chain=filter_chain, url_scorer=scorer
) )
@@ -136,13 +126,10 @@ async def advanced_scraper_example():
try: try:
# Use streaming mode # Use streaming mode
results = [] results = []
result_generator = await crawler.adeep_crawl( result_generator = await crawler.arun(
"https://techcrunch.com", "https://techcrunch.com",
strategy=bfs_strategy, config=CrawlerRunConfig(deep_crawl_strategy=bfs_strategy,
crawler_run_config=CrawlerRunConfig( stream=True)
scraping_strategy=LXMLWebScrapingStrategy()
),
stream=True,
) )
async for result in result_generator: async for result in result_generator:
stats["processed"] += 1 stats["processed"] += 1