diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..92f3c1a7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +# .pre-commit-config.yaml +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.11 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format \ No newline at end of file diff --git a/crawl4ai/__init__.py b/crawl4ai/__init__.py index 86c2cb9e..78ccdb02 100644 --- a/crawl4ai/__init__.py +++ b/crawl4ai/__init__.py @@ -2,14 +2,28 @@ from .async_webcrawler import AsyncWebCrawler, CacheMode from .async_configs import BrowserConfig, CrawlerRunConfig -from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy -from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy +from .content_scraping_strategy import ( + ContentScrapingStrategy, + WebScrapingStrategy, + LXMLWebScrapingStrategy, +) +from .extraction_strategy import ( + ExtractionStrategy, + LLMExtractionStrategy, + CosineStrategy, + JsonCssExtractionStrategy, +) from .chunking_strategy import ChunkingStrategy, RegexChunking from .markdown_generation_strategy import DefaultMarkdownGenerator from .content_filter_strategy import PruningContentFilter, BM25ContentFilter from .models import CrawlResult, MarkdownGenerationResult -from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode -from .__version__ import __version__ +from .async_dispatcher import ( + MemoryAdaptiveDispatcher, + SemaphoreDispatcher, + RateLimiter, + CrawlerMonitor, + DisplayMode, +) __all__ = [ "AsyncWebCrawler", @@ -18,40 +32,45 @@ __all__ = [ "ContentScrapingStrategy", "WebScrapingStrategy", "LXMLWebScrapingStrategy", - 'BrowserConfig', - 'CrawlerRunConfig', - 'ExtractionStrategy', - 'LLMExtractionStrategy', - 'CosineStrategy', - 'JsonCssExtractionStrategy', - 'ChunkingStrategy', - 'RegexChunking', - 'DefaultMarkdownGenerator', - 'PruningContentFilter', - 'BM25ContentFilter', - 'MemoryAdaptiveDispatcher', - 'SemaphoreDispatcher', - 'RateLimiter', - 'CrawlerMonitor', - 'DisplayMode', - 'MarkdownGenerationResult', + "BrowserConfig", + "CrawlerRunConfig", + "ExtractionStrategy", + "LLMExtractionStrategy", + "CosineStrategy", + "JsonCssExtractionStrategy", + "ChunkingStrategy", + "RegexChunking", + "DefaultMarkdownGenerator", + "PruningContentFilter", + "BM25ContentFilter", + "MemoryAdaptiveDispatcher", + "SemaphoreDispatcher", + "RateLimiter", + "CrawlerMonitor", + "DisplayMode", + "MarkdownGenerationResult", ] + def is_sync_version_installed(): try: import selenium + return True except ImportError: return False + if is_sync_version_installed(): try: from .web_crawler import WebCrawler + __all__.append("WebCrawler") except ImportError: - import warnings - print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.") + print( + "Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies." + ) else: WebCrawler = None # import warnings - # print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.") \ No newline at end of file + # print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.") diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index 28f90bb3..c6f25994 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -5,7 +5,6 @@ from .config import ( PAGE_TIMEOUT, IMAGE_SCORE_THRESHOLD, SOCIAL_MEDIA_DOMAINS, - ) from .user_agent_generator import UserAgentGenerator from .extraction_strategy import ExtractionStrategy @@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy from typing import Union, List + class BrowserConfig: """ Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy. @@ -84,7 +84,7 @@ class BrowserConfig: proxy: str = None, proxy_config: dict = None, viewport_width: int = 1080, - viewport_height: int = 600, + viewport_height: int = 600, accept_downloads: bool = False, downloads_path: str = None, storage_state=None, @@ -103,7 +103,7 @@ class BrowserConfig: text_mode: bool = False, light_mode: bool = False, extra_args: list = None, - debugging_port : int = 9222, + debugging_port: int = 9222, ): self.browser_type = browser_type self.headless = headless @@ -142,7 +142,7 @@ class BrowserConfig: self.user_agent = user_agenr_generator.generate() else: pass - + self.browser_hint = user_agenr_generator.generate_client_hints(self.user_agent) self.headers.setdefault("sec-ch-ua", self.browser_hint) @@ -313,7 +313,7 @@ class CrawlerRunConfig: Default: True. log_console (bool): If True, log console messages from the page. Default: False. - + # Optional Parameters url: str = None # This is not a compulsory parameter """ @@ -335,10 +335,8 @@ class CrawlerRunConfig: prettiify: bool = False, parser_type: str = "lxml", scraping_strategy: ContentScrapingStrategy = None, - # SSL Parameters fetch_ssl_certificate: bool = False, - # Caching Parameters cache_mode=None, session_id: str = None, @@ -346,7 +344,6 @@ class CrawlerRunConfig: disable_cache: bool = False, no_cache_read: bool = False, no_cache_write: bool = False, - # Page Navigation and Timing Parameters wait_until: str = "domcontentloaded", page_timeout: int = PAGE_TIMEOUT, @@ -356,7 +353,6 @@ class CrawlerRunConfig: mean_delay: float = 0.1, max_range: float = 0.3, semaphore_count: int = 5, - # Page Interaction Parameters js_code: Union[str, List[str]] = None, js_only: bool = False, @@ -369,7 +365,6 @@ class CrawlerRunConfig: override_navigator: bool = False, magic: bool = False, adjust_viewport_to_content: bool = False, - # Media Handling Parameters screenshot: bool = False, screenshot_wait_for: float = None, @@ -378,21 +373,18 @@ class CrawlerRunConfig: image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, image_score_threshold: int = IMAGE_SCORE_THRESHOLD, exclude_external_images: bool = False, - # Link and Domain Handling Parameters exclude_social_media_domains: list = None, exclude_external_links: bool = False, exclude_social_media_links: bool = False, exclude_domains: list = None, - # Debugging and Logging Parameters verbose: bool = True, log_console: bool = False, - url: str = None, ): self.url = url - + # Content Processing Parameters self.word_count_threshold = word_count_threshold self.extraction_strategy = extraction_strategy @@ -453,7 +445,9 @@ class CrawlerRunConfig: self.exclude_external_images = exclude_external_images # Link and Domain Handling Parameters - self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS + self.exclude_social_media_domains = ( + exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS + ) self.exclude_external_links = exclude_external_links self.exclude_social_media_links = exclude_social_media_links self.exclude_domains = exclude_domains or [] @@ -466,11 +460,15 @@ class CrawlerRunConfig: if self.extraction_strategy is not None and not isinstance( self.extraction_strategy, ExtractionStrategy ): - raise ValueError("extraction_strategy must be an instance of ExtractionStrategy") + raise ValueError( + "extraction_strategy must be an instance of ExtractionStrategy" + ) if self.chunking_strategy is not None and not isinstance( self.chunking_strategy, ChunkingStrategy ): - raise ValueError("chunking_strategy must be an instance of ChunkingStrategy") + raise ValueError( + "chunking_strategy must be an instance of ChunkingStrategy" + ) # Set default chunking strategy if None if self.chunking_strategy is None: @@ -494,10 +492,8 @@ class CrawlerRunConfig: prettiify=kwargs.get("prettiify", False), parser_type=kwargs.get("parser_type", "lxml"), scraping_strategy=kwargs.get("scraping_strategy"), - # SSL Parameters fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False), - # Caching Parameters cache_mode=kwargs.get("cache_mode"), session_id=kwargs.get("session_id"), @@ -505,7 +501,6 @@ class CrawlerRunConfig: disable_cache=kwargs.get("disable_cache", False), no_cache_read=kwargs.get("no_cache_read", False), no_cache_write=kwargs.get("no_cache_write", False), - # Page Navigation and Timing Parameters wait_until=kwargs.get("wait_until", "domcontentloaded"), page_timeout=kwargs.get("page_timeout", 60000), @@ -515,7 +510,6 @@ class CrawlerRunConfig: mean_delay=kwargs.get("mean_delay", 0.1), max_range=kwargs.get("max_range", 0.3), semaphore_count=kwargs.get("semaphore_count", 5), - # Page Interaction Parameters js_code=kwargs.get("js_code"), js_only=kwargs.get("js_only", False), @@ -528,29 +522,34 @@ class CrawlerRunConfig: override_navigator=kwargs.get("override_navigator", False), magic=kwargs.get("magic", False), adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False), - # Media Handling Parameters screenshot=kwargs.get("screenshot", False), screenshot_wait_for=kwargs.get("screenshot_wait_for"), - screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD), + screenshot_height_threshold=kwargs.get( + "screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD + ), pdf=kwargs.get("pdf", False), - image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD), - image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD), + image_description_min_word_threshold=kwargs.get( + "image_description_min_word_threshold", + IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, + ), + image_score_threshold=kwargs.get( + "image_score_threshold", IMAGE_SCORE_THRESHOLD + ), exclude_external_images=kwargs.get("exclude_external_images", False), - # Link and Domain Handling Parameters - exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS), + exclude_social_media_domains=kwargs.get( + "exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS + ), exclude_external_links=kwargs.get("exclude_external_links", False), exclude_social_media_links=kwargs.get("exclude_social_media_links", False), exclude_domains=kwargs.get("exclude_domains", []), - # Debugging and Logging Parameters verbose=kwargs.get("verbose", True), log_console=kwargs.get("log_console", False), - url=kwargs.get("url"), ) - + # Create a funciton returns dict of the object def to_dict(self): return { diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index ed7407f8..0edefa73 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -2,27 +2,25 @@ import asyncio import base64 import time from abc import ABC, abstractmethod -from typing import Callable, Dict, Any, List, Optional, Awaitable, Union -import os, sys, shutil -import tempfile, subprocess -from playwright.async_api import async_playwright, Page, Browser, Error, BrowserContext +from typing import Callable, Dict, Any, List, Optional, Union +import os +import sys +import shutil +import tempfile +import subprocess +from playwright.async_api import Page, Error, BrowserContext from playwright.async_api import TimeoutError as PlaywrightTimeoutError from io import BytesIO from PIL import Image, ImageDraw, ImageFont -from pathlib import Path -from playwright.async_api import ProxySettings -from pydantic import BaseModel import hashlib -import json import uuid from .js_snippet import load_js_script from .models import AsyncCrawlResponse -from .utils import get_error_context from .user_agent_generator import UserAgentGenerator from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT from .async_configs import BrowserConfig, CrawlerRunConfig from .async_logger import AsyncLogger -from playwright_stealth import StealthConfig, stealth_async +from playwright_stealth import StealthConfig from .ssl_certificate import SSLCertificate stealth_config = StealthConfig( @@ -66,7 +64,7 @@ BROWSER_DISABLE_OPTIONS = [ class ManagedBrowser: """ Manages the browser process and context. This class allows to connect to the browser using CDP protocol. - + Attributes: browser_type (str): The type of browser to launch. Supported values: "chromium", "firefox", "webkit". Default: "chromium". @@ -75,16 +73,16 @@ class ManagedBrowser: headless (bool): Whether to run the browser in headless mode (no visible GUI). Default: True. browser_process (subprocess.Popen): The process object for the browser. - temp_dir (str): Temporary directory for user data if not provided. + temp_dir (str): Temporary directory for user data if not provided. debugging_port (int): Port for debugging the browser. host (str): Host for debugging the browser. - + Methods: start(): Starts the browser process and returns the CDP endpoint URL. _get_browser_path(): Returns the browser executable path based on OS and browser type. _get_browser_args(): Returns browser-specific command line arguments. _get_user_data_dir(): Returns the user data directory path. - _cleanup(): Terminates the browser process and removes the temporary directory. + _cleanup(): Terminates the browser process and removes the temporary directory. """ browser_type: str @@ -94,6 +92,7 @@ class ManagedBrowser: temp_dir: str debugging_port: int host: str + def __init__( self, browser_type: str = "chromium", @@ -105,7 +104,7 @@ class ManagedBrowser: ): """ Initialize the ManagedBrowser instance. - + Args: browser_type (str): The type of browser to launch. Supported values: "chromium", "firefox", "webkit". Default: "chromium". @@ -116,7 +115,7 @@ class ManagedBrowser: logger (logging.Logger): Logger instance for logging messages. Default: None. host (str): Host for debugging the browser. Default: "localhost". debugging_port (int): Port for debugging the browser. Default: 9222. - """ + """ self.browser_type = browser_type self.user_data_dir = user_data_dir self.headless = headless @@ -139,7 +138,7 @@ class ManagedBrowser: self.user_data_dir = self.temp_dir # Get browser path and args based on OS and browser type - browser_path = self._get_browser_path() + # browser_path = self._get_browser_path() args = self._get_browser_args() # Start browser process @@ -158,13 +157,13 @@ class ManagedBrowser: async def _monitor_browser_process(self): """ Monitor the browser process for unexpected termination. - + How it works: 1. Read stdout and stderr from the browser process. 2. If the process has terminated, log the error message and terminate the browser. 3. If the shutting_down flag is set, log the normal termination message. 4. If any other error occurs, log the error message. - + Note: This method should be called in a separate task to avoid blocking the main event loop. """ if self.browser_process: @@ -289,17 +288,18 @@ class ManagedBrowser: class BrowserManager: """ Manages the browser instance and context. - - Attributes: + + Attributes: config (BrowserConfig): Configuration object containing all browser settings logger: Logger instance for recording events and errors browser (Browser): The browser instance - default_context (BrowserContext): The default browser context + default_context (BrowserContext): The default browser context managed_browser (ManagedBrowser): The managed browser instance playwright (Playwright): The Playwright instance sessions (dict): Dictionary to store session information session_ttl (int): Session timeout in seconds """ + def __init__(self, browser_config: BrowserConfig, logger=None): """ Initialize the BrowserManager with a browser configuration. @@ -334,13 +334,13 @@ class BrowserManager: async def start(self): """ Start the browser instance and set up the default context. - + How it works: 1. Check if Playwright is already initialized. 2. If not, initialize Playwright. 3. If managed browser is used, start it and connect to the CDP endpoint. 4. If managed browser is not used, launch the browser and set up the default context. - + Note: This method should be called in a separate task to avoid blocking the main event loop. """ if self.playwright is None: @@ -453,7 +453,12 @@ class BrowserManager: return browser_args - async def setup_context(self, context: BrowserContext, crawlerRunConfig: CrawlerRunConfig = None, is_default=False): + async def setup_context( + self, + context: BrowserContext, + crawlerRunConfig: CrawlerRunConfig = None, + is_default=False, + ): """ Set up a browser context with the configured options. @@ -474,11 +479,11 @@ class BrowserManager: 14. Set default timeouts for navigation and download if enabled. 15. Set user agent if provided. 16. Set browser hints if provided. - + Args: context (BrowserContext): The browser context to set up crawlerRunConfig (CrawlerRunConfig): Configuration object containing all browser settings - is_default (bool): Flag indicating if this is the default context + is_default (bool): Flag indicating if this is the default context Returns: None """ @@ -496,9 +501,9 @@ class BrowserManager: context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT) if self.config.downloads_path: context._impl_obj._options["accept_downloads"] = True - context._impl_obj._options["downloads_path"] = ( - self.config.downloads_path - ) + context._impl_obj._options[ + "downloads_path" + ] = self.config.downloads_path # Handle user agent and browser hints if self.config.user_agent: @@ -511,7 +516,15 @@ class BrowserManager: # Add default cookie await context.add_cookies( - [{"name": "cookiesEnabled", "value": "true", "url": crawlerRunConfig.url if crawlerRunConfig else "https://crawl4ai.com/"}] + [ + { + "name": "cookiesEnabled", + "value": "true", + "url": crawlerRunConfig.url + if crawlerRunConfig + else "https://crawl4ai.com/", + } + ] ) # Handle navigator overrides @@ -527,7 +540,7 @@ class BrowserManager: """ Creates and returns a new browser context with configured settings. Applies text-only mode settings if text_mode is enabled in config. - + Returns: Context: Browser context object with the specified configurations """ @@ -538,25 +551,62 @@ class BrowserManager: "height": self.config.viewport_height, } proxy_settings = {"server": self.config.proxy} if self.config.proxy else None - + blocked_extensions = [ # Images - 'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'ico', 'bmp', 'tiff', 'psd', + "jpg", + "jpeg", + "png", + "gif", + "webp", + "svg", + "ico", + "bmp", + "tiff", + "psd", # Fonts - 'woff', 'woff2', 'ttf', 'otf', 'eot', + "woff", + "woff2", + "ttf", + "otf", + "eot", # Styles # 'css', 'less', 'scss', 'sass', # Media - 'mp4', 'webm', 'ogg', 'avi', 'mov', 'wmv', 'flv', 'm4v', - 'mp3', 'wav', 'aac', 'm4a', 'opus', 'flac', + "mp4", + "webm", + "ogg", + "avi", + "mov", + "wmv", + "flv", + "m4v", + "mp3", + "wav", + "aac", + "m4a", + "opus", + "flac", # Documents - 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', + "pdf", + "doc", + "docx", + "xls", + "xlsx", + "ppt", + "pptx", # Archives - 'zip', 'rar', '7z', 'tar', 'gz', + "zip", + "rar", + "7z", + "tar", + "gz", # Scripts and data - 'xml', 'swf', 'wasm' + "xml", + "swf", + "wasm", ] - + # Common context settings context_settings = { "user_agent": user_agent, @@ -568,7 +618,7 @@ class BrowserManager: "device_scale_factor": 1.0, "java_script_enabled": self.config.java_script_enabled, } - + if self.config.text_mode: text_mode_settings = { "has_touch": False, @@ -576,22 +626,22 @@ class BrowserManager: } # Update context settings with text mode settings context_settings.update(text_mode_settings) - + # Create and return the context with all settings context = await self.browser.new_context(**context_settings) - + # Apply text mode settings if enabled if self.config.text_mode: # Create and apply route patterns for each extension for ext in blocked_extensions: await context.route(f"**/*.{ext}", lambda route: route.abort()) return context - + # async def get_page(self, session_id: Optional[str], user_agent: str): async def get_page(self, crawlerRunConfig: CrawlerRunConfig): """ Get a page for the given session ID, creating a new one if needed. - + Args: crawlerRunConfig (CrawlerRunConfig): Configuration object containing all browser settings @@ -621,8 +671,8 @@ class BrowserManager: async def kill_session(self, session_id: str): """ - Kill a browser session and clean up resources. - + Kill a browser session and clean up resources. + Args: session_id (str): The session ID to kill. """ @@ -672,20 +722,20 @@ class AsyncCrawlerStrategy(ABC): Abstract base class for crawler strategies. Subclasses must implement the crawl method. """ + @abstractmethod async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: pass # 4 + 3 - class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Crawler strategy using Playwright. - + Attributes: browser_config (BrowserConfig): Configuration object containing browser settings. logger (AsyncLogger): Logger instance for recording events and errors. - _downloaded_files (List[str]): List of downloaded file paths. + _downloaded_files (List[str]): List of downloaded file paths. hooks (Dict[str, Callable]): Dictionary of hooks for custom behavior. browser_manager (BrowserManager): Manager for browser creation and management. @@ -704,8 +754,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): Kill a browser session and clean up resources. crawl(self, url, **kwargs): Run the crawler for a single URL. - + """ + def __init__( self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs ): @@ -769,10 +820,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def kill_session(self, session_id: str): """ Kill a browser session and clean up resources. - + Args: session_id (str): The ID of the session to kill. - + Returns: None """ @@ -787,20 +838,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Set a hook function for a specific hook type. Following are list of hook types: - on_browser_created: Called when a new browser instance is created. - - on_page_context_created: Called when a new page context is created. - - on_user_agent_updated: Called when the user agent is updated. - - on_execution_started: Called when the execution starts. - - before_goto: Called before a goto operation. - - after_goto: Called after a goto operation. - - before_return_html: Called before returning HTML content. - - before_retrieve_html: Called before retrieving HTML content. - + - on_page_context_created: Called when a new page context is created. + - on_user_agent_updated: Called when the user agent is updated. + - on_execution_started: Called when the execution starts. + - before_goto: Called before a goto operation. + - after_goto: Called after a goto operation. + - before_return_html: Called before returning HTML content. + - before_retrieve_html: Called before retrieving HTML content. + All hooks except on_browser_created accepts a context and a page as arguments and **kwargs. However, on_browser_created accepts a browser and a context as arguments and **kwargs. - + Args: hook_type (str): The type of the hook. hook (Callable): The hook function to set. - + Returns: None """ @@ -812,12 +863,12 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def execute_hook(self, hook_type: str, *args, **kwargs): """ Execute a hook function for a specific hook type. - + Args: hook_type (str): The type of the hook. *args: Variable length positional arguments. **kwargs: Keyword arguments. - + Returns: The return value of the hook function, if any. """ @@ -832,42 +883,42 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): def update_user_agent(self, user_agent: str): """ Update the user agent for the browser. - + Args: user_agent (str): The new user agent string. - + Returns: None """ self.user_agent = user_agent def set_custom_headers(self, headers: Dict[str, str]): - """ - Set custom headers for the browser. - + """ + Set custom headers for the browser. + Args: headers (Dict[str, str]): A dictionary of headers to set. - + Returns: None """ self.headers = headers async def smart_wait(self, page: Page, wait_for: str, timeout: float = 30000): - """ + """ Wait for a condition in a smart way. This functions works as below: - + 1. If wait_for starts with 'js:', it assumes it's a JavaScript function and waits for it to return true. 2. If wait_for starts with 'css:', it assumes it's a CSS selector and waits for it to be present. 3. Otherwise, it tries to evaluate wait_for as a JavaScript function and waits for it to return true. 4. If it's not a JavaScript function, it assumes it's a CSS selector and waits for it to be present. - - This is a more advanced version of the wait_for parameter in CrawlerStrategy.crawl(). + + This is a more advanced version of the wait_for parameter in CrawlerStrategy.crawl(). Args: page: Playwright page object wait_for (str): The condition to wait for. Can be a CSS selector, a JavaScript function, or explicitly prefixed with 'js:' or 'css:'. timeout (float): Maximum time to wait in milliseconds - + Returns: None """ @@ -917,18 +968,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): "or explicitly prefixed with 'js:' or 'css:'." ) - async def csp_compliant_wait( self, page: Page, user_wait_function: str, timeout: float = 30000 ): + async def csp_compliant_wait( + self, page: Page, user_wait_function: str, timeout: float = 30000 + ): """ Wait for a condition in a CSP-compliant way. - + Args: page: Playwright page object user_wait_function: JavaScript function as string that returns boolean timeout: Maximum time to wait in milliseconds - + Returns: bool: True if condition was met, False if timed out - + Raises: RuntimeError: If there's an error evaluating the condition """ @@ -964,10 +1017,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def process_iframes(self, page): """ Process iframes on a page. This function will extract the content of each iframe and replace it with a div containing the extracted content. - + Args: page: Playwright page object - + Returns: Playwright page object """ @@ -1029,10 +1082,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Creates a new browser session and returns its ID. A browse session is a unique openned page can be reused for multiple crawls. This function is asynchronous and returns a string representing the session ID. - + Args: **kwargs: Optional keyword arguments to configure the session. - + Returns: str: The session ID. """ @@ -1045,7 +1098,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): page, context = await self.browser_manager.get_page(session_id, user_agent) return session_id - async def crawl( self, url: str, config: CrawlerRunConfig, **kwargs ) -> AsyncCrawlResponse: + async def crawl( + self, url: str, config: CrawlerRunConfig, **kwargs + ) -> AsyncCrawlResponse: """ Crawls a given URL or processes raw HTML/local file content based on the URL prefix. @@ -1104,7 +1159,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): "URL must start with 'http://', 'https://', 'file://', or 'raw:'" ) - async def _crawl_web( self, url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse: + async def _crawl_web( + self, url: str, config: CrawlerRunConfig + ) -> AsyncCrawlResponse: """ Internal method to crawl web URLs with the specified configuration. @@ -1188,11 +1245,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): try: # Generate a unique nonce for this request nonce = hashlib.sha256(os.urandom(32)).hexdigest() - + # Add CSP headers to the request - await page.set_extra_http_headers({ - 'Content-Security-Policy': f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'" - }) + await page.set_extra_http_headers( + { + "Content-Security-Policy": f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'" + } + ) response = await page.goto( url, wait_until=config.wait_until, timeout=config.page_timeout @@ -1200,7 +1259,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): except Error as e: raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}") - await self.execute_hook("after_goto", page, context=context, url=url, response=response) + await self.execute_hook( + "after_goto", page, context=context, url=url, response=response + ) if response is None: status_code = 200 @@ -1216,7 +1277,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Wait for body element and visibility try: await page.wait_for_selector("body", state="attached", timeout=30000) - + # Use the new check_visibility function with csp_compliant_wait is_visible = await self.csp_compliant_wait( page, @@ -1229,16 +1290,16 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): style.opacity !== '0'; return isVisible; }""", - timeout=30000 + timeout=30000, ) - + if not is_visible and not config.ignore_body_visibility: visibility_info = await self.check_visibility(page) raise Error(f"Body element is hidden: {visibility_info}") - except Error as e: + except Error: visibility_info = await self.check_visibility(page) - + if self.config.verbose: self.logger.debug( message="Body visibility info: {info}", @@ -1247,19 +1308,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): ) if not config.ignore_body_visibility: - raise Error(f"Body element is hidden: {visibility_info}") - - + raise Error(f"Body element is hidden: {visibility_info}") + # try: # await page.wait_for_selector("body", state="attached", timeout=30000) - + # await page.wait_for_function( # """ # () => { # const body = document.body; # const style = window.getComputedStyle(body); - # return style.display !== 'none' && - # style.visibility !== 'hidden' && + # return style.display !== 'none' && + # style.visibility !== 'hidden' && # style.opacity !== '0'; # } # """, @@ -1298,14 +1358,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): ): await page.wait_for_load_state("domcontentloaded") await asyncio.sleep(0.1) - + # Check for image loading with improved error handling images_loaded = await self.csp_compliant_wait( page, "() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)", - timeout=1000 + timeout=1000, ) - + if not images_loaded and self.logger: self.logger.warning( message="Some images failed to load within timeout", @@ -1316,8 +1376,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): if not self.browser_config.text_mode and config.adjust_viewport_to_content: try: dimensions = await self.get_page_dimensions(page) - page_height = dimensions['height'] - page_width = dimensions['width'] + page_height = dimensions["height"] + page_width = dimensions["width"] # page_width = await page.evaluate( # "document.documentElement.scrollWidth" # ) @@ -1361,16 +1421,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # elif isinstance(config.js_code, list): # for js in config.js_code: # await page.evaluate(js) - + if config.js_code: # execution_result = await self.execute_user_script(page, config.js_code) - execution_result = await self.robust_execute_user_script(page, config.js_code) + execution_result = await self.robust_execute_user_script( + page, config.js_code + ) if not execution_result["success"]: self.logger.warning( message="User script execution had issues: {error}", tag="JS_EXEC", - params={"error": execution_result.get("error")} - ) + params={"error": execution_result.get("error")}, + ) await self.execute_hook("on_execution_started", page, context=context) @@ -1385,7 +1447,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Todo: Decide how to handle this if not config.wait_for and config.css_selector and False: config.wait_for = f"css:{config.css_selector}" - + if config.wait_for: try: await self.smart_wait( @@ -1425,7 +1487,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Get final HTML content html = await page.content() - await self.execute_hook("before_return_html", page = page, html = html, context=context) + await self.execute_hook( + "before_return_html", page=page, html=html, context=context + ) # Handle PDF and screenshot generation start_export_time = time.perf_counter() @@ -1475,7 +1539,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): except Exception as e: raise e - + finally: # If no session_id is given we should close the page if not config.session_id: @@ -1483,20 +1547,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def _handle_full_page_scan(self, page: Page, scroll_delay: float = 0.1): """ - Helper method to handle full page scanning. - + Helper method to handle full page scanning. + How it works: 1. Get the viewport height. 2. Scroll to the bottom of the page. 3. Get the total height of the page. 4. Scroll back to the top of the page. - 5. Scroll to the bottom of the page again. + 5. Scroll to the bottom of the page again. 6. Continue scrolling until the bottom of the page is reached. - + Args: page (Page): The Playwright page object scroll_delay (float): The delay between page scrolls - + """ try: viewport_height = page.viewport_size.get( @@ -1511,8 +1575,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # total_height = await page.evaluate("document.documentElement.scrollHeight") dimensions = await self.get_page_dimensions(page) - total_height = dimensions['height'] - + total_height = dimensions["height"] + while current_position < total_height: current_position = min(current_position + viewport_height, total_height) await self.safe_scroll(page, 0, current_position, delay=scroll_delay) @@ -1521,8 +1585,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # new_height = await page.evaluate("document.documentElement.scrollHeight") dimensions = await self.get_page_dimensions(page) - new_height = dimensions['height'] - + new_height = dimensions["height"] + if new_height > total_height: total_height = new_height @@ -1542,7 +1606,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def _handle_download(self, download): """ Handle file downloads. - + How it works: 1. Get the suggested filename. 2. Get the download path. @@ -1550,10 +1614,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): 4. Start the download. 5. Save the downloaded file. 6. Log the completion. - + Args: download (Download): The Playwright download object - + Returns: None """ @@ -1598,7 +1662,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): remove_overlays_js = load_js_script("remove_overlay_elements") try: - await page.evaluate(f""" + await page.evaluate( + f""" (() => {{ try {{ {remove_overlays_js} @@ -1611,7 +1676,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): }}; }} }})() - """) + """ + ) await page.wait_for_timeout(500) # Wait for any animations to complete except Exception as e: self.logger.warning( @@ -1623,10 +1689,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def export_pdf(self, page: Page) -> bytes: """ Exports the current page as a PDF. - + Args: page (Page): The Playwright page object - + Returns: bytes: The PDF data """ @@ -1636,16 +1702,16 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def take_screenshot(self, page, **kwargs) -> str: """ Take a screenshot of the current page. - + Args: page (Page): The Playwright page object kwargs: Additional keyword arguments - + Returns: str: The base64-encoded screenshot data """ need_scroll = await self.page_need_scroll(page) - + if not need_scroll: # Page is short enough, just take a screenshot return await self.take_screenshot_naive(page) @@ -1656,13 +1722,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def take_screenshot_from_pdf(self, pdf_data: bytes) -> str: """ - Convert the first page of the PDF to a screenshot. - + Convert the first page of the PDF to a screenshot. + Requires pdf2image and poppler. - + Args: pdf_data (bytes): The PDF data - + Returns: str: The base64-encoded screenshot data """ @@ -1694,21 +1760,21 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Attempt to set a large viewport and take a full-page screenshot. If still too large, segment the page as before. - + Requires pdf2image and poppler. - + Args: page (Page): The Playwright page object kwargs: Additional keyword arguments - + Returns: str: The base64-encoded screenshot data """ try: # Get page height dimensions = await self.get_page_dimensions(page) - page_width = dimensions['width'] - page_height = dimensions['height'] + page_width = dimensions["width"] + page_height = dimensions["height"] # page_height = await page.evaluate("document.documentElement.scrollHeight") # page_width = await page.evaluate("document.documentElement.scrollWidth") @@ -1805,10 +1871,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Exports the current storage state (cookies, localStorage, sessionStorage) to a JSON file at the specified path. - + Args: path (str): The path to save the storage state JSON file - + Returns: dict: The exported storage state """ @@ -1826,33 +1892,35 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): tag="WARNING", ) - async def robust_execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]: + async def robust_execute_user_script( + self, page: Page, js_code: Union[str, List[str]] + ) -> Dict[str, Any]: """ Executes user-provided JavaScript code with proper error handling and context, supporting both synchronous and async user code, plus navigations. - + How it works: 1. Wait for load state 'domcontentloaded' 2. If js_code is a string, execute it directly 3. If js_code is a list, execute each element in sequence - 4. Wait for load state 'networkidle' - 5. Return results - - Args: + 4. Wait for load state 'networkidle' + 5. Return results + + Args: page (Page): The Playwright page instance js_code (Union[str, List[str]]): The JavaScript code to execute - + Returns: Dict[str, Any]: The results of the execution """ try: - await page.wait_for_load_state('domcontentloaded') - + await page.wait_for_load_state("domcontentloaded") + if isinstance(js_code, str): scripts = [js_code] else: scripts = js_code - + results = [] for script in scripts: try: @@ -1861,7 +1929,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # then wait for the new page to load before continuing result = None try: - result = await page.evaluate(f""" + result = await page.evaluate( + f""" (async () => {{ try {{ {script} @@ -1870,53 +1939,62 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): return {{ success: false, error: err.toString(), stack: err.stack }}; }} }})(); - """) + """ + ) except Error as e: # If it's due to navigation destroying the context, handle gracefully if "Execution context was destroyed" in str(e): - self.logger.info("Navigation triggered by script, waiting for load state", tag="JS_EXEC") + self.logger.info( + "Navigation triggered by script, waiting for load state", + tag="JS_EXEC", + ) try: - await page.wait_for_load_state('load', timeout=30000) + await page.wait_for_load_state("load", timeout=30000) except Error as nav_err: self.logger.warning( message="Navigation wait failed: {error}", tag="JS_EXEC", - params={"error": str(nav_err)} + params={"error": str(nav_err)}, ) try: - await page.wait_for_load_state('networkidle', timeout=30000) + await page.wait_for_load_state( + "networkidle", timeout=30000 + ) except Error as nav_err: self.logger.warning( message="Network idle wait failed: {error}", tag="JS_EXEC", - params={"error": str(nav_err)} + params={"error": str(nav_err)}, ) # Return partial success, or adapt as you see fit result = { "success": True, - "info": "Navigation triggered, ignoring context destroyed error" + "info": "Navigation triggered, ignoring context destroyed error", } else: # It's some other error, log and continue self.logger.error( message="Playwright execution error: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) result = {"success": False, "error": str(e)} - + # If we made it this far with no repeated error, do post-load waits t1 = time.time() try: - await page.wait_for_load_state('domcontentloaded', timeout=5000) - print("DOM content loaded after script execution in", time.time() - t1) + await page.wait_for_load_state("domcontentloaded", timeout=5000) + print( + "DOM content loaded after script execution in", + time.time() - t1, + ) except Error as e: self.logger.warning( message="DOM content load timeout: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) - + # t1 = time.time() # try: # await page.wait_for_load_state('networkidle', timeout=5000) @@ -1935,46 +2013,49 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): self.logger.error( message="Script chunk failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) results.append({"success": False, "error": str(e)}) return {"success": True, "results": results} - + except Exception as e: self.logger.error( message="Script execution failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) return {"success": False, "error": str(e)} - async def execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]: + async def execute_user_script( + self, page: Page, js_code: Union[str, List[str]] + ) -> Dict[str, Any]: """ Executes user-provided JavaScript code with proper error handling and context. - + Args: page: Playwright page object js_code: Single JavaScript string or list of JavaScript code strings - + Returns: Dict containing execution status and results/errors """ try: # Ensure the page is ready for script execution - await page.wait_for_load_state('domcontentloaded') - + await page.wait_for_load_state("domcontentloaded") + # Handle single script or multiple scripts if isinstance(js_code, str): scripts = [js_code] else: scripts = js_code - + results = [] for script in scripts: try: # Execute the script and wait for network idle - result = await page.evaluate(f""" + result = await page.evaluate( + f""" (() => {{ return new Promise((resolve) => {{ try {{ @@ -2007,57 +2088,61 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): }} }}); }})() - """) - + """ + ) + # Wait for network idle after script execution t1 = time.time() - await page.wait_for_load_state('domcontentloaded', timeout=5000) - print("DOM content loaded after script execution in", time.time() - t1) + await page.wait_for_load_state("domcontentloaded", timeout=5000) + print( + "DOM content loaded after script execution in", time.time() - t1 + ) t1 = time.time() - await page.wait_for_load_state('networkidle', timeout=5000) + await page.wait_for_load_state("networkidle", timeout=5000) print("Network idle after script execution in", time.time() - t1) - + results.append(result if result else {"success": True}) - + except Error as e: # Handle Playwright-specific errors self.logger.error( message="Playwright execution error: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) results.append({"success": False, "error": str(e)}) - + return {"success": True, "results": results} - + except Exception as e: self.logger.error( message="Script execution failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) return {"success": False, "error": str(e)} - + except Exception as e: self.logger.error( message="Script execution failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) return {"success": False, "error": str(e)} async def check_visibility(self, page): """ Checks if an element is visible on the page. - + Args: page: Playwright page object - + Returns: Boolean indicating visibility """ - return await page.evaluate(""" + return await page.evaluate( + """ () => { const element = document.body; if (!element) return false; @@ -2067,31 +2152,32 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): style.opacity !== '0'; return isVisible; } - """) - + """ + ) + async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1): """ Safely scroll the page with rendering time. - + Args: page: Playwright page object x: Horizontal scroll position y: Vertical scroll position """ result = await self.csp_scroll_to(page, x, y) - if result['success']: + if result["success"]: await page.wait_for_timeout(delay * 1000) return result - + async def csp_scroll_to(self, page: Page, x: int, y: int) -> Dict[str, Any]: """ Performs a CSP-compliant scroll operation and returns the result status. - + Args: page: Playwright page object x: Horizontal scroll position y: Vertical scroll position - + Returns: Dict containing scroll status and position information """ @@ -2125,67 +2211,68 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): }} }}""" ) - - if not result['success']: + + if not result["success"]: self.logger.warning( message="Scroll operation failed: {error}", tag="SCROLL", - params={"error": result.get('error')} + params={"error": result.get("error")}, ) - + return result - + except Exception as e: self.logger.error( message="Failed to execute scroll: {error}", tag="SCROLL", - params={"error": str(e)} + params={"error": str(e)}, ) - return { - "success": False, - "error": str(e) - } - + return {"success": False, "error": str(e)} + async def get_page_dimensions(self, page: Page): """ Get the dimensions of the page. - + Args: page: Playwright page object - + Returns: Dict containing width and height of the page """ - return await page.evaluate(""" + return await page.evaluate( + """ () => { const {scrollWidth, scrollHeight} = document.documentElement; return {width: scrollWidth, height: scrollHeight}; } - """) - + """ + ) + async def page_need_scroll(self, page: Page) -> bool: """ Determine whether the page need to scroll - + Args: page: Playwright page object - + Returns: bool: True if page needs scrolling """ try: - need_scroll = await page.evaluate(""" + need_scroll = await page.evaluate( + """ () => { const scrollHeight = document.documentElement.scrollHeight; const viewportHeight = window.innerHeight; return scrollHeight > viewportHeight; } - """) + """ + ) return need_scroll except Exception as e: self.logger.warning( message="Failed to check scroll need: {error}. Defaulting to True for safety.", tag="SCROLL", - params={"error": str(e)} + params={"error": str(e)}, ) - return True # Default to scrolling if check fails \ No newline at end of file + return True # Default to scrolling if check fails diff --git a/crawl4ai/async_database.py b/crawl4ai/async_database.py index aed9c76b..669ddec2 100644 --- a/crawl4ai/async_database.py +++ b/crawl4ai/async_database.py @@ -1,27 +1,29 @@ -import os, sys +import os from pathlib import Path import aiosqlite import asyncio -from typing import Optional, Tuple, Dict +from typing import Optional, Dict from contextlib import asynccontextmanager import logging import json # Added for serialization/deserialization from .utils import ensure_content_dirs, generate_content_hash from .models import CrawlResult, MarkdownGenerationResult -import xxhash import aiofiles -from .config import NEED_MIGRATION from .version_manager import VersionManager from .async_logger import AsyncLogger from .utils import get_error_context, create_box_message + # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") +base_directory = DB_PATH = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai" +) os.makedirs(DB_PATH, exist_ok=True) DB_PATH = os.path.join(base_directory, "crawl4ai.db") + class AsyncDatabaseManager: def __init__(self, pool_size: int = 10, max_retries: int = 3): self.db_path = DB_PATH @@ -32,28 +34,27 @@ class AsyncDatabaseManager: self.pool_lock = asyncio.Lock() self.init_lock = asyncio.Lock() self.connection_semaphore = asyncio.Semaphore(pool_size) - self._initialized = False + self._initialized = False self.version_manager = VersionManager() self.logger = AsyncLogger( log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), verbose=False, - tag_width=10 + tag_width=10, ) - - + async def initialize(self): """Initialize the database and connection pool""" try: self.logger.info("Initializing database", tag="INIT") # Ensure the database file exists os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - + # Check if version update is needed needs_update = self.version_manager.needs_update() - + # Always ensure base table exists await self.ainit_db() - + # Verify the table exists async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with db.execute( @@ -62,33 +63,37 @@ class AsyncDatabaseManager: result = await cursor.fetchone() if not result: raise Exception("crawled_data table was not created") - + # If version changed or fresh install, run updates if needs_update: self.logger.info("New version detected, running updates", tag="INIT") await self.update_db_schema() - from .migrations import run_migration # Import here to avoid circular imports + from .migrations import ( + run_migration, + ) # Import here to avoid circular imports + await run_migration() self.version_manager.update_version() # Update stored version after successful migration - self.logger.success("Version update completed successfully", tag="COMPLETE") + self.logger.success( + "Version update completed successfully", tag="COMPLETE" + ) else: - self.logger.success("Database initialization completed successfully", tag="COMPLETE") + self.logger.success( + "Database initialization completed successfully", tag="COMPLETE" + ) - except Exception as e: self.logger.error( message="Database initialization error: {error}", tag="ERROR", - params={"error": str(e)} + params={"error": str(e)}, ) self.logger.info( - message="Database will be initialized on first use", - tag="INIT" + message="Database will be initialized on first use", tag="INIT" ) - + raise - async def cleanup(self): """Cleanup connections when shutting down""" async with self.pool_lock: @@ -107,6 +112,7 @@ class AsyncDatabaseManager: self._initialized = True except Exception as e: import sys + error_context = get_error_context(sys.exc_info()) self.logger.error( message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", @@ -115,41 +121,52 @@ class AsyncDatabaseManager: params={ "error": str(e), "context": error_context["code_context"], - "traceback": error_context["full_traceback"] - } + "traceback": error_context["full_traceback"], + }, ) raise await self.connection_semaphore.acquire() task_id = id(asyncio.current_task()) - + try: async with self.pool_lock: if task_id not in self.connection_pool: try: - conn = await aiosqlite.connect( - self.db_path, - timeout=30.0 - ) - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA busy_timeout = 5000') - + conn = await aiosqlite.connect(self.db_path, timeout=30.0) + await conn.execute("PRAGMA journal_mode = WAL") + await conn.execute("PRAGMA busy_timeout = 5000") + # Verify database structure - async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: + async with conn.execute( + "PRAGMA table_info(crawled_data)" + ) as cursor: columns = await cursor.fetchall() column_names = [col[1] for col in columns] expected_columns = { - 'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', - 'success', 'media', 'links', 'metadata', 'screenshot', - 'response_headers', 'downloaded_files' + "url", + "html", + "cleaned_html", + "markdown", + "extracted_content", + "success", + "media", + "links", + "metadata", + "screenshot", + "response_headers", + "downloaded_files", } missing_columns = expected_columns - set(column_names) if missing_columns: - raise ValueError(f"Database missing columns: {missing_columns}") - + raise ValueError( + f"Database missing columns: {missing_columns}" + ) + self.connection_pool[task_id] = conn except Exception as e: import sys + error_context = get_error_context(sys.exc_info()) error_message = ( f"Unexpected error in db get_connection at line {error_context['line_no']} " @@ -158,7 +175,7 @@ class AsyncDatabaseManager: f"Code context:\n{error_context['code_context']}" ) self.logger.error( - message=create_box_message(error_message, type= "error"), + message=create_box_message(error_message, type="error"), ) raise @@ -167,6 +184,7 @@ class AsyncDatabaseManager: except Exception as e: import sys + error_context = get_error_context(sys.exc_info()) error_message = ( f"Unexpected error in db get_connection at line {error_context['line_no']} " @@ -175,7 +193,7 @@ class AsyncDatabaseManager: f"Code context:\n{error_context['code_context']}" ) self.logger.error( - message=create_box_message(error_message, type= "error"), + message=create_box_message(error_message, type="error"), ) raise finally: @@ -185,7 +203,6 @@ class AsyncDatabaseManager: del self.connection_pool[task_id] self.connection_semaphore.release() - async def execute_with_retry(self, operation, *args): """Execute database operations with retry logic""" for attempt in range(self.max_retries): @@ -200,18 +217,16 @@ class AsyncDatabaseManager: message="Operation failed after {retries} attempts: {error}", tag="ERROR", force_verbose=True, - params={ - "retries": self.max_retries, - "error": str(e) - } - ) + params={"retries": self.max_retries, "error": str(e)}, + ) raise await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff async def ainit_db(self): """Initialize database schema""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: - await db.execute(''' + await db.execute( + """ CREATE TABLE IF NOT EXISTS crawled_data ( url TEXT PRIMARY KEY, html TEXT, @@ -226,21 +241,27 @@ class AsyncDatabaseManager: response_headers TEXT DEFAULT "{}", downloaded_files TEXT DEFAULT "{}" -- New column added ) - ''') + """ + ) await db.commit() - - async def update_db_schema(self): """Update database schema if needed""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: cursor = await db.execute("PRAGMA table_info(crawled_data)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] - + # List of new columns to add - new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] - + new_columns = [ + "media", + "links", + "metadata", + "screenshot", + "response_headers", + "downloaded_files", + ] + for column in new_columns: if column not in column_names: await self.aalter_db_add_column(column, db) @@ -248,75 +269,91 @@ class AsyncDatabaseManager: async def aalter_db_add_column(self, new_column: str, db): """Add new column to the database""" - if new_column == 'response_headers': - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') + if new_column == "response_headers": + await db.execute( + f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"' + ) else: - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') + await db.execute( + f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""' + ) self.logger.info( message="Added column '{column}' to the database", tag="INIT", - params={"column": new_column} - ) - + params={"column": new_column}, + ) async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: """Retrieve cached URL data as CrawlResult""" + async def _get(db): async with db.execute( - 'SELECT * FROM crawled_data WHERE url = ?', (url,) + "SELECT * FROM crawled_data WHERE url = ?", (url,) ) as cursor: row = await cursor.fetchone() if not row: return None - + # Get column names columns = [description[0] for description in cursor.description] # Create dict from row data row_dict = dict(zip(columns, row)) - + # Load content from files using stored hashes content_fields = { - 'html': row_dict['html'], - 'cleaned_html': row_dict['cleaned_html'], - 'markdown': row_dict['markdown'], - 'extracted_content': row_dict['extracted_content'], - 'screenshot': row_dict['screenshot'], - 'screenshots': row_dict['screenshot'], + "html": row_dict["html"], + "cleaned_html": row_dict["cleaned_html"], + "markdown": row_dict["markdown"], + "extracted_content": row_dict["extracted_content"], + "screenshot": row_dict["screenshot"], + "screenshots": row_dict["screenshot"], } - + for field, hash_value in content_fields.items(): if hash_value: content = await self._load_content( - hash_value, - field.split('_')[0] # Get content type from field name + hash_value, + field.split("_")[0], # Get content type from field name ) row_dict[field] = content or "" else: row_dict[field] = "" # Parse JSON fields - json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] + json_fields = [ + "media", + "links", + "metadata", + "response_headers", + "markdown", + ] for field in json_fields: try: - row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} + row_dict[field] = ( + json.loads(row_dict[field]) if row_dict[field] else {} + ) except json.JSONDecodeError: row_dict[field] = {} - if isinstance(row_dict['markdown'], Dict): - row_dict['markdown_v2'] = row_dict['markdown'] - if row_dict['markdown'].get('raw_markdown'): - row_dict['markdown'] = row_dict['markdown']['raw_markdown'] - + if isinstance(row_dict["markdown"], Dict): + row_dict["markdown_v2"] = row_dict["markdown"] + if row_dict["markdown"].get("raw_markdown"): + row_dict["markdown"] = row_dict["markdown"]["raw_markdown"] + # Parse downloaded_files try: - row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] + row_dict["downloaded_files"] = ( + json.loads(row_dict["downloaded_files"]) + if row_dict["downloaded_files"] + else [] + ) except json.JSONDecodeError: - row_dict['downloaded_files'] = [] + row_dict["downloaded_files"] = [] # Remove any fields not in CrawlResult model valid_fields = CrawlResult.__annotations__.keys() filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields} - + return CrawlResult(**filtered_dict) try: @@ -326,7 +363,7 @@ class AsyncDatabaseManager: message="Error retrieving cached URL: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) return None @@ -334,37 +371,52 @@ class AsyncDatabaseManager: """Cache CrawlResult data""" # Store content files and get hashes content_map = { - 'html': (result.html, 'html'), - 'cleaned_html': (result.cleaned_html or "", 'cleaned'), - 'markdown': None, - 'extracted_content': (result.extracted_content or "", 'extracted'), - 'screenshot': (result.screenshot or "", 'screenshots') + "html": (result.html, "html"), + "cleaned_html": (result.cleaned_html or "", "cleaned"), + "markdown": None, + "extracted_content": (result.extracted_content or "", "extracted"), + "screenshot": (result.screenshot or "", "screenshots"), } try: if isinstance(result.markdown, MarkdownGenerationResult): - content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') - elif hasattr(result, 'markdown_v2'): - content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') + content_map["markdown"] = ( + result.markdown.model_dump_json(), + "markdown", + ) + elif hasattr(result, "markdown_v2"): + content_map["markdown"] = ( + result.markdown_v2.model_dump_json(), + "markdown", + ) elif isinstance(result.markdown, str): markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) - content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') + content_map["markdown"] = ( + markdown_result.model_dump_json(), + "markdown", + ) else: - content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') + content_map["markdown"] = ( + MarkdownGenerationResult().model_dump_json(), + "markdown", + ) except Exception as e: self.logger.warning( - message=f"Error processing markdown content: {str(e)}", - tag="WARNING" + message=f"Error processing markdown content: {str(e)}", tag="WARNING" ) # Fallback to empty markdown result - content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') - + content_map["markdown"] = ( + MarkdownGenerationResult().model_dump_json(), + "markdown", + ) + content_hashes = {} for field, (content, content_type) in content_map.items(): content_hashes[field] = await self._store_content(content, content_type) async def _cache(db): - await db.execute(''' + await db.execute( + """ INSERT INTO crawled_data ( url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, @@ -383,20 +435,22 @@ class AsyncDatabaseManager: screenshot = excluded.screenshot, response_headers = excluded.response_headers, downloaded_files = excluded.downloaded_files - ''', ( - result.url, - content_hashes['html'], - content_hashes['cleaned_html'], - content_hashes['markdown'], - content_hashes['extracted_content'], - result.success, - json.dumps(result.media), - json.dumps(result.links), - json.dumps(result.metadata or {}), - content_hashes['screenshot'], - json.dumps(result.response_headers or {}), - json.dumps(result.downloaded_files or []) - )) + """, + ( + result.url, + content_hashes["html"], + content_hashes["cleaned_html"], + content_hashes["markdown"], + content_hashes["extracted_content"], + result.success, + json.dumps(result.media), + json.dumps(result.links), + json.dumps(result.metadata or {}), + content_hashes["screenshot"], + json.dumps(result.response_headers or {}), + json.dumps(result.downloaded_files or []), + ), + ) try: await self.execute_with_retry(_cache) @@ -405,14 +459,14 @@ class AsyncDatabaseManager: message="Error caching URL: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) - async def aget_total_count(self) -> int: """Get total number of cached URLs""" + async def _count(db): - async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: + async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor: result = await cursor.fetchone() return result[0] if result else 0 @@ -423,14 +477,15 @@ class AsyncDatabaseManager: message="Error getting total count: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) return 0 async def aclear_db(self): """Clear all data from the database""" + async def _clear(db): - await db.execute('DELETE FROM crawled_data') + await db.execute("DELETE FROM crawled_data") try: await self.execute_with_retry(_clear) @@ -439,13 +494,14 @@ class AsyncDatabaseManager: message="Error clearing database: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) async def aflush_db(self): """Drop the entire table""" + async def _flush(db): - await db.execute('DROP TABLE IF EXISTS crawled_data') + await db.execute("DROP TABLE IF EXISTS crawled_data") try: await self.execute_with_retry(_flush) @@ -454,42 +510,44 @@ class AsyncDatabaseManager: message="Error flushing database: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) - - + async def _store_content(self, content: str, content_type: str) -> str: """Store content in filesystem and return hash""" if not content: return "" - + content_hash = generate_content_hash(content) file_path = os.path.join(self.content_paths[content_type], content_hash) - + # Only write if file doesn't exist if not os.path.exists(file_path): - async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: + async with aiofiles.open(file_path, "w", encoding="utf-8") as f: await f.write(content) - + return content_hash - async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: + async def _load_content( + self, content_hash: str, content_type: str + ) -> Optional[str]: """Load content from filesystem by hash""" if not content_hash: return None - + file_path = os.path.join(self.content_paths[content_type], content_hash) try: - async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: return await f.read() except: self.logger.error( message="Failed to load content: {file_path}", tag="ERROR", force_verbose=True, - params={"file_path": file_path} + params={"file_path": file_path}, ) return None + # Create a singleton instance async_db_manager = AsyncDatabaseManager() diff --git a/crawl4ai/async_dispatcher.py b/crawl4ai/async_dispatcher.py index 8f5fbe81..b796a92b 100644 --- a/crawl4ai/async_dispatcher.py +++ b/crawl4ai/async_dispatcher.py @@ -1,14 +1,19 @@ -from typing import Dict, Optional, List -from .async_configs import * -from .models import * +from typing import Dict, Optional, List, Tuple +from .async_configs import CrawlerRunConfig +from .models import ( + CrawlResult, + CrawlerTaskResult, + CrawlStatus, + DisplayMode, + CrawlStats, + DomainState, +) from rich.live import Live from rich.table import Table from rich.console import Console -from rich.style import Style from rich import box from datetime import datetime, timedelta -from dataclasses import dataclass import time import psutil @@ -26,63 +31,66 @@ class RateLimiter: base_delay: Tuple[float, float] = (1.0, 3.0), max_delay: float = 60.0, max_retries: int = 3, - rate_limit_codes: List[int] = None + rate_limit_codes: List[int] = None, ): self.base_delay = base_delay self.max_delay = max_delay self.max_retries = max_retries self.rate_limit_codes = rate_limit_codes or [429, 503] self.domains: Dict[str, DomainState] = {} - + def get_domain(self, url: str) -> str: return urlparse(url).netloc - + async def wait_if_needed(self, url: str) -> None: domain = self.get_domain(url) state = self.domains.get(domain) - + if not state: self.domains[domain] = DomainState() state = self.domains[domain] - + now = time.time() if state.last_request_time: wait_time = max(0, state.current_delay - (now - state.last_request_time)) if wait_time > 0: await asyncio.sleep(wait_time) - + # Random delay within base range if no current delay if state.current_delay == 0: state.current_delay = random.uniform(*self.base_delay) - + state.last_request_time = time.time() - + def update_delay(self, url: str, status_code: int) -> bool: domain = self.get_domain(url) state = self.domains[domain] - + if status_code in self.rate_limit_codes: state.fail_count += 1 if state.fail_count > self.max_retries: return False - + # Exponential backoff with random jitter state.current_delay = min( - state.current_delay * 2 * random.uniform(0.75, 1.25), - self.max_delay + state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay ) else: # Gradually reduce delay on success state.current_delay = max( - random.uniform(*self.base_delay), - state.current_delay * 0.75 + random.uniform(*self.base_delay), state.current_delay * 0.75 ) state.fail_count = 0 - + return True + class CrawlerMonitor: - def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED): + def __init__( + self, + max_visible_rows: int = 15, + display_mode: DisplayMode = DisplayMode.DETAILED, + ): self.console = Console() self.max_visible_rows = max_visible_rows self.display_mode = display_mode @@ -90,23 +98,25 @@ class CrawlerMonitor: self.process = psutil.Process() self.start_time = datetime.now() self.live = Live(self._create_table(), refresh_per_second=2) - + def start(self): self.live.start() - + def stop(self): self.live.stop() - + def add_task(self, task_id: str, url: str): - self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED) + self.stats[task_id] = CrawlStats( + task_id=task_id, url=url, status=CrawlStatus.QUEUED + ) self.live.update(self._create_table()) - + def update_task(self, task_id: str, **kwargs): if task_id in self.stats: for key, value in kwargs.items(): setattr(self.stats[task_id], key, value) self.live.update(self._create_table()) - + def _create_aggregated_table(self) -> Table: """Creates a compact table showing only aggregated statistics""" table = Table( @@ -114,78 +124,78 @@ class CrawlerMonitor: title="Crawler Status Overview", title_style="bold magenta", header_style="bold blue", - show_lines=True + show_lines=True, ) - + # Calculate statistics total_tasks = len(self.stats) - queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED) - in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS) - completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED) - failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED) - + queued = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED + ) + in_progress = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS + ) + completed = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED + ) + failed = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED + ) + # Memory statistics current_memory = self.process.memory_info().rss / (1024 * 1024) total_task_memory = sum(stat.memory_usage for stat in self.stats.values()) - peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0) - + peak_memory = max( + (stat.peak_memory for stat in self.stats.values()), default=0.0 + ) + # Duration duration = datetime.now() - self.start_time - + # Create status row table.add_column("Status", style="bold cyan") table.add_column("Count", justify="right") table.add_column("Percentage", justify="right") - - table.add_row( - "Total Tasks", - str(total_tasks), - "100%" - ) + + table.add_row("Total Tasks", str(total_tasks), "100%") table.add_row( "[yellow]In Queue[/yellow]", str(queued), - f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) table.add_row( "[blue]In Progress[/blue]", str(in_progress), - f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) table.add_row( "[green]Completed[/green]", str(completed), - f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) table.add_row( "[red]Failed[/red]", str(failed), - f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) - + # Add memory information table.add_section() table.add_row( - "[magenta]Current Memory[/magenta]", - f"{current_memory:.1f} MB", - "" + "[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", "" ) table.add_row( - "[magenta]Total Task Memory[/magenta]", - f"{total_task_memory:.1f} MB", - "" + "[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", "" ) table.add_row( - "[magenta]Peak Task Memory[/magenta]", - f"{peak_memory:.1f} MB", - "" + "[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", "" ) table.add_row( "[yellow]Runtime[/yellow]", str(timedelta(seconds=int(duration.total_seconds()))), - "" + "", ) - + return table def _create_detailed_table(self) -> Table: @@ -193,9 +203,9 @@ class CrawlerMonitor: box=box.ROUNDED, title="Crawler Performance Monitor", title_style="bold magenta", - header_style="bold blue" + header_style="bold blue", ) - + # Add columns table.add_column("Task ID", style="cyan", no_wrap=True) table.add_column("URL", style="cyan", no_wrap=True) @@ -204,47 +214,54 @@ class CrawlerMonitor: table.add_column("Peak (MB)", justify="right") table.add_column("Duration", justify="right") table.add_column("Info", style="italic") - + # Add summary row total_memory = sum(stat.memory_usage for stat in self.stats.values()) - active_count = sum(1 for stat in self.stats.values() - if stat.status == CrawlStatus.IN_PROGRESS) - completed_count = sum(1 for stat in self.stats.values() - if stat.status == CrawlStatus.COMPLETED) - failed_count = sum(1 for stat in self.stats.values() - if stat.status == CrawlStatus.FAILED) - + active_count = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS + ) + completed_count = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED + ) + failed_count = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED + ) + table.add_row( "[bold yellow]SUMMARY", f"Total: {len(self.stats)}", f"Active: {active_count}", f"{total_memory:.1f}", f"{self.process.memory_info().rss / (1024 * 1024):.1f}", - str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))), + str( + timedelta( + seconds=int((datetime.now() - self.start_time).total_seconds()) + ) + ), f"✓{completed_count} ✗{failed_count}", - style="bold" + style="bold", ) - + table.add_section() - + # Add rows for each task visible_stats = sorted( self.stats.values(), key=lambda x: ( x.status != CrawlStatus.IN_PROGRESS, x.status != CrawlStatus.QUEUED, - x.end_time or datetime.max - ) - )[:self.max_visible_rows] - + x.end_time or datetime.max, + ), + )[: self.max_visible_rows] + for stat in visible_stats: status_style = { CrawlStatus.QUEUED: "white", CrawlStatus.IN_PROGRESS: "yellow", CrawlStatus.COMPLETED: "green", - CrawlStatus.FAILED: "red" + CrawlStatus.FAILED: "red", }[stat.status] - + table.add_row( stat.task_id[:8], # Show first 8 chars of task ID stat.url[:40] + "..." if len(stat.url) > 40 else stat.url, @@ -252,9 +269,9 @@ class CrawlerMonitor: f"{stat.memory_usage:.1f}", f"{stat.peak_memory:.1f}", stat.duration, - stat.error_message[:40] if stat.error_message else "" + stat.error_message[:40] if stat.error_message else "", ) - + return table def _create_table(self) -> Table: @@ -268,7 +285,7 @@ class BaseDispatcher(ABC): def __init__( self, rate_limiter: Optional[RateLimiter] = None, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ): self.crawler = None self._domain_last_hit: Dict[str, float] = {} @@ -278,24 +295,25 @@ class BaseDispatcher(ABC): @abstractmethod async def crawl_url( - self, - url: str, - config: CrawlerRunConfig, + self, + url: str, + config: CrawlerRunConfig, task_id: str, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ) -> CrawlerTaskResult: pass @abstractmethod async def run_urls( - self, - urls: List[str], - crawler: "AsyncWebCrawler", + self, + urls: List[str], + crawler: "AsyncWebCrawler", # noqa: F821 config: CrawlerRunConfig, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ) -> List[CrawlerTaskResult]: pass + class MemoryAdaptiveDispatcher(BaseDispatcher): def __init__( self, @@ -304,39 +322,41 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): max_session_permit: int = 20, memory_wait_timeout: float = 300.0, # 5 minutes default timeout rate_limiter: Optional[RateLimiter] = None, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ): super().__init__(rate_limiter, monitor) self.memory_threshold_percent = memory_threshold_percent self.check_interval = check_interval self.max_session_permit = max_session_permit self.memory_wait_timeout = memory_wait_timeout - + async def crawl_url( - self, - url: str, - config: CrawlerRunConfig, + self, + url: str, + config: CrawlerRunConfig, task_id: str, ) -> CrawlerTaskResult: start_time = datetime.now() error_message = "" memory_usage = peak_memory = 0.0 - + try: if self.monitor: - self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) + self.monitor.update_task( + task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time + ) self.concurrent_sessions += 1 - + if self.rate_limiter: await self.rate_limiter.wait_if_needed(url) - + process = psutil.Process() start_memory = process.memory_info().rss / (1024 * 1024) result = await self.crawler.arun(url, config=config, session_id=task_id) end_memory = process.memory_info().rss / (1024 * 1024) - + memory_usage = peak_memory = end_memory - start_memory - + if self.rate_limiter and result.status_code: if not self.rate_limiter.update_delay(url, result.status_code): error_message = f"Rate limit retry count exceeded for domain {urlparse(url).netloc}" @@ -350,22 +370,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=datetime.now(), - error_message=error_message + error_message=error_message, ) - + if not result.success: error_message = result.error_message if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) elif self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.COMPLETED) - + except Exception as e: error_message = str(e) if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) - result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) - + result = CrawlResult( + url=url, html="", metadata={}, success=False, error_message=str(e) + ) + finally: end_time = datetime.now() if self.monitor: @@ -374,10 +396,10 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): end_time=end_time, memory_usage=memory_usage, peak_memory=peak_memory, - error_message=error_message + error_message=error_message, ) self.concurrent_sessions -= 1 - + return CrawlerTaskResult( task_id=task_id, url=url, @@ -386,20 +408,20 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=end_time, - error_message=error_message + error_message=error_message, ) async def run_urls( - self, - urls: List[str], - crawler: "AsyncWebCrawler", + self, + urls: List[str], + crawler: "AsyncWebCrawler", # noqa: F821 config: CrawlerRunConfig, ) -> List[CrawlerTaskResult]: self.crawler = crawler - + if self.monitor: self.monitor.start() - + try: pending_tasks = [] active_tasks = [] @@ -417,23 +439,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): if psutil.virtual_memory().percent >= self.memory_threshold_percent: # Check if we've exceeded the timeout if time.time() - wait_start_time > self.memory_wait_timeout: - raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds") + raise MemoryError( + f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds" + ) await asyncio.sleep(self.check_interval) continue - + url, task_id = task_queue.pop(0) task = asyncio.create_task(self.crawl_url(url, config, task_id)) active_tasks.append(task) - + if not active_tasks: await asyncio.sleep(self.check_interval) continue - + done, pending = await asyncio.wait( - active_tasks, - return_when=asyncio.FIRST_COMPLETED + active_tasks, return_when=asyncio.FIRST_COMPLETED ) - + pending_tasks.extend(done) active_tasks = list(pending) @@ -442,24 +465,25 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): if self.monitor: self.monitor.stop() + class SemaphoreDispatcher(BaseDispatcher): def __init__( self, semaphore_count: int = 5, max_session_permit: int = 20, rate_limiter: Optional[RateLimiter] = None, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ): super().__init__(rate_limiter, monitor) self.semaphore_count = semaphore_count - self.max_session_permit = max_session_permit - + self.max_session_permit = max_session_permit + async def crawl_url( - self, - url: str, - config: CrawlerRunConfig, + self, + url: str, + config: CrawlerRunConfig, task_id: str, - semaphore: asyncio.Semaphore = None + semaphore: asyncio.Semaphore = None, ) -> CrawlerTaskResult: start_time = datetime.now() error_message = "" @@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher): try: if self.monitor: - self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) + self.monitor.update_task( + task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time + ) if self.rate_limiter: await self.rate_limiter.wait_if_needed(url) @@ -477,7 +503,7 @@ class SemaphoreDispatcher(BaseDispatcher): start_memory = process.memory_info().rss / (1024 * 1024) result = await self.crawler.arun(url, config=config, session_id=task_id) end_memory = process.memory_info().rss / (1024 * 1024) - + memory_usage = peak_memory = end_memory - start_memory if self.rate_limiter and result.status_code: @@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=datetime.now(), - error_message=error_message + error_message=error_message, ) if not result.success: @@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher): error_message = str(e) if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) - result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) + result = CrawlResult( + url=url, html="", metadata={}, success=False, error_message=str(e) + ) finally: end_time = datetime.now() @@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher): end_time=end_time, memory_usage=memory_usage, peak_memory=peak_memory, - error_message=error_message + error_message=error_message, ) return CrawlerTaskResult( @@ -528,13 +556,13 @@ class SemaphoreDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=end_time, - error_message=error_message + error_message=error_message, ) async def run_urls( - self, - crawler: "AsyncWebCrawler", - urls: List[str], + self, + crawler: "AsyncWebCrawler", # noqa: F821 + urls: List[str], config: CrawlerRunConfig, ) -> List[CrawlerTaskResult]: self.crawler = crawler @@ -557,4 +585,4 @@ class SemaphoreDispatcher(BaseDispatcher): return await asyncio.gather(*tasks, return_exceptions=True) finally: if self.monitor: - self.monitor.stop() \ No newline at end of file + self.monitor.stop() diff --git a/crawl4ai/async_logger.py b/crawl4ai/async_logger.py index 5d2d54b5..0e049289 100644 --- a/crawl4ai/async_logger.py +++ b/crawl4ai/async_logger.py @@ -1,10 +1,10 @@ from enum import Enum -from typing import Optional, Dict, Any, Union -from colorama import Fore, Back, Style, init -import time +from typing import Optional, Dict, Any +from colorama import Fore, Style, init import os from datetime import datetime + class LogLevel(Enum): DEBUG = 1 INFO = 2 @@ -12,23 +12,24 @@ class LogLevel(Enum): WARNING = 4 ERROR = 5 + class AsyncLogger: """ Asynchronous logger with support for colored console output and file logging. Supports templated messages with colored components. """ - + DEFAULT_ICONS = { - 'INIT': '→', - 'READY': '✓', - 'FETCH': '↓', - 'SCRAPE': '◆', - 'EXTRACT': '■', - 'COMPLETE': '●', - 'ERROR': '×', - 'DEBUG': '⋯', - 'INFO': 'ℹ', - 'WARNING': '⚠', + "INIT": "→", + "READY": "✓", + "FETCH": "↓", + "SCRAPE": "◆", + "EXTRACT": "■", + "COMPLETE": "●", + "ERROR": "×", + "DEBUG": "⋯", + "INFO": "ℹ", + "WARNING": "⚠", } DEFAULT_COLORS = { @@ -46,11 +47,11 @@ class AsyncLogger: tag_width: int = 10, icons: Optional[Dict[str, str]] = None, colors: Optional[Dict[LogLevel, str]] = None, - verbose: bool = True + verbose: bool = True, ): """ Initialize the logger. - + Args: log_file: Optional file path for logging log_level: Minimum log level to display @@ -66,7 +67,7 @@ class AsyncLogger: self.icons = icons or self.DEFAULT_ICONS self.colors = colors or self.DEFAULT_COLORS self.verbose = verbose - + # Create log file directory if needed if log_file: os.makedirs(os.path.dirname(os.path.abspath(log_file)), exist_ok=True) @@ -77,18 +78,20 @@ class AsyncLogger: def _get_icon(self, tag: str) -> str: """Get the icon for a tag, defaulting to info icon if not found.""" - return self.icons.get(tag, self.icons['INFO']) + return self.icons.get(tag, self.icons["INFO"]) def _write_to_file(self, message: str): """Write a message to the log file if configured.""" if self.log_file: - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] - with open(self.log_file, 'a', encoding='utf-8') as f: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + with open(self.log_file, "a", encoding="utf-8") as f: # Strip ANSI color codes for file output - clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '') + clean_message = message.replace(Fore.RESET, "").replace( + Style.RESET_ALL, "" + ) for color in vars(Fore).values(): if isinstance(color, str): - clean_message = clean_message.replace(color, '') + clean_message = clean_message.replace(color, "") f.write(f"[{timestamp}] {clean_message}\n") def _log( @@ -99,11 +102,11 @@ class AsyncLogger: params: Optional[Dict[str, Any]] = None, colors: Optional[Dict[str, str]] = None, base_color: Optional[str] = None, - **kwargs + **kwargs, ): """ Core logging method that handles message formatting and output. - + Args: level: Log level for this message message: Message template string @@ -120,7 +123,7 @@ class AsyncLogger: try: # First format the message with raw parameters formatted_message = message.format(**params) - + # Then apply colors if specified if colors: for key, color in colors.items(): @@ -128,12 +131,13 @@ class AsyncLogger: if key in params: value_str = str(params[key]) formatted_message = formatted_message.replace( - value_str, - f"{color}{value_str}{Style.RESET_ALL}" + value_str, f"{color}{value_str}{Style.RESET_ALL}" ) - + except KeyError as e: - formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template" + formatted_message = ( + f"LOGGING ERROR: Missing parameter {e} in message template" + ) level = LogLevel.ERROR else: formatted_message = message @@ -175,11 +179,11 @@ class AsyncLogger: success: bool, timing: float, tag: str = "FETCH", - url_length: int = 50 + url_length: int = 50, ): """ Convenience method for logging URL fetch status. - + Args: url: The URL being processed success: Whether the operation was successful @@ -195,24 +199,20 @@ class AsyncLogger: "url": url, "url_length": url_length, "status": success, - "timing": timing + "timing": timing, }, colors={ "status": Fore.GREEN if success else Fore.RED, - "timing": Fore.YELLOW - } + "timing": Fore.YELLOW, + }, ) def error_status( - self, - url: str, - error: str, - tag: str = "ERROR", - url_length: int = 50 + self, url: str, error: str, tag: str = "ERROR", url_length: int = 50 ): """ Convenience method for logging error status. - + Args: url: The URL being processed error: Error message @@ -223,9 +223,5 @@ class AsyncLogger: level=LogLevel.ERROR, message="{url:.{url_length}}... | Error: {error}", tag=tag, - params={ - "url": url, - "url_length": url_length, - "error": error - } - ) \ No newline at end of file + params={"url": url, "url_length": url_length, "error": error}, + ) diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 9ef19966..a7596a55 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -1,49 +1,54 @@ -import os, sys +import os +import sys import time import warnings -from enum import Enum -from colorama import init, Fore, Back, Style +from colorama import Fore from pathlib import Path -from typing import Optional, List, Union +from typing import Optional, List import json import asyncio + # from contextlib import nullcontext, asynccontextmanager from contextlib import asynccontextmanager -from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult +from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult, RateLimiter from .async_database import async_db_manager -from .chunking_strategy import * -from .content_filter_strategy import * -from .extraction_strategy import * -from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse +from .chunking_strategy import * # noqa: F403 +from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking +from .content_filter_strategy import * # noqa: F403 +from .content_filter_strategy import RelevantContentFilter +from .extraction_strategy import * # noqa: F403 +from .extraction_strategy import NoExtractionStrategy, ExtractionStrategy +from .async_crawler_strategy import ( + AsyncCrawlerStrategy, + AsyncPlaywrightCrawlerStrategy, + AsyncCrawlResponse, +) from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode -from .markdown_generation_strategy import DefaultMarkdownGenerator, MarkdownGenerationStrategy -from .content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy +from .markdown_generation_strategy import ( + DefaultMarkdownGenerator, + MarkdownGenerationStrategy, +) from .async_logger import AsyncLogger from .async_configs import BrowserConfig, CrawlerRunConfig -from .async_dispatcher import * +from .async_dispatcher import * # noqa: F403 +from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher -from .config import ( - MIN_WORD_THRESHOLD, - IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, - URL_LOG_SHORTEN_LENGTH -) +from .config import MIN_WORD_THRESHOLD from .utils import ( sanitize_input_encode, InvalidCSSSelectorError, - format_html, fast_format_html, - create_box_message + create_box_message, + get_error_context, ) -from urllib.parse import urlparse -import random from .__version__ import __version__ as crawl4ai_version class AsyncWebCrawler: """ Asynchronous web crawler with flexible caching capabilities. - + There are two ways to use the crawler: 1. Using context manager (recommended for simple cases): @@ -56,23 +61,23 @@ class AsyncWebCrawler: ```python crawler = AsyncWebCrawler() await crawler.start() - + # Use the crawler multiple times result1 = await crawler.arun(url="https://example.com") result2 = await crawler.arun(url="https://another.com") - + await crawler.close() ``` - + Migration Guide: Old way (deprecated): crawler = AsyncWebCrawler(always_by_pass_cache=True, browser_type="chromium", headless=True) - + New way (recommended): browser_config = BrowserConfig(browser_type="chromium", headless=True) crawler = AsyncWebCrawler(config=browser_config) - - + + Attributes: browser_config (BrowserConfig): Configuration object for browser settings. crawler_strategy (AsyncCrawlerStrategy): Strategy for crawling web pages. @@ -81,7 +86,7 @@ class AsyncWebCrawler: crawl4ai_folder (str): Directory for storing cache. base_directory (str): Base directory for storing cache. ready (bool): Whether the crawler is ready for use. - + Methods: start(): Start the crawler explicitly without using context manager. close(): Close the crawler explicitly without using context manager. @@ -89,21 +94,22 @@ class AsyncWebCrawler: awarmup(): Perform warmup sequence. arun_many(): Run the crawler for multiple sources. aprocess_html(): Process HTML content. - + Typical Usage: async with AsyncWebCrawler() as crawler: result = await crawler.arun(url="https://example.com") print(result.markdown) - + Using configuration: browser_config = BrowserConfig(browser_type="chromium", headless=True) async with AsyncWebCrawler(config=browser_config) as crawler: crawler_config = CrawlerRunConfig( - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS ) result = await crawler.arun(url="https://example.com", config=crawler_config) print(result.markdown) """ + _domain_last_hit = {} def __init__( @@ -127,43 +133,48 @@ class AsyncWebCrawler: base_directory: Base directory for storing cache thread_safe: Whether to use thread-safe operations **kwargs: Additional arguments for backwards compatibility - """ + """ # Handle browser configuration browser_config = config if browser_config is not None: - if any(k in kwargs for k in ["browser_type", "headless", "viewport_width", "viewport_height"]): + if any( + k in kwargs + for k in [ + "browser_type", + "headless", + "viewport_width", + "viewport_height", + ] + ): self.logger.warning( message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.", - tag="WARNING" + tag="WARNING", ) else: # Create browser config from kwargs for backwards compatibility browser_config = BrowserConfig.from_kwargs(kwargs) self.browser_config = browser_config - + # Initialize logger first since other components may need it self.logger = AsyncLogger( log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"), - verbose=self.browser_config.verbose, - tag_width=10 + verbose=self.browser_config.verbose, + tag_width=10, ) - # Initialize crawler strategy - params = { - k:v for k, v in kwargs.items() if k in ['browser_congig', 'logger'] - } + params = {k: v for k, v in kwargs.items() if k in ["browser_congig", "logger"]} self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy( browser_config=browser_config, logger=self.logger, - **params # Pass remaining kwargs for backwards compatibility + **params, # Pass remaining kwargs for backwards compatibility ) - + # If craweler strategy doesnt have logger, use crawler logger if not self.crawler_strategy.logger: self.crawler_strategy.logger = self.logger - + # Handle deprecated cache parameter if always_by_pass_cache is not None: if kwargs.get("warning", True): @@ -172,7 +183,7 @@ class AsyncWebCrawler: "Use 'always_bypass_cache' instead. " "Pass warning=False to suppress this warning.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) self.always_bypass_cache = always_by_pass_cache else: @@ -180,24 +191,24 @@ class AsyncWebCrawler: # Thread safety setup self._lock = asyncio.Lock() if thread_safe else None - + # Initialize directories self.crawl4ai_folder = os.path.join(base_directory, ".crawl4ai") os.makedirs(self.crawl4ai_folder, exist_ok=True) os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) - + self.ready = False async def start(self): """ Start the crawler explicitly without using context manager. This is equivalent to using 'async with' but gives more control over the lifecycle. - + This method will: 1. Initialize the browser and context 2. Perform warmup sequence 3. Return the crawler instance for method chaining - + Returns: AsyncWebCrawler: The initialized crawler instance """ @@ -209,7 +220,7 @@ class AsyncWebCrawler: """ Close the crawler explicitly without using context manager. This should be called when you're done with the crawler if you used start(). - + This method will: 1. Clean up browser resources 2. Close any open pages and contexts @@ -221,11 +232,11 @@ class AsyncWebCrawler: async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() - + async def awarmup(self): """ Initialize the crawler with warm-up sequence. - + This method: 1. Logs initialization info 2. Sets up browser configuration @@ -240,548 +251,553 @@ class AsyncWebCrawler: yield async def arun( - self, - url: str, - config: Optional[CrawlerRunConfig] = None, - # Legacy parameters maintained for backwards compatibility - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - content_filter: RelevantContentFilter = None, - cache_mode: Optional[CacheMode] = None, - # Deprecated cache parameters - bypass_cache: bool = False, - disable_cache: bool = False, - no_cache_read: bool = False, - no_cache_write: bool = False, - # Other legacy parameters - css_selector: str = None, - screenshot: bool = False, - pdf: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> CrawlResult: - """ - Runs the crawler for a single source: URL (web, local file, or raw HTML). + self, + url: str, + config: Optional[CrawlerRunConfig] = None, + # Legacy parameters maintained for backwards compatibility + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = RegexChunking(), + content_filter: RelevantContentFilter = None, + cache_mode: Optional[CacheMode] = None, + # Deprecated cache parameters + bypass_cache: bool = False, + disable_cache: bool = False, + no_cache_read: bool = False, + no_cache_write: bool = False, + # Other legacy parameters + css_selector: str = None, + screenshot: bool = False, + pdf: bool = False, + user_agent: str = None, + verbose=True, + **kwargs, + ) -> CrawlResult: + """ + Runs the crawler for a single source: URL (web, local file, or raw HTML). - Migration Guide: - Old way (deprecated): - result = await crawler.arun( - url="https://example.com", - word_count_threshold=200, - screenshot=True, - ... - ) + Migration Guide: + Old way (deprecated): + result = await crawler.arun( + url="https://example.com", + word_count_threshold=200, + screenshot=True, + ... + ) - New way (recommended): - config = CrawlerRunConfig( - word_count_threshold=200, - screenshot=True, - ... - ) - result = await crawler.arun(url="https://example.com", crawler_config=config) + New way (recommended): + config = CrawlerRunConfig( + word_count_threshold=200, + screenshot=True, + ... + ) + result = await crawler.arun(url="https://example.com", crawler_config=config) - Args: - url: The URL to crawl (http://, https://, file://, or raw:) - crawler_config: Configuration object controlling crawl behavior - [other parameters maintained for backwards compatibility] + Args: + url: The URL to crawl (http://, https://, file://, or raw:) + crawler_config: Configuration object controlling crawl behavior + [other parameters maintained for backwards compatibility] - Returns: - CrawlResult: The result of crawling and processing - """ - crawler_config = config - if not isinstance(url, str) or not url: - raise ValueError("Invalid URL, make sure the URL is a non-empty string") - - async with self._lock or self.nullcontext(): - try: - # Handle configuration - if crawler_config is not None: - # if any(param is not None for param in [ - # word_count_threshold, extraction_strategy, chunking_strategy, - # content_filter, cache_mode, css_selector, screenshot, pdf - # ]): - # self.logger.warning( - # message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.", - # tag="WARNING" - # ) - config = crawler_config - else: - # Merge all parameters into a single kwargs dict for config creation - config_kwargs = { - "word_count_threshold": word_count_threshold, - "extraction_strategy": extraction_strategy, - "chunking_strategy": chunking_strategy, - "content_filter": content_filter, - "cache_mode": cache_mode, - "bypass_cache": bypass_cache, - "disable_cache": disable_cache, - "no_cache_read": no_cache_read, - "no_cache_write": no_cache_write, - "css_selector": css_selector, - "screenshot": screenshot, - "pdf": pdf, - "verbose": verbose, - **kwargs - } - config = CrawlerRunConfig.from_kwargs(config_kwargs) + Returns: + CrawlResult: The result of crawling and processing + """ + crawler_config = config + if not isinstance(url, str) or not url: + raise ValueError("Invalid URL, make sure the URL is a non-empty string") - # Handle deprecated cache parameters - if any([bypass_cache, disable_cache, no_cache_read, no_cache_write]): - if kwargs.get("warning", True): - warnings.warn( - "Cache control boolean flags are deprecated and will be removed in version 0.5.0. " - "Use 'cache_mode' parameter instead.", - DeprecationWarning, - stacklevel=2 - ) - - # Convert legacy parameters if cache_mode not provided - if config.cache_mode is None: - config.cache_mode = _legacy_to_cache_mode( - disable_cache=disable_cache, - bypass_cache=bypass_cache, - no_cache_read=no_cache_read, - no_cache_write=no_cache_write - ) - - # Default to ENABLED if no cache mode specified + async with self._lock or self.nullcontext(): + try: + # Handle configuration + if crawler_config is not None: + # if any(param is not None for param in [ + # word_count_threshold, extraction_strategy, chunking_strategy, + # content_filter, cache_mode, css_selector, screenshot, pdf + # ]): + # self.logger.warning( + # message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.", + # tag="WARNING" + # ) + config = crawler_config + else: + # Merge all parameters into a single kwargs dict for config creation + config_kwargs = { + "word_count_threshold": word_count_threshold, + "extraction_strategy": extraction_strategy, + "chunking_strategy": chunking_strategy, + "content_filter": content_filter, + "cache_mode": cache_mode, + "bypass_cache": bypass_cache, + "disable_cache": disable_cache, + "no_cache_read": no_cache_read, + "no_cache_write": no_cache_write, + "css_selector": css_selector, + "screenshot": screenshot, + "pdf": pdf, + "verbose": verbose, + **kwargs, + } + config = CrawlerRunConfig.from_kwargs(config_kwargs) + + # Handle deprecated cache parameters + if any([bypass_cache, disable_cache, no_cache_read, no_cache_write]): + if kwargs.get("warning", True): + warnings.warn( + "Cache control boolean flags are deprecated and will be removed in version 0.5.0. " + "Use 'cache_mode' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + + # Convert legacy parameters if cache_mode not provided if config.cache_mode is None: - config.cache_mode = CacheMode.ENABLED - - # Create cache context - cache_context = CacheContext(url, config.cache_mode, self.always_bypass_cache) - - # Initialize processing variables - async_response: AsyncCrawlResponse = None - cached_result: CrawlResult = None - screenshot_data = None - pdf_data = None - extracted_content = None - start_time = time.perf_counter() - - # Try to get cached result if appropriate - if cache_context.should_read(): - cached_result = await async_db_manager.aget_cached_url(url) - - if cached_result: - html = sanitize_input_encode(cached_result.html) - extracted_content = sanitize_input_encode(cached_result.extracted_content or "") - extracted_content = None if not extracted_content or extracted_content == "[]" else extracted_content - # If screenshot is requested but its not in cache, then set cache_result to None - screenshot_data = cached_result.screenshot - pdf_data = cached_result.pdf - if config.screenshot and not screenshot or config.pdf and not pdf: - cached_result = None - - self.logger.url_status( - url=cache_context.display_url, - success=bool(html), - timing=time.perf_counter() - start_time, - tag="FETCH" + config.cache_mode = _legacy_to_cache_mode( + disable_cache=disable_cache, + bypass_cache=bypass_cache, + no_cache_read=no_cache_read, + no_cache_write=no_cache_write, ) - # Fetch fresh content if needed - if not cached_result or not html: - t1 = time.perf_counter() - - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - - # Pass config to crawl method - async_response = await self.crawler_strategy.crawl( - url, - config=config # Pass the entire config object - ) - - html = sanitize_input_encode(async_response.html) - screenshot_data = async_response.screenshot - pdf_data = async_response.pdf_data - - t2 = time.perf_counter() - self.logger.url_status( - url=cache_context.display_url, - success=bool(html), - timing=t2 - t1, - tag="FETCH" - ) + # Default to ENABLED if no cache mode specified + if config.cache_mode is None: + config.cache_mode = CacheMode.ENABLED - # Process the HTML content - crawl_result = await self.aprocess_html( - url=url, - html=html, - extracted_content=extracted_content, - config=config, # Pass the config object instead of individual parameters - screenshot=screenshot_data, - pdf_data=pdf_data, - verbose=config.verbose, - is_raw_html = True if url.startswith("raw:") else False, - **kwargs - ) + # Create cache context + cache_context = CacheContext( + url, config.cache_mode, self.always_bypass_cache + ) - crawl_result.status_code = async_response.status_code - crawl_result.response_headers = async_response.response_headers - crawl_result.downloaded_files = async_response.downloaded_files - crawl_result.ssl_certificate = async_response.ssl_certificate # Add SSL certificate + # Initialize processing variables + async_response: AsyncCrawlResponse = None + cached_result: CrawlResult = None + screenshot_data = None + pdf_data = None + extracted_content = None + start_time = time.perf_counter() - # # Check and set values from async_response to crawl_result - # try: - # for key in vars(async_response): - # if hasattr(crawl_result, key): - # value = getattr(async_response, key, None) - # current_value = getattr(crawl_result, key, None) - # if value is not None and not current_value: - # try: - # setattr(crawl_result, key, value) - # except Exception as e: - # self.logger.warning( - # message=f"Failed to set attribute {key}: {str(e)}", - # tag="WARNING" - # ) - # except Exception as e: - # self.logger.warning( - # message=f"Error copying response attributes: {str(e)}", - # tag="WARNING" - # ) + # Try to get cached result if appropriate + if cache_context.should_read(): + cached_result = await async_db_manager.aget_cached_url(url) - crawl_result.success = bool(html) - crawl_result.session_id = getattr(config, 'session_id', None) - - self.logger.success( - message="{url:.50}... | Status: {status} | Total: {timing}", - tag="COMPLETE", - params={ - "url": cache_context.display_url, - "status": crawl_result.success, - "timing": f"{time.perf_counter() - start_time:.2f}s" - }, - colors={ - "status": Fore.GREEN if crawl_result.success else Fore.RED, - "timing": Fore.YELLOW - } - ) - - # Update cache if appropriate - if cache_context.should_write() and not bool(cached_result): - await async_db_manager.acache_url(crawl_result) - - return crawl_result - - else: - self.logger.success( - message="{url:.50}... | Status: {status} | Total: {timing}", - tag="COMPLETE", - params={ - "url": cache_context.display_url, - "status": True, - "timing": f"{time.perf_counter() - start_time:.2f}s" - }, - colors={ - "status": Fore.GREEN, - "timing": Fore.YELLOW - } - ) - - cached_result.success = bool(html) - cached_result.session_id = getattr(config, 'session_id', None) - return cached_result - - except Exception as e: - error_context = get_error_context(sys.exc_info()) - - error_message = ( - f"Unexpected error in _crawl_web at line {error_context['line_no']} " - f"in {error_context['function']} ({error_context['filename']}):\n" - f"Error: {str(e)}\n\n" - f"Code context:\n{error_context['code_context']}" + if cached_result: + html = sanitize_input_encode(cached_result.html) + extracted_content = sanitize_input_encode( + cached_result.extracted_content or "" ) - # if not hasattr(e, "msg"): - # e.msg = str(e) - - self.logger.error_status( + extracted_content = ( + None + if not extracted_content or extracted_content == "[]" + else extracted_content + ) + # If screenshot is requested but its not in cache, then set cache_result to None + screenshot_data = cached_result.screenshot + pdf_data = cached_result.pdf + if config.screenshot and not screenshot or config.pdf and not pdf: + cached_result = None + + self.logger.url_status( + url=cache_context.display_url, + success=bool(html), + timing=time.perf_counter() - start_time, + tag="FETCH", + ) + + # Fetch fresh content if needed + if not cached_result or not html: + t1 = time.perf_counter() + + if user_agent: + self.crawler_strategy.update_user_agent(user_agent) + + # Pass config to crawl method + async_response = await self.crawler_strategy.crawl( + url, + config=config, # Pass the entire config object + ) + + html = sanitize_input_encode(async_response.html) + screenshot_data = async_response.screenshot + pdf_data = async_response.pdf_data + + t2 = time.perf_counter() + self.logger.url_status( + url=cache_context.display_url, + success=bool(html), + timing=t2 - t1, + tag="FETCH", + ) + + # Process the HTML content + crawl_result = await self.aprocess_html( url=url, - error=create_box_message(error_message, type="error"), - tag="ERROR" + html=html, + extracted_content=extracted_content, + config=config, # Pass the config object instead of individual parameters + screenshot=screenshot_data, + pdf_data=pdf_data, + verbose=config.verbose, + is_raw_html=True if url.startswith("raw:") else False, + **kwargs, ) - - return CrawlResult( - url=url, - html="", - success=False, - error_message=error_message + + crawl_result.status_code = async_response.status_code + crawl_result.response_headers = async_response.response_headers + crawl_result.downloaded_files = async_response.downloaded_files + crawl_result.ssl_certificate = ( + async_response.ssl_certificate + ) # Add SSL certificate + + # # Check and set values from async_response to crawl_result + # try: + # for key in vars(async_response): + # if hasattr(crawl_result, key): + # value = getattr(async_response, key, None) + # current_value = getattr(crawl_result, key, None) + # if value is not None and not current_value: + # try: + # setattr(crawl_result, key, value) + # except Exception as e: + # self.logger.warning( + # message=f"Failed to set attribute {key}: {str(e)}", + # tag="WARNING" + # ) + # except Exception as e: + # self.logger.warning( + # message=f"Error copying response attributes: {str(e)}", + # tag="WARNING" + # ) + + crawl_result.success = bool(html) + crawl_result.session_id = getattr(config, "session_id", None) + + self.logger.success( + message="{url:.50}... | Status: {status} | Total: {timing}", + tag="COMPLETE", + params={ + "url": cache_context.display_url, + "status": crawl_result.success, + "timing": f"{time.perf_counter() - start_time:.2f}s", + }, + colors={ + "status": Fore.GREEN if crawl_result.success else Fore.RED, + "timing": Fore.YELLOW, + }, ) + # Update cache if appropriate + if cache_context.should_write() and not bool(cached_result): + await async_db_manager.acache_url(crawl_result) + + return crawl_result + + else: + self.logger.success( + message="{url:.50}... | Status: {status} | Total: {timing}", + tag="COMPLETE", + params={ + "url": cache_context.display_url, + "status": True, + "timing": f"{time.perf_counter() - start_time:.2f}s", + }, + colors={"status": Fore.GREEN, "timing": Fore.YELLOW}, + ) + + cached_result.success = bool(html) + cached_result.session_id = getattr(config, "session_id", None) + return cached_result + + except Exception as e: + error_context = get_error_context(sys.exc_info()) + + error_message = ( + f"Unexpected error in _crawl_web at line {error_context['line_no']} " + f"in {error_context['function']} ({error_context['filename']}):\n" + f"Error: {str(e)}\n\n" + f"Code context:\n{error_context['code_context']}" + ) + # if not hasattr(e, "msg"): + # e.msg = str(e) + + self.logger.error_status( + url=url, + error=create_box_message(error_message, type="error"), + tag="ERROR", + ) + + return CrawlResult( + url=url, html="", success=False, error_message=error_message + ) + async def aprocess_html( - self, - url: str, - html: str, - extracted_content: str, - config: CrawlerRunConfig, - screenshot: str, - pdf_data: str, - verbose: bool, - **kwargs, - ) -> CrawlResult: - """ - Process HTML content using the provided configuration. - - Args: - url: The URL being processed - html: Raw HTML content - extracted_content: Previously extracted content (if any) - config: Configuration object controlling processing behavior - screenshot: Screenshot data (if any) - pdf_data: PDF data (if any) - verbose: Whether to enable verbose logging - **kwargs: Additional parameters for backwards compatibility - - Returns: - CrawlResult: Processed result containing extracted and formatted content - """ - try: - _url = url if not kwargs.get("is_raw_html", False) else "Raw HTML" - t1 = time.perf_counter() + self, + url: str, + html: str, + extracted_content: str, + config: CrawlerRunConfig, + screenshot: str, + pdf_data: str, + verbose: bool, + **kwargs, + ) -> CrawlResult: + """ + Process HTML content using the provided configuration. - # Get scraping strategy and ensure it has a logger - scraping_strategy = config.scraping_strategy - if not scraping_strategy.logger: - scraping_strategy.logger = self.logger + Args: + url: The URL being processed + html: Raw HTML content + extracted_content: Previously extracted content (if any) + config: Configuration object controlling processing behavior + screenshot: Screenshot data (if any) + pdf_data: PDF data (if any) + verbose: Whether to enable verbose logging + **kwargs: Additional parameters for backwards compatibility - # Process HTML content - params = {k:v for k, v in config.to_dict().items() if k not in ["url"]} - # add keys from kwargs to params that doesn't exist in params - params.update({k:v for k, v in kwargs.items() if k not in params.keys()}) - - result = scraping_strategy.scrap( - url, - html, - **params + Returns: + CrawlResult: Processed result containing extracted and formatted content + """ + try: + _url = url if not kwargs.get("is_raw_html", False) else "Raw HTML" + t1 = time.perf_counter() + + # Get scraping strategy and ensure it has a logger + scraping_strategy = config.scraping_strategy + if not scraping_strategy.logger: + scraping_strategy.logger = self.logger + + # Process HTML content + params = {k: v for k, v in config.to_dict().items() if k not in ["url"]} + # add keys from kwargs to params that doesn't exist in params + params.update({k: v for k, v in kwargs.items() if k not in params.keys()}) + + result = scraping_strategy.scrap(url, html, **params) + + if result is None: + raise ValueError( + f"Process HTML, Failed to extract content from the website: {url}" ) - if result is None: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}") + except InvalidCSSSelectorError as e: + raise ValueError(str(e)) + except Exception as e: + raise ValueError( + f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}" + ) - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - except Exception as e: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") + # Extract results - handle both dict and ScrapingResult + if isinstance(result, dict): + cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) + media = result.get("media", {}) + links = result.get("links", {}) + metadata = result.get("metadata", {}) + else: + cleaned_html = sanitize_input_encode(result.cleaned_html) + media = result.media.model_dump() + links = result.links.model_dump() + metadata = result.metadata - + # Markdown Generation + markdown_generator: Optional[MarkdownGenerationStrategy] = ( + config.markdown_generator or DefaultMarkdownGenerator() + ) - # Extract results - handle both dict and ScrapingResult - if isinstance(result, dict): - cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) - media = result.get("media", {}) - links = result.get("links", {}) - metadata = result.get("metadata", {}) - else: - cleaned_html = sanitize_input_encode(result.cleaned_html) - media = result.media.model_dump() - links = result.links.model_dump() - metadata = result.metadata + # Uncomment if by default we want to use PruningContentFilter + # if not config.content_filter and not markdown_generator.content_filter: + # markdown_generator.content_filter = PruningContentFilter() - # Markdown Generation - markdown_generator: Optional[MarkdownGenerationStrategy] = config.markdown_generator or DefaultMarkdownGenerator() - - # Uncomment if by default we want to use PruningContentFilter - # if not config.content_filter and not markdown_generator.content_filter: - # markdown_generator.content_filter = PruningContentFilter() - - markdown_result: MarkdownGenerationResult = markdown_generator.generate_markdown( + markdown_result: MarkdownGenerationResult = ( + markdown_generator.generate_markdown( cleaned_html=cleaned_html, base_url=url, # html2text_options=kwargs.get('html2text', {}) ) - markdown_v2 = markdown_result - markdown = sanitize_input_encode(markdown_result.raw_markdown) + ) + markdown_v2 = markdown_result + markdown = sanitize_input_encode(markdown_result.raw_markdown) - # Log processing completion - self.logger.info( - message="Processed {url:.50}... | Time: {timing}ms", - tag="SCRAPE", - params={ - "url": _url, - "timing": int((time.perf_counter() - t1) * 1000) - } + # Log processing completion + self.logger.info( + message="Processed {url:.50}... | Time: {timing}ms", + tag="SCRAPE", + params={"url": _url, "timing": int((time.perf_counter() - t1) * 1000)}, + ) + + # Handle content extraction if needed + if ( + not bool(extracted_content) + and config.extraction_strategy + and not isinstance(config.extraction_strategy, NoExtractionStrategy) + ): + t1 = time.perf_counter() + + # Choose content based on input_format + content_format = config.extraction_strategy.input_format + if content_format == "fit_markdown" and not markdown_result.fit_markdown: + self.logger.warning( + message="Fit markdown requested but not available. Falling back to raw markdown.", + tag="EXTRACT", + params={"url": _url}, + ) + content_format = "markdown" + + content = { + "markdown": markdown, + "html": html, + "fit_markdown": markdown_result.raw_markdown, + }.get(content_format, markdown) + + # Use IdentityChunking for HTML input, otherwise use provided chunking strategy + chunking = ( + IdentityChunking() + if content_format == "html" + else config.chunking_strategy + ) + sections = chunking.chunk(content) + extracted_content = config.extraction_strategy.run(url, sections) + extracted_content = json.dumps( + extracted_content, indent=4, default=str, ensure_ascii=False ) - # Handle content extraction if needed - if (not bool(extracted_content) and config.extraction_strategy and not isinstance(config.extraction_strategy, NoExtractionStrategy)): - - t1 = time.perf_counter() - - # Choose content based on input_format - content_format = config.extraction_strategy.input_format - if content_format == "fit_markdown" and not markdown_result.fit_markdown: - self.logger.warning( - message="Fit markdown requested but not available. Falling back to raw markdown.", - tag="EXTRACT", - params={"url": _url} - ) - content_format = "markdown" + # Log extraction completion + self.logger.info( + message="Completed for {url:.50}... | Time: {timing}s", + tag="EXTRACT", + params={"url": _url, "timing": time.perf_counter() - t1}, + ) - content = { - "markdown": markdown, - "html": html, - "fit_markdown": markdown_result.raw_markdown - }.get(content_format, markdown) - - # Use IdentityChunking for HTML input, otherwise use provided chunking strategy - chunking = IdentityChunking() if content_format == "html" else config.chunking_strategy - sections = chunking.chunk(content) - extracted_content = config.extraction_strategy.run(url, sections) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) + # Handle screenshot and PDF data + screenshot_data = None if not screenshot else screenshot + pdf_data = None if not pdf_data else pdf_data - # Log extraction completion - self.logger.info( - message="Completed for {url:.50}... | Time: {timing}s", - tag="EXTRACT", - params={ - "url": _url, - "timing": time.perf_counter() - t1 - } - ) + # Apply HTML formatting if requested + if config.prettiify: + cleaned_html = fast_format_html(cleaned_html) - # Handle screenshot and PDF data - screenshot_data = None if not screenshot else screenshot - pdf_data = None if not pdf_data else pdf_data - - # Apply HTML formatting if requested - if config.prettiify: - cleaned_html = fast_format_html(cleaned_html) - - # Return complete crawl result - return CrawlResult( - url=url, - html=html, - cleaned_html=cleaned_html, - markdown_v2=markdown_v2, - markdown=markdown, - fit_markdown=markdown_result.fit_markdown, - fit_html=markdown_result.fit_html, - media=media, - links=links, - metadata=metadata, - screenshot=screenshot_data, - pdf=pdf_data, - extracted_content=extracted_content, - success=True, - error_message="", - ) + # Return complete crawl result + return CrawlResult( + url=url, + html=html, + cleaned_html=cleaned_html, + markdown_v2=markdown_v2, + markdown=markdown, + fit_markdown=markdown_result.fit_markdown, + fit_html=markdown_result.fit_html, + media=media, + links=links, + metadata=metadata, + screenshot=screenshot_data, + pdf=pdf_data, + extracted_content=extracted_content, + success=True, + error_message="", + ) async def arun_many( - self, - urls: List[str], - config: Optional[CrawlerRunConfig] = None, - dispatcher: Optional[BaseDispatcher] = None, - # Legacy parameters maintained for backwards compatibility - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - content_filter: RelevantContentFilter = None, - cache_mode: Optional[CacheMode] = None, - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - pdf: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> List[CrawlResult]: - """ - Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy. + self, + urls: List[str], + config: Optional[CrawlerRunConfig] = None, + dispatcher: Optional[BaseDispatcher] = None, + # Legacy parameters maintained for backwards compatibility + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = RegexChunking(), + content_filter: RelevantContentFilter = None, + cache_mode: Optional[CacheMode] = None, + bypass_cache: bool = False, + css_selector: str = None, + screenshot: bool = False, + pdf: bool = False, + user_agent: str = None, + verbose=True, + **kwargs, + ) -> List[CrawlResult]: + """ + Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy. - Migration Guide: - Old way (deprecated): - results = await crawler.arun_many( - urls, - word_count_threshold=200, - screenshot=True, - ... - ) - - New way (recommended): - config = CrawlerRunConfig( - word_count_threshold=200, - screenshot=True, - dispatcher_config=DispatcherConfig( - enable_rate_limiting=True, - rate_limit_config=RateLimitConfig(...), - ), - ... - ) - results = await crawler.arun_many( - urls, - config=config, - dispatcher_strategy=MemoryAdaptiveDispatcher # Optional, this is the default - ) - - Args: - urls: List of URLs to crawl - config: Configuration object controlling crawl behavior for all URLs - dispatcher_strategy: The dispatcher strategy class to use. Defaults to MemoryAdaptiveDispatcher. - [other parameters maintained for backwards compatibility] - - Returns: - List[CrawlResult]: Results for each URL - """ - # Create config if not provided - if config is None: - config = CrawlerRunConfig( - word_count_threshold=word_count_threshold, - extraction_strategy=extraction_strategy, - chunking_strategy=chunking_strategy, - content_filter=content_filter, - cache_mode=cache_mode, - bypass_cache=bypass_cache, - css_selector=css_selector, - screenshot=screenshot, - pdf=pdf, - verbose=verbose, - **kwargs - ) - - # # Initialize the dispatcher with the selected strategy - # dispatcher = dispatcher_strategy(self, config.dispatcher_config) - - # memory_monitor: CrawlerMonitor = None - # if config.dispatcher_config.enable_monitor: - # memory_monitor = CrawlerMonitor(max_visible_rows=config.dispatcher_config.max_display_rows, display_mode=config.dispatcher_config.display_mode) - - # Create default dispatcher if none provided - if dispatcher is None: - dispatcher = MemoryAdaptiveDispatcher( - self, - rate_limiter=RateLimiter( - base_delay=(1.0, 3.0), - max_delay=60.0, - max_retries=3 - ) - ) - - # Run the URLs through the dispatcher - _results: List[CrawlerTaskResult] = await dispatcher.run_urls( - crawler=self, - urls=urls, - config=config + Migration Guide: + Old way (deprecated): + results = await crawler.arun_many( + urls, + word_count_threshold=200, + screenshot=True, + ... ) - - results: CrawlResult = [] - for res in _results: - _res : CrawlResult = res.result - dispatch_result: DispatchResult = DispatchResult( - task_id=res.task_id, - memory_usage=res.memory_usage, - peak_memory=res.peak_memory, - start_time=res.start_time, - end_time=res.end_time, - error_message=res.error_message - ) - _res.dispatch_result = dispatch_result - results.append(_res) - - return results + + New way (recommended): + config = CrawlerRunConfig( + word_count_threshold=200, + screenshot=True, + dispatcher_config=DispatcherConfig( + enable_rate_limiting=True, + rate_limit_config=RateLimitConfig(...), + ), + ... + ) + results = await crawler.arun_many( + urls, + config=config, + dispatcher_strategy=MemoryAdaptiveDispatcher # Optional, this is the default + ) + + Args: + urls: List of URLs to crawl + config: Configuration object controlling crawl behavior for all URLs + dispatcher_strategy: The dispatcher strategy class to use. Defaults to MemoryAdaptiveDispatcher. + [other parameters maintained for backwards compatibility] + + Returns: + List[CrawlResult]: Results for each URL + """ + # Create config if not provided + if config is None: + config = CrawlerRunConfig( + word_count_threshold=word_count_threshold, + extraction_strategy=extraction_strategy, + chunking_strategy=chunking_strategy, + content_filter=content_filter, + cache_mode=cache_mode, + bypass_cache=bypass_cache, + css_selector=css_selector, + screenshot=screenshot, + pdf=pdf, + verbose=verbose, + **kwargs, + ) + + # # Initialize the dispatcher with the selected strategy + # dispatcher = dispatcher_strategy(self, config.dispatcher_config) + + # memory_monitor: CrawlerMonitor = None + # if config.dispatcher_config.enable_monitor: + # memory_monitor = CrawlerMonitor(max_visible_rows=config.dispatcher_config.max_display_rows, display_mode=config.dispatcher_config.display_mode) + + # Create default dispatcher if none provided + if dispatcher is None: + dispatcher = MemoryAdaptiveDispatcher( + self, + rate_limiter=RateLimiter( + base_delay=(1.0, 3.0), max_delay=60.0, max_retries=3 + ), + ) + + # Run the URLs through the dispatcher + _results: List[CrawlerTaskResult] = await dispatcher.run_urls( + crawler=self, urls=urls, config=config + ) + + results: CrawlResult = [] + for res in _results: + _res: CrawlResult = res.result + dispatch_result: DispatchResult = DispatchResult( + task_id=res.task_id, + memory_usage=res.memory_usage, + peak_memory=res.peak_memory, + start_time=res.start_time, + end_time=res.end_time, + error_message=res.error_message, + ) + _res.dispatch_result = dispatch_result + results.append(_res) + + return results async def aclear_cache(self): """Clear the cache database.""" diff --git a/crawl4ai/cache_context.py b/crawl4ai/cache_context.py index 588edd62..75914b5b 100644 --- a/crawl4ai/cache_context.py +++ b/crawl4ai/cache_context.py @@ -4,7 +4,7 @@ from enum import Enum class CacheMode(Enum): """ Defines the caching behavior for web crawling operations. - + Modes: - ENABLED: Normal caching behavior (read and write) - DISABLED: No caching at all @@ -12,6 +12,7 @@ class CacheMode(Enum): - WRITE_ONLY: Only write to cache, don't read - BYPASS: Bypass cache for this operation """ + ENABLED = "enabled" DISABLED = "disabled" READ_ONLY = "read_only" @@ -22,10 +23,10 @@ class CacheMode(Enum): class CacheContext: """ Encapsulates cache-related decisions and URL handling. - + This class centralizes all cache-related logic and URL type checking, making the caching behavior more predictable and maintainable. - + Attributes: url (str): The URL being processed. cache_mode (CacheMode): The cache mode for the current operation. @@ -36,10 +37,11 @@ class CacheContext: is_raw_html (bool): True if the URL is raw HTML, False otherwise. _url_display (str): The display name for the URL (web, local file, or raw HTML). """ + def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False): """ Initializes the CacheContext with the provided URL and cache mode. - + Args: url (str): The URL being processed. cache_mode (CacheMode): The cache mode for the current operation. @@ -48,42 +50,42 @@ class CacheContext: self.url = url self.cache_mode = cache_mode self.always_bypass = always_bypass - self.is_cacheable = url.startswith(('http://', 'https://', 'file://')) - self.is_web_url = url.startswith(('http://', 'https://')) + self.is_cacheable = url.startswith(("http://", "https://", "file://")) + self.is_web_url = url.startswith(("http://", "https://")) self.is_local_file = url.startswith("file://") self.is_raw_html = url.startswith("raw:") self._url_display = url if not self.is_raw_html else "Raw HTML" - + def should_read(self) -> bool: """ Determines if cache should be read based on context. - + How it works: 1. If always_bypass is True or is_cacheable is False, return False. 2. If cache_mode is ENABLED or READ_ONLY, return True. - + Returns: bool: True if cache should be read, False otherwise. """ if self.always_bypass or not self.is_cacheable: return False return self.cache_mode in [CacheMode.ENABLED, CacheMode.READ_ONLY] - + def should_write(self) -> bool: """ Determines if cache should be written based on context. - + How it works: 1. If always_bypass is True or is_cacheable is False, return False. 2. If cache_mode is ENABLED or WRITE_ONLY, return True. - + Returns: bool: True if cache should be written, False otherwise. """ if self.always_bypass or not self.is_cacheable: return False return self.cache_mode in [CacheMode.ENABLED, CacheMode.WRITE_ONLY] - + @property def display_url(self) -> str: """Returns the URL in display format.""" @@ -94,11 +96,11 @@ def _legacy_to_cache_mode( disable_cache: bool = False, bypass_cache: bool = False, no_cache_read: bool = False, - no_cache_write: bool = False + no_cache_write: bool = False, ) -> CacheMode: """ Converts legacy cache parameters to the new CacheMode enum. - + This is an internal function to help transition from the old boolean flags to the new CacheMode system. """ diff --git a/crawl4ai/chunking_strategy.py b/crawl4ai/chunking_strategy.py index 7b8c08ad..ca188d1d 100644 --- a/crawl4ai/chunking_strategy.py +++ b/crawl4ai/chunking_strategy.py @@ -3,49 +3,53 @@ import re from collections import Counter import string from .model_loader import load_nltk_punkt -from .utils import * + # Define the abstract base class for chunking strategies class ChunkingStrategy(ABC): """ Abstract base class for chunking strategies. """ - + @abstractmethod def chunk(self, text: str) -> list: """ Abstract method to chunk the given text. - + Args: text (str): The text to chunk. - + Returns: list: A list of chunks. """ pass + # Create an identity chunking strategy f(x) = [x] class IdentityChunking(ChunkingStrategy): """ Chunking strategy that returns the input text as a single chunk. """ + def chunk(self, text: str) -> list: return [text] + # Regex-based chunking class RegexChunking(ChunkingStrategy): """ Chunking strategy that splits text based on regular expression patterns. """ + def __init__(self, patterns=None, **kwargs): """ Initialize the RegexChunking object. - + Args: patterns (list): A list of regular expression patterns to split text. """ if patterns is None: - patterns = [r'\n\n'] # Default split pattern + patterns = [r"\n\n"] # Default split pattern self.patterns = patterns def chunk(self, text: str) -> list: @@ -56,18 +60,19 @@ class RegexChunking(ChunkingStrategy): new_paragraphs.extend(re.split(pattern, paragraph)) paragraphs = new_paragraphs return paragraphs - -# NLP-based sentence chunking + + +# NLP-based sentence chunking class NlpSentenceChunking(ChunkingStrategy): """ Chunking strategy that splits text into sentences using NLTK's sentence tokenizer. - """ + """ + def __init__(self, **kwargs): """ Initialize the NlpSentenceChunking object. """ load_nltk_punkt() - def chunk(self, text: str) -> list: # Improved regex for sentence splitting @@ -75,31 +80,34 @@ class NlpSentenceChunking(ChunkingStrategy): # r'(? list: # Tokenize and remove stopwords and punctuation import nltk as nl + tokens = nl.toknize.word_tokenize(text) - tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation] + tokens = [ + token.lower() + for token in tokens + if token not in nl.corpus.stopwords.words("english") + and token not in string.punctuation + ] # Calculate frequency distribution freq_dist = Counter(tokens) @@ -123,23 +137,27 @@ class TopicSegmentationChunking(ChunkingStrategy): # Segment the text into topics segments = self.chunk(text) # Extract keywords for each topic segment - segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments] + segments_with_topics = [ + (segment, self.extract_keywords(segment)) for segment in segments + ] return segments_with_topics - + + # Fixed-length word chunks class FixedLengthWordChunking(ChunkingStrategy): """ Chunking strategy that splits text into fixed-length word chunks. - + How it works: 1. Split the text into words 2. Create chunks of fixed length 3. Return the list of chunks """ + def __init__(self, chunk_size=100, **kwargs): """ Initialize the fixed-length word chunking strategy with the given chunk size. - + Args: chunk_size (int): The size of each chunk in words. """ @@ -147,23 +165,28 @@ class FixedLengthWordChunking(ChunkingStrategy): def chunk(self, text: str) -> list: words = text.split() - return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)] - + return [ + " ".join(words[i : i + self.chunk_size]) + for i in range(0, len(words), self.chunk_size) + ] + + # Sliding window chunking class SlidingWindowChunking(ChunkingStrategy): """ Chunking strategy that splits text into overlapping word chunks. - + How it works: 1. Split the text into words 2. Create chunks of fixed length 3. Return the list of chunks """ + def __init__(self, window_size=100, step=50, **kwargs): """ Initialize the sliding window chunking strategy with the given window size and step size. - + Args: window_size (int): The size of the sliding window in words. step (int): The step size for sliding the window in words. @@ -174,35 +197,37 @@ class SlidingWindowChunking(ChunkingStrategy): def chunk(self, text: str) -> list: words = text.split() chunks = [] - + if len(words) <= self.window_size: return [text] - + for i in range(0, len(words) - self.window_size + 1, self.step): - chunk = ' '.join(words[i:i + self.window_size]) + chunk = " ".join(words[i : i + self.window_size]) chunks.append(chunk) - + # Handle the last chunk if it doesn't align perfectly if i + self.window_size < len(words): - chunks.append(' '.join(words[-self.window_size:])) - + chunks.append(" ".join(words[-self.window_size :])) + return chunks - + + class OverlappingWindowChunking(ChunkingStrategy): """ Chunking strategy that splits text into overlapping word chunks. - + How it works: 1. Split the text into words using whitespace 2. Create chunks of fixed length equal to the window size 3. Slide the window by the overlap size 4. Return the list of chunks """ + def __init__(self, window_size=1000, overlap=100, **kwargs): """ Initialize the overlapping window chunking strategy with the given window size and overlap size. - + Args: window_size (int): The size of the window in words. overlap (int): The size of the overlap between consecutive chunks in words. @@ -213,19 +238,19 @@ class OverlappingWindowChunking(ChunkingStrategy): def chunk(self, text: str) -> list: words = text.split() chunks = [] - + if len(words) <= self.window_size: return [text] - + start = 0 while start < len(words): end = start + self.window_size - chunk = ' '.join(words[start:end]) + chunk = " ".join(words[start:end]) chunks.append(chunk) - + if end >= len(words): break - + start = end - self.overlap - - return chunks \ No newline at end of file + + return chunks diff --git a/crawl4ai/cli.py b/crawl4ai/cli.py index 4a01c1c2..b2d2199e 100644 --- a/crawl4ai/cli.py +++ b/crawl4ai/cli.py @@ -8,15 +8,22 @@ from .async_logger import AsyncLogger logger = AsyncLogger(verbose=True) docs_manager = DocsManager(logger) + def print_table(headers: List[str], rows: List[List[str]], padding: int = 2): """Print formatted table with headers and rows""" widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)] - border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+' - + border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+" + def format_row(row): - return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}" - for cell, w in zip(row, widths)) + '|' - + return ( + "|" + + "|".join( + f"{' ' * padding}{str(cell):<{w}}{' ' * padding}" + for cell, w in zip(row, widths) + ) + + "|" + ) + click.echo(border) click.echo(format_row(headers)) click.echo(border) @@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2): click.echo(format_row(row)) click.echo(border) + @click.group() def cli(): """Crawl4AI Command Line Interface""" pass + @cli.group() def docs(): """Documentation operations""" pass + @docs.command() -@click.argument('sections', nargs=-1) -@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended') +@click.argument("sections", nargs=-1) +@click.option( + "--mode", type=click.Choice(["extended", "condensed"]), default="extended" +) def combine(sections: tuple, mode: str): """Combine documentation sections""" try: @@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str): logger.error(str(e), tag="ERROR") sys.exit(1) + @docs.command() -@click.argument('query') -@click.option('--top-k', '-k', default=5) -@click.option('--build-index', is_flag=True, help='Build index if missing') +@click.argument("query") +@click.option("--top-k", "-k", default=5) +@click.option("--build-index", is_flag=True, help="Build index if missing") def search(query: str, top_k: int, build_index: bool): """Search documentation""" try: result = docs_manager.search(query, top_k) if result == "No search index available. Call build_search_index() first.": - if build_index or click.confirm('No search index found. Build it now?'): + if build_index or click.confirm("No search index found. Build it now?"): asyncio.run(docs_manager.llm_text.generate_index_files()) result = docs_manager.search(query, top_k) click.echo(result) @@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool): click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + @docs.command() def update(): """Update docs from GitHub""" @@ -73,22 +87,25 @@ def update(): click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + @docs.command() -@click.option('--force-facts', is_flag=True, help='Force regenerate fact files') -@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache') +@click.option("--force-facts", is_flag=True, help="Force regenerate fact files") +@click.option("--clear-cache", is_flag=True, help="Clear BM25 cache") def index(force_facts: bool, clear_cache: bool): """Build or rebuild search indexes""" try: asyncio.run(docs_manager.ensure_docs_exist()) - asyncio.run(docs_manager.llm_text.generate_index_files( - force_generate_facts=force_facts, - clear_bm25_cache=clear_cache - )) + asyncio.run( + docs_manager.llm_text.generate_index_files( + force_generate_facts=force_facts, clear_bm25_cache=clear_cache + ) + ) click.echo("Search indexes built successfully") except Exception as e: click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + # Add docs list command @docs.command() def list(): @@ -96,10 +113,11 @@ def list(): try: sections = docs_manager.list() print_table(["Sections"], [[section] for section in sections]) - + except Exception as e: click.echo(f"Error: {str(e)}", err=True) sys.exit(1) -if __name__ == '__main__': - cli() \ No newline at end of file + +if __name__ == "__main__": + cli() diff --git a/crawl4ai/config.py b/crawl4ai/config.py index c2be7638..3e26514a 100644 --- a/crawl4ai/config.py +++ b/crawl4ai/config.py @@ -8,7 +8,7 @@ DEFAULT_PROVIDER = "openai/gpt-4o-mini" MODEL_REPO_BRANCH = "new-release-0.0.2" # Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy PROVIDER_MODELS = { - "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token + "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token "groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"), "groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"), "openai/gpt-4o-mini": os.getenv("OPENAI_API_KEY"), @@ -22,27 +22,49 @@ PROVIDER_MODELS = { } # Chunk token threshold -CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens +CHUNK_TOKEN_THRESHOLD = 2**11 # 2048 tokens OVERLAP_RATE = 0.1 WORD_TOKEN_RATE = 1.3 -# Threshold for the minimum number of word in a HTML tag to be considered +# Threshold for the minimum number of word in a HTML tag to be considered MIN_WORD_THRESHOLD = 1 IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1 -IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height'] -ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark'] +IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"] +ONLY_TEXT_ELIGIBLE_TAGS = [ + "b", + "i", + "u", + "span", + "del", + "ins", + "sub", + "sup", + "strong", + "em", + "code", + "kbd", + "var", + "s", + "q", + "abbr", + "cite", + "dfn", + "time", + "small", + "mark", +] SOCIAL_MEDIA_DOMAINS = [ - 'facebook.com', - 'twitter.com', - 'x.com', - 'linkedin.com', - 'instagram.com', - 'pinterest.com', - 'tiktok.com', - 'snapchat.com', - 'reddit.com', - ] + "facebook.com", + "twitter.com", + "x.com", + "linkedin.com", + "instagram.com", + "pinterest.com", + "tiktok.com", + "snapchat.com", + "reddit.com", +] # Threshold for the Image extraction - Range is 1 to 6 # Images are scored based on point based system, to filter based on usefulness. Points are assigned @@ -60,5 +82,5 @@ NEED_MIGRATION = True URL_LOG_SHORTEN_LENGTH = 30 SHOW_DEPRECATION_WARNINGS = True SCREENSHOT_HEIGHT_TRESHOLD = 10000 -PAGE_TIMEOUT=60000 -DOWNLOAD_PAGE_TIMEOUT=60000 \ No newline at end of file +PAGE_TIMEOUT = 60000 +DOWNLOAD_PAGE_TIMEOUT = 60000 diff --git a/crawl4ai/content_filter_strategy.py b/crawl4ai/content_filter_strategy.py index ce433118..11a33ac2 100644 --- a/crawl4ai/content_filter_strategy.py +++ b/crawl4ai/content_filter_strategy.py @@ -1,59 +1,100 @@ import re from bs4 import BeautifulSoup, Tag -from typing import List, Tuple, Dict +from typing import List, Tuple from rank_bm25 import BM25Okapi -from time import perf_counter from collections import deque -from bs4 import BeautifulSoup, NavigableString, Tag, Comment +from bs4 import NavigableString, Comment from .utils import clean_tokens from abc import ABC, abstractmethod import math from snowballstemmer import stemmer + + class RelevantContentFilter(ABC): """Abstract base class for content filtering strategies""" + def __init__(self, user_query: str = None): self.user_query = user_query self.included_tags = { # Primary structure - 'article', 'main', 'section', 'div', + "article", + "main", + "section", + "div", # List structures - 'ul', 'ol', 'li', 'dl', 'dt', 'dd', + "ul", + "ol", + "li", + "dl", + "dt", + "dd", # Text content - 'p', 'span', 'blockquote', 'pre', 'code', + "p", + "span", + "blockquote", + "pre", + "code", # Headers - 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", # Tables - 'table', 'thead', 'tbody', 'tr', 'td', 'th', + "table", + "thead", + "tbody", + "tr", + "td", + "th", # Other semantic elements - 'figure', 'figcaption', 'details', 'summary', + "figure", + "figcaption", + "details", + "summary", # Text formatting - 'em', 'strong', 'b', 'i', 'mark', 'small', + "em", + "strong", + "b", + "i", + "mark", + "small", # Rich content - 'time', 'address', 'cite', 'q' + "time", + "address", + "cite", + "q", } self.excluded_tags = { - 'nav', 'footer', 'header', 'aside', 'script', - 'style', 'form', 'iframe', 'noscript' + "nav", + "footer", + "header", + "aside", + "script", + "style", + "form", + "iframe", + "noscript", } - self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'} + self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"} self.negative_patterns = re.compile( - r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share', - re.I + r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I ) self.min_word_count = 2 - + @abstractmethod def filter_content(self, html: str) -> List[str]: """Abstract method to be implemented by specific filtering strategies""" pass - + def extract_page_query(self, soup: BeautifulSoup, body: Tag) -> str: """Common method to extract page metadata with fallbacks""" if self.user_query: return self.user_query query_parts = [] - + # Title try: title = soup.title.string @@ -62,109 +103,145 @@ class RelevantContentFilter(ABC): except Exception: pass - if soup.find('h1'): - query_parts.append(soup.find('h1').get_text()) - + if soup.find("h1"): + query_parts.append(soup.find("h1").get_text()) + # Meta tags temp = "" - for meta_name in ['keywords', 'description']: - meta = soup.find('meta', attrs={'name': meta_name}) - if meta and meta.get('content'): - query_parts.append(meta['content']) - temp += meta['content'] - + for meta_name in ["keywords", "description"]: + meta = soup.find("meta", attrs={"name": meta_name}) + if meta and meta.get("content"): + query_parts.append(meta["content"]) + temp += meta["content"] + # If still empty, grab first significant paragraph if not temp: # Find the first tag P thatits text contains more than 50 characters - for p in body.find_all('p'): + for p in body.find_all("p"): if len(p.get_text()) > 150: query_parts.append(p.get_text()[:150]) - break - - return ' '.join(filter(None, query_parts)) + break - def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]: + return " ".join(filter(None, query_parts)) + + def extract_text_chunks( + self, body: Tag, min_word_threshold: int = None + ) -> List[Tuple[str, str]]: """ Extracts text chunks from a BeautifulSoup body element while preserving order. Returns list of tuples (text, tag_name) for classification. - + Args: body: BeautifulSoup Tag object representing the body element - + Returns: List of (text, tag_name) tuples """ # Tags to ignore - inline elements that shouldn't break text flow INLINE_TAGS = { - 'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code', - 'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q', - 'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup', - 'textarea', 'time', 'tt', 'var' + "a", + "abbr", + "acronym", + "b", + "bdo", + "big", + "br", + "button", + "cite", + "code", + "dfn", + "em", + "i", + "img", + "input", + "kbd", + "label", + "map", + "object", + "q", + "samp", + "script", + "select", + "small", + "span", + "strong", + "sub", + "sup", + "textarea", + "time", + "tt", + "var", } - + # Tags that typically contain meaningful headers - HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'} - + HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"} + chunks = [] current_text = [] chunk_index = 0 - + def should_break_chunk(tag: Tag) -> bool: """Determine if a tag should cause a break in the current text chunk""" - return ( - tag.name not in INLINE_TAGS - and not (tag.name == 'p' and len(current_text) == 0) + return tag.name not in INLINE_TAGS and not ( + tag.name == "p" and len(current_text) == 0 ) - + # Use deque for efficient push/pop operations stack = deque([(body, False)]) - + while stack: element, visited = stack.pop() - + if visited: # End of block element - flush accumulated text if current_text and should_break_chunk(element): - text = ' '.join(''.join(current_text).split()) + text = " ".join("".join(current_text).split()) if text: - tag_type = 'header' if element.name in HEADER_TAGS else 'content' + tag_type = ( + "header" if element.name in HEADER_TAGS else "content" + ) chunks.append((chunk_index, text, tag_type, element)) chunk_index += 1 current_text = [] continue - + if isinstance(element, NavigableString): if str(element).strip(): current_text.append(str(element).strip()) continue - + # Pre-allocate children to avoid multiple list operations children = list(element.children) if not children: continue - + # Mark block for revisit after processing children stack.append((element, True)) - + # Add children in reverse order for correct processing for child in reversed(children): if isinstance(child, (Tag, NavigableString)): stack.append((child, False)) - + # Handle any remaining text if current_text: - text = ' '.join(''.join(current_text).split()) + text = " ".join("".join(current_text).split()) if text: - chunks.append((chunk_index, text, 'content', body)) - - if min_word_threshold: - chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold] - - return chunks + chunks.append((chunk_index, text, "content", body)) - def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]: + if min_word_threshold: + chunks = [ + chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold + ] + + return chunks + + def _deprecated_extract_text_chunks( + self, soup: BeautifulSoup + ) -> List[Tuple[int, str, Tag]]: """Common method for extracting text chunks""" _text_cache = {} + def fast_text(element: Tag) -> str: elem_id = id(element) if elem_id in _text_cache: @@ -175,13 +252,13 @@ class RelevantContentFilter(ABC): text = content.strip() if text: texts.append(text) - result = ' '.join(texts) + result = " ".join(texts) _text_cache[elem_id] = result return result - + candidates = [] index = 0 - + def dfs(element): nonlocal index if isinstance(element, Tag): @@ -189,7 +266,7 @@ class RelevantContentFilter(ABC): if not self.is_excluded(element): text = fast_text(element) word_count = len(text.split()) - + # Headers pass through with adjusted minimum if element.name in self.header_tags: if word_count >= 3: # Minimal sanity check for headers @@ -199,7 +276,7 @@ class RelevantContentFilter(ABC): elif word_count >= self.min_word_count: candidates.append((index, text, element)) index += 1 - + for child in element.children: dfs(child) @@ -210,59 +287,67 @@ class RelevantContentFilter(ABC): """Common method for exclusion logic""" if tag.name in self.excluded_tags: return True - class_id = ' '.join(filter(None, [ - ' '.join(tag.get('class', [])), - tag.get('id', '') - ])) + class_id = " ".join( + filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")]) + ) return bool(self.negative_patterns.search(class_id)) def clean_element(self, tag: Tag) -> str: """Common method for cleaning HTML elements with minimal overhead""" if not tag or not isinstance(tag, Tag): return "" - - unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'} - unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'} - + + unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"} + unwanted_attrs = { + "style", + "onclick", + "onmouseover", + "align", + "bgcolor", + "class", + "id", + } + # Use string builder pattern for better performance builder = [] - + def render_tag(elem): if not isinstance(elem, Tag): if isinstance(elem, str): builder.append(elem.strip()) return - + if elem.name in unwanted_tags: return - + # Start tag - builder.append(f'<{elem.name}') - + builder.append(f"<{elem.name}") + # Add cleaned attributes attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs} for key, value in attrs.items(): builder.append(f' {key}="{value}"') - - builder.append('>') - + + builder.append(">") + # Process children for child in elem.children: render_tag(child) - + # Close tag - builder.append(f'{elem.name}>') - + builder.append(f"{elem.name}>") + try: render_tag(tag) - return ''.join(builder) + return "".join(builder) except Exception: return str(tag) # Fallback to original if anything fails + class BM25ContentFilter(RelevantContentFilter): """ Content filtering using BM25 algorithm with priority tag handling. - + How it works: 1. Extracts page metadata with fallbacks. 2. Extracts text chunks from the body element. @@ -271,22 +356,28 @@ class BM25ContentFilter(RelevantContentFilter): 5. Filters out chunks below the threshold. 6. Sorts chunks by score in descending order. 7. Returns the top N chunks. - + Attributes: user_query (str): User query for filtering (optional). bm25_threshold (float): BM25 threshold for filtering (default: 1.0). language (str): Language for stemming (default: 'english'). - + Methods: filter_content(self, html: str, min_word_threshold: int = None) """ - def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'): + + def __init__( + self, + user_query: str = None, + bm25_threshold: float = 1.0, + language: str = "english", + ): """ Initializes the BM25ContentFilter class, if not provided, falls back to page metadata. - + Note: If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph. - + Args: user_query (str): User query for filtering (optional). bm25_threshold (float): BM25 threshold for filtering (default: 1.0). @@ -295,52 +386,52 @@ class BM25ContentFilter(RelevantContentFilter): super().__init__(user_query=user_query) self.bm25_threshold = bm25_threshold self.priority_tags = { - 'h1': 5.0, - 'h2': 4.0, - 'h3': 3.0, - 'title': 4.0, - 'strong': 2.0, - 'b': 1.5, - 'em': 1.5, - 'blockquote': 2.0, - 'code': 2.0, - 'pre': 1.5, - 'th': 1.5, # Table headers + "h1": 5.0, + "h2": 4.0, + "h3": 3.0, + "title": 4.0, + "strong": 2.0, + "b": 1.5, + "em": 1.5, + "blockquote": 2.0, + "code": 2.0, + "pre": 1.5, + "th": 1.5, # Table headers } self.stemmer = stemmer(language) def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: """ Implements content filtering using BM25 algorithm with priority tag handling. - + Note: This method implements the filtering logic for the BM25ContentFilter class. It takes HTML content as input and returns a list of filtered text chunks. - + Args: html (str): HTML content to be filtered. min_word_threshold (int): Minimum word threshold for filtering (optional). - + Returns: List[str]: List of filtered text chunks. """ if not html or not isinstance(html, str): return [] - soup = BeautifulSoup(html, 'lxml') - + soup = BeautifulSoup(html, "lxml") + # Check if body is present if not soup.body: # Wrap in body tag if missing - soup = BeautifulSoup(f'
{html}', 'lxml') - body = soup.find('body') - + soup = BeautifulSoup(f"{html}", "lxml") + body = soup.find("body") + query = self.extract_page_query(soup, body) - + if not query: return [] # return [self.clean_element(soup)] - + candidates = self.extract_text_chunks(body, min_word_threshold) if not candidates: @@ -349,16 +440,20 @@ class BM25ContentFilter(RelevantContentFilter): # Tokenize corpus # tokenized_corpus = [chunk.lower().split() for _, chunk, _, _ in candidates] # tokenized_query = query.lower().split() - - # tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()] - # for _, chunk, _, _ in candidates] - # tokenized_query = [ps.stem(word) for word in query.lower().split()] - - tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()] - for _, chunk, _, _ in candidates] - tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()] - # tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] + # tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()] + # for _, chunk, _, _ in candidates] + # tokenized_query = [ps.stem(word) for word in query.lower().split()] + + tokenized_corpus = [ + [self.stemmer.stemWord(word) for word in chunk.lower().split()] + for _, chunk, _, _ in candidates + ] + tokenized_query = [ + self.stemmer.stemWord(word) for word in query.lower().split() + ] + + # tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] # for _, chunk, _, _ in candidates] # tokenized_query = [self.stemmer.stemWord(word) for word in tokenize_text(query.lower())] @@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter): # Filter candidates by threshold selected_candidates = [ - (index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates + (index, chunk, tag) + for adjusted_score, index, chunk, tag in adjusted_candidates if adjusted_score >= self.bm25_threshold ] @@ -390,10 +486,11 @@ class BM25ContentFilter(RelevantContentFilter): return [self.clean_element(tag) for _, _, tag in selected_candidates] + class PruningContentFilter(RelevantContentFilter): """ Content filtering using pruning algorithm with dynamic threshold. - + How it works: 1. Extracts page metadata with fallbacks. 2. Extracts text chunks from the body element. @@ -407,18 +504,24 @@ class PruningContentFilter(RelevantContentFilter): min_word_threshold (int): Minimum word threshold for filtering (optional). threshold_type (str): Threshold type for dynamic threshold (default: 'fixed'). threshold (float): Fixed threshold value (default: 0.48). - + Methods: filter_content(self, html: str, min_word_threshold: int = None): """ - def __init__(self, user_query: str = None, min_word_threshold: int = None, - threshold_type: str = 'fixed', threshold: float = 0.48): + + def __init__( + self, + user_query: str = None, + min_word_threshold: int = None, + threshold_type: str = "fixed", + threshold: float = 0.48, + ): """ Initializes the PruningContentFilter class, if not provided, falls back to page metadata. - + Note: If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph. - + Args: user_query (str): User query for filtering (optional). min_word_threshold (int): Minimum word threshold for filtering (optional). @@ -429,92 +532,92 @@ class PruningContentFilter(RelevantContentFilter): self.min_word_threshold = min_word_threshold self.threshold_type = threshold_type self.threshold = threshold - + # Add tag importance for dynamic threshold self.tag_importance = { - 'article': 1.5, - 'main': 1.4, - 'section': 1.3, - 'p': 1.2, - 'h1': 1.4, - 'h2': 1.3, - 'h3': 1.2, - 'div': 0.7, - 'span': 0.6 + "article": 1.5, + "main": 1.4, + "section": 1.3, + "p": 1.2, + "h1": 1.4, + "h2": 1.3, + "h3": 1.2, + "div": 0.7, + "span": 0.6, } - + # Metric configuration self.metric_config = { - 'text_density': True, - 'link_density': True, - 'tag_weight': True, - 'class_id_weight': True, - 'text_length': True, + "text_density": True, + "link_density": True, + "tag_weight": True, + "class_id_weight": True, + "text_length": True, } - + self.metric_weights = { - 'text_density': 0.4, - 'link_density': 0.2, - 'tag_weight': 0.2, - 'class_id_weight': 0.1, - 'text_length': 0.1, + "text_density": 0.4, + "link_density": 0.2, + "tag_weight": 0.2, + "class_id_weight": 0.1, + "text_length": 0.1, } - + self.tag_weights = { - 'div': 0.5, - 'p': 1.0, - 'article': 1.5, - 'section': 1.0, - 'span': 0.3, - 'li': 0.5, - 'ul': 0.5, - 'ol': 0.5, - 'h1': 1.2, - 'h2': 1.1, - 'h3': 1.0, - 'h4': 0.9, - 'h5': 0.8, - 'h6': 0.7, + "div": 0.5, + "p": 1.0, + "article": 1.5, + "section": 1.0, + "span": 0.3, + "li": 0.5, + "ul": 0.5, + "ol": 0.5, + "h1": 1.2, + "h2": 1.1, + "h3": 1.0, + "h4": 0.9, + "h5": 0.8, + "h6": 0.7, } def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: """ Implements content filtering using pruning algorithm with dynamic threshold. - + Note: This method implements the filtering logic for the PruningContentFilter class. It takes HTML content as input and returns a list of filtered text chunks. - + Args: html (str): HTML content to be filtered. min_word_threshold (int): Minimum word threshold for filtering (optional). - + Returns: List[str]: List of filtered text chunks. """ if not html or not isinstance(html, str): return [] - - soup = BeautifulSoup(html, 'lxml') + + soup = BeautifulSoup(html, "lxml") if not soup.body: - soup = BeautifulSoup(f'{html}', 'lxml') - + soup = BeautifulSoup(f"{html}", "lxml") + # Remove comments and unwanted tags self._remove_comments(soup) self._remove_unwanted_tags(soup) - + # Prune tree starting from body - body = soup.find('body') + body = soup.find("body") self._prune_tree(body) - + # Extract remaining content as list of HTML strings content_blocks = [] for element in body.children: - if isinstance(element, str) or not hasattr(element, 'name'): + if isinstance(element, str) or not hasattr(element, "name"): continue if len(element.get_text(strip=True)) > 0: content_blocks.append(str(element)) - + return content_blocks def _remove_comments(self, soup): @@ -531,34 +634,38 @@ class PruningContentFilter(RelevantContentFilter): def _prune_tree(self, node): """ Prunes the tree starting from the given node. - + Args: node (Tag): The node from which the pruning starts. """ - if not node or not hasattr(node, 'name') or node.name is None: + if not node or not hasattr(node, "name") or node.name is None: return text_len = len(node.get_text(strip=True)) - tag_len = len(node.encode_contents().decode('utf-8')) - link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s) + tag_len = len(node.encode_contents().decode("utf-8")) + link_text_len = sum( + len(s.strip()) + for s in (a.string for a in node.find_all("a", recursive=False)) + if s + ) metrics = { - 'node': node, - 'tag_name': node.name, - 'text_len': text_len, - 'tag_len': tag_len, - 'link_text_len': link_text_len + "node": node, + "tag_name": node.name, + "text_len": text_len, + "tag_len": tag_len, + "link_text_len": link_text_len, } score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len) - if self.threshold_type == 'fixed': + if self.threshold_type == "fixed": should_remove = score < self.threshold else: # dynamic tag_importance = self.tag_importance.get(node.name, 0.7) text_ratio = text_len / tag_len if tag_len > 0 else 0 link_ratio = link_text_len / text_len if text_len > 0 else 1 - + threshold = self.threshold # base threshold if tag_importance > 1: threshold *= 0.8 @@ -566,13 +673,13 @@ class PruningContentFilter(RelevantContentFilter): threshold *= 0.9 if link_ratio > 0.6: threshold *= 1.2 - + should_remove = score < threshold if should_remove: node.decompose() else: - children = [child for child in node.children if hasattr(child, 'name')] + children = [child for child in node.children if hasattr(child, "name")] for child in children: self._prune_tree(child) @@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter): """Computes the composite score""" if self.min_word_threshold: # Get raw text from metrics node - avoid extra processing - text = metrics['node'].get_text(strip=True) - word_count = text.count(' ') + 1 + text = metrics["node"].get_text(strip=True) + word_count = text.count(" ") + 1 if word_count < self.min_word_threshold: return -1.0 # Guaranteed removal score = 0.0 total_weight = 0.0 - if self.metric_config['text_density']: + if self.metric_config["text_density"]: density = text_len / tag_len if tag_len > 0 else 0 - score += self.metric_weights['text_density'] * density - total_weight += self.metric_weights['text_density'] + score += self.metric_weights["text_density"] * density + total_weight += self.metric_weights["text_density"] - if self.metric_config['link_density']: + if self.metric_config["link_density"]: density = 1 - (link_text_len / text_len if text_len > 0 else 0) - score += self.metric_weights['link_density'] * density - total_weight += self.metric_weights['link_density'] + score += self.metric_weights["link_density"] * density + total_weight += self.metric_weights["link_density"] - if self.metric_config['tag_weight']: - tag_score = self.tag_weights.get(metrics['tag_name'], 0.5) - score += self.metric_weights['tag_weight'] * tag_score - total_weight += self.metric_weights['tag_weight'] + if self.metric_config["tag_weight"]: + tag_score = self.tag_weights.get(metrics["tag_name"], 0.5) + score += self.metric_weights["tag_weight"] * tag_score + total_weight += self.metric_weights["tag_weight"] - if self.metric_config['class_id_weight']: - class_score = self._compute_class_id_weight(metrics['node']) - score += self.metric_weights['class_id_weight'] * max(0, class_score) - total_weight += self.metric_weights['class_id_weight'] + if self.metric_config["class_id_weight"]: + class_score = self._compute_class_id_weight(metrics["node"]) + score += self.metric_weights["class_id_weight"] * max(0, class_score) + total_weight += self.metric_weights["class_id_weight"] - if self.metric_config['text_length']: - score += self.metric_weights['text_length'] * math.log(text_len + 1) - total_weight += self.metric_weights['text_length'] + if self.metric_config["text_length"]: + score += self.metric_weights["text_length"] * math.log(text_len + 1) + total_weight += self.metric_weights["text_length"] return score / total_weight if total_weight > 0 else 0 def _compute_class_id_weight(self, node): """Computes the class ID weight""" class_id_score = 0 - if 'class' in node.attrs: - classes = ' '.join(node['class']) + if "class" in node.attrs: + classes = " ".join(node["class"]) if self.negative_patterns.match(classes): class_id_score -= 0.5 - if 'id' in node.attrs: - element_id = node['id'] + if "id" in node.attrs: + element_id = node["id"] if self.negative_patterns.match(element_id): class_id_score -= 0.5 - return class_id_score \ No newline at end of file + return class_id_score diff --git a/crawl4ai/content_scraping_strategy.py b/crawl4ai/content_scraping_strategy.py index ae09037d..6cb169db 100644 --- a/crawl4ai/content_scraping_strategy.py +++ b/crawl4ai/content_scraping_strategy.py @@ -1,12 +1,18 @@ -import re +import re from itertools import chain -import time from abc import ABC, abstractmethod from typing import Dict, Any, Optional from bs4 import BeautifulSoup -from concurrent.futures import ThreadPoolExecutor -import asyncio, requests, re, os -from .config import * +import asyncio +import requests +from .config import ( + MIN_WORD_THRESHOLD, + IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, + IMAGE_SCORE_THRESHOLD, + ONLY_TEXT_ELIGIBLE_TAGS, + IMPORTANT_ATTRS, + SOCIAL_MEDIA_DOMAINS, +) from bs4 import NavigableString, Comment from bs4 import PageElement, Tag from urllib.parse import urljoin @@ -14,18 +20,18 @@ from requests.exceptions import InvalidSchema from .utils import ( extract_metadata, normalize_url, - is_external_url, - get_base_domain, - extract_metadata_using_lxml + is_external_url, + get_base_domain, + extract_metadata_using_lxml, ) from lxml import etree from lxml import html as lhtml -from typing import Dict, Any, List, Tuple +from typing import List from .models import ScrapingResult, MediaItem, Link, Media, Links # Pre-compile regular expressions for Open Graph and Twitter metadata -OG_REGEX = re.compile(r'^og:') -TWITTER_REGEX = re.compile(r'^twitter:') +OG_REGEX = re.compile(r"^og:") +TWITTER_REGEX = re.compile(r"^twitter:") DIMENSION_REGEX = re.compile(r"(\d+)(\D*)") @@ -34,17 +40,22 @@ def parse_srcset(s: str) -> List[Dict]: if not s: return [] variants = [] - for part in s.split(','): + for part in s.split(","): part = part.strip() if not part: continue parts = part.split() if len(parts) >= 1: url = parts[0] - width = parts[1].rstrip('w') if len(parts) > 1 and parts[1].endswith('w') else None - variants.append({'url': url, 'width': width}) + width = ( + parts[1].rstrip("w") + if len(parts) > 1 and parts[1].endswith("w") + else None + ) + variants.append({"url": url, "width": width}) return variants + # Function to parse image height/width value and units def parse_dimension(dimension): if dimension: @@ -52,26 +63,28 @@ def parse_dimension(dimension): match = DIMENSION_REGEX.match(dimension) if match: number = int(match.group(1)) - unit = match.group(2) or 'px' # Default unit is 'px' if not specified + unit = match.group(2) or "px" # Default unit is 'px' if not specified return number, unit return None, None + # Fetch image file metadata to extract size and extension def fetch_image_file_size(img, base_url): - #If src is relative path construct full URL, if not it may be CDN URL - img_url = urljoin(base_url,img.get('src')) + # If src is relative path construct full URL, if not it may be CDN URL + img_url = urljoin(base_url, img.get("src")) try: response = requests.head(img_url) if response.status_code == 200: - return response.headers.get('Content-Length',None) + return response.headers.get("Content-Length", None) else: print(f"Failed to retrieve file size for {img_url}") return None - except InvalidSchema as e: + except InvalidSchema: return None finally: return + class ContentScrapingStrategy(ABC): @abstractmethod def scrap(self, url: str, html: str, **kwargs) -> ScrapingResult: @@ -81,10 +94,11 @@ class ContentScrapingStrategy(ABC): async def ascrap(self, url: str, html: str, **kwargs) -> ScrapingResult: pass + class WebScrapingStrategy(ContentScrapingStrategy): """ - Class for web content scraping. Perhaps the most important class. - + Class for web content scraping. Perhaps the most important class. + How it works: 1. Extract content from HTML using BeautifulSoup. 2. Clean the extracted content using a content cleaning strategy. @@ -92,7 +106,7 @@ class WebScrapingStrategy(ContentScrapingStrategy): 4. Generate markdown content from the filtered content. 5. Return the markdown content. """ - + def __init__(self, logger=None): self.logger = logger @@ -101,10 +115,10 @@ class WebScrapingStrategy(ContentScrapingStrategy): if self.logger: log_method = getattr(self.logger, level) log_method(message=message, tag=tag, **kwargs) - + def scrap(self, url: str, html: str, **kwargs) -> ScrapingResult: """ - Main entry point for content scraping. + Main entry point for content scraping. Args: url (str): The URL of the page to scrape. @@ -121,20 +135,40 @@ class WebScrapingStrategy(ContentScrapingStrategy): success=False, media=Media(), links=Links(), - metadata={} + metadata={}, ) # Convert media items media = Media( - images=[MediaItem(**img) for img in raw_result.get("media", {}).get("images", []) if img], - videos=[MediaItem(**vid) for vid in raw_result.get("media", {}).get("videos", []) if vid], - audios=[MediaItem(**aud) for aud in raw_result.get("media", {}).get("audios", []) if aud] + images=[ + MediaItem(**img) + for img in raw_result.get("media", {}).get("images", []) + if img + ], + videos=[ + MediaItem(**vid) + for vid in raw_result.get("media", {}).get("videos", []) + if vid + ], + audios=[ + MediaItem(**aud) + for aud in raw_result.get("media", {}).get("audios", []) + if aud + ], ) # Convert links links = Links( - internal=[Link(**link) for link in raw_result.get("links", {}).get("internal", []) if link], - external=[Link(**link) for link in raw_result.get("links", {}).get("external", []) if link] + internal=[ + Link(**link) + for link in raw_result.get("links", {}).get("internal", []) + if link + ], + external=[ + Link(**link) + for link in raw_result.get("links", {}).get("external", []) + if link + ], ) return ScrapingResult( @@ -142,7 +176,7 @@ class WebScrapingStrategy(ContentScrapingStrategy): success=raw_result.get("success", False), media=media, links=links, - metadata=raw_result.get("metadata", {}) + metadata=raw_result.get("metadata", {}), ) async def ascrap(self, url: str, html: str, **kwargs) -> ScrapingResult: @@ -171,7 +205,11 @@ class WebScrapingStrategy(ContentScrapingStrategy): """ if isinstance(node, NavigableString): return node - if len(node.contents) == 1 and isinstance(node.contents[0], Tag) and node.contents[0].name == node.name: + if ( + len(node.contents) == 1 + and isinstance(node.contents[0], Tag) + and node.contents[0].name == node.name + ): return self.flatten_nested_elements(node.contents[0]) node.contents = [self.flatten_nested_elements(child) for child in node.contents] return node @@ -187,23 +225,27 @@ class WebScrapingStrategy(ContentScrapingStrategy): Returns: Tag: The closest parent with useful text, or None if not found. """ - image_description_min_word_threshold = kwargs.get('image_description_min_word_threshold', IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD) + image_description_min_word_threshold = kwargs.get( + "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD + ) current_tag = tag while current_tag: current_tag = current_tag.parent # Get the text content of the parent tag if current_tag: - text_content = current_tag.get_text(separator=' ',strip=True) + text_content = current_tag.get_text(separator=" ", strip=True) # Check if the text content has at least word_count_threshold if len(text_content.split()) >= image_description_min_word_threshold: return text_content return None - def remove_unwanted_attributes(self, element, important_attrs, keep_data_attributes=False): + def remove_unwanted_attributes( + self, element, important_attrs, keep_data_attributes=False + ): """ Remove unwanted attributes from an HTML element. - Args: + Args: element (Tag): The HTML element to remove attributes from. important_attrs (list): List of important attributes to keep. keep_data_attributes (bool): Whether to keep data attributes. @@ -215,18 +257,18 @@ class WebScrapingStrategy(ContentScrapingStrategy): for attr in element.attrs: if attr not in important_attrs: if keep_data_attributes: - if not attr.startswith('data-'): + if not attr.startswith("data-"): attrs_to_remove.append(attr) else: attrs_to_remove.append(attr) - + for attr in attrs_to_remove: del element[attr] def process_image(self, img, url, index, total_images, **kwargs): """ Process an image element. - + How it works: 1. Check if the image has valid display and inside undesired html elements. 2. Score an image for it's usefulness. @@ -244,33 +286,35 @@ class WebScrapingStrategy(ContentScrapingStrategy): Returns: dict: A dictionary containing the processed image information. """ - # parse_srcset = lambda s: [{'url': u.strip().split()[0], 'width': u.strip().split()[-1].rstrip('w') - # if ' ' in u else None} + # parse_srcset = lambda s: [{'url': u.strip().split()[0], 'width': u.strip().split()[-1].rstrip('w') + # if ' ' in u else None} # for u in [f"http{p}" for p in s.split("http") if p]] - + # Constants for checks - classes_to_check = frozenset(['button', 'icon', 'logo']) - tags_to_check = frozenset(['button', 'input']) - image_formats = frozenset(['jpg', 'jpeg', 'png', 'webp', 'avif', 'gif']) - + classes_to_check = frozenset(["button", "icon", "logo"]) + tags_to_check = frozenset(["button", "input"]) + image_formats = frozenset(["jpg", "jpeg", "png", "webp", "avif", "gif"]) + # Pre-fetch commonly used attributes - style = img.get('style', '') - alt = img.get('alt', '') - src = img.get('src', '') - data_src = img.get('data-src', '') - srcset = img.get('srcset', '') - data_srcset = img.get('data-srcset', '') - width = img.get('width') - height = img.get('height') + style = img.get("style", "") + alt = img.get("alt", "") + src = img.get("src", "") + data_src = img.get("data-src", "") + srcset = img.get("srcset", "") + data_srcset = img.get("data-srcset", "") + width = img.get("width") + height = img.get("height") parent = img.parent - parent_classes = parent.get('class', []) + parent_classes = parent.get("class", []) # Quick validation checks - if ('display:none' in style or - parent.name in tags_to_check or - any(c in cls for c in parent_classes for cls in classes_to_check) or - any(c in src for c in classes_to_check) or - any(c in alt for c in classes_to_check)): + if ( + "display:none" in style + or parent.name in tags_to_check + or any(c in cls for c in parent_classes for cls in classes_to_check) + or any(c in src for c in classes_to_check) + or any(c in alt for c in classes_to_check) + ): return None # Quick score calculation @@ -283,30 +327,29 @@ class WebScrapingStrategy(ContentScrapingStrategy): score += 1 if height_val > 150 else 0 if alt: score += 1 - score += index/total_images < 0.5 - + score += index / total_images < 0.5 + # image_format = '' # if "data:image/" in src: # image_format = src.split(',')[0].split(';')[0].split('/')[1].split(';')[0] # else: # image_format = os.path.splitext(src)[1].lower().strip('.').split('?')[0] - + # if image_format in ('jpg', 'png', 'webp', 'avif'): # score += 1 - - + # Check for image format in all possible sources def has_image_format(url): return any(fmt in url.lower() for fmt in image_formats) - + # Score for having proper image sources if any(has_image_format(url) for url in [src, data_src, srcset, data_srcset]): score += 1 if srcset or data_srcset: score += 1 - if img.find_parent('picture'): + if img.find_parent("picture"): score += 1 - + # Detect format from any available source detected_format = None for url in [src, data_src, srcset, data_srcset]: @@ -314,62 +357,66 @@ class WebScrapingStrategy(ContentScrapingStrategy): format_matches = [fmt for fmt in image_formats if fmt in url.lower()] if format_matches: detected_format = format_matches[0] - break + break - if score <= kwargs.get('image_score_threshold', IMAGE_SCORE_THRESHOLD): + if score <= kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD): return None # Use set for deduplication unique_urls = set() image_variants = [] - + # Generate a unique group ID for this set of variants - group_id = index - + group_id = index + # Base image info template base_info = { - 'alt': alt, - 'desc': self.find_closest_parent_with_useful_text(img, **kwargs), - 'score': score, - 'type': 'image', - 'group_id': group_id, # Group ID for this set of variants - 'format': detected_format, + "alt": alt, + "desc": self.find_closest_parent_with_useful_text(img, **kwargs), + "score": score, + "type": "image", + "group_id": group_id, # Group ID for this set of variants + "format": detected_format, } # Inline function for adding variants def add_variant(src, width=None): - if src and not src.startswith('data:') and src not in unique_urls: + if src and not src.startswith("data:") and src not in unique_urls: unique_urls.add(src) - image_variants.append({**base_info, 'src': src, 'width': width}) + image_variants.append({**base_info, "src": src, "width": width}) # Process all sources add_variant(src) add_variant(data_src) - + # Handle srcset and data-srcset in one pass - for attr in ('srcset', 'data-srcset'): + for attr in ("srcset", "data-srcset"): if value := img.get(attr): for source in parse_srcset(value): - add_variant(source['url'], source['width']) + add_variant(source["url"], source["width"]) # Quick picture element check - if picture := img.find_parent('picture'): - for source in picture.find_all('source'): - if srcset := source.get('srcset'): + if picture := img.find_parent("picture"): + for source in picture.find_all("source"): + if srcset := source.get("srcset"): for src in parse_srcset(srcset): - add_variant(src['url'], src['width']) + add_variant(src["url"], src["width"]) # Framework-specific attributes in one pass for attr, value in img.attrs.items(): - if attr.startswith('data-') and ('src' in attr or 'srcset' in attr) and 'http' in value: + if ( + attr.startswith("data-") + and ("src" in attr or "srcset" in attr) + and "http" in value + ): add_variant(value) return image_variants if image_variants else None - def process_element(self, url, element: PageElement, **kwargs) -> Dict[str, Any]: + def process_element(self, url, element: PageElement, **kwargs) -> Dict[str, Any]: """ Process an HTML element. - + How it works: 1. Check if the element is an image, video, or audio. 2. Extract the element's attributes and content. @@ -384,89 +431,92 @@ class WebScrapingStrategy(ContentScrapingStrategy): Returns: dict: A dictionary containing the processed element information. """ - media = {'images': [], 'videos': [], 'audios': []} + media = {"images": [], "videos": [], "audios": []} internal_links_dict = {} external_links_dict = {} self._process_element( - url, - element, - media, - internal_links_dict, - external_links_dict, - **kwargs + url, element, media, internal_links_dict, external_links_dict, **kwargs ) return { - 'media': media, - 'internal_links_dict': internal_links_dict, - 'external_links_dict': external_links_dict + "media": media, + "internal_links_dict": internal_links_dict, + "external_links_dict": external_links_dict, } - - def _process_element(self, url, element: PageElement, media: Dict[str, Any], internal_links_dict: Dict[str, Any], external_links_dict: Dict[str, Any], **kwargs) -> bool: + + def _process_element( + self, + url, + element: PageElement, + media: Dict[str, Any], + internal_links_dict: Dict[str, Any], + external_links_dict: Dict[str, Any], + **kwargs, + ) -> bool: """ - Process an HTML element. + Process an HTML element. """ try: if isinstance(element, NavigableString): if isinstance(element, Comment): element.extract() return False - + # if element.name == 'img': # process_image(element, url, 0, 1) # return True base_domain = kwargs.get("base_domain", get_base_domain(url)) - if element.name in ['script', 'style', 'link', 'meta', 'noscript']: + if element.name in ["script", "style", "link", "meta", "noscript"]: element.decompose() return False keep_element = False - - exclude_domains = kwargs.get('exclude_domains', []) + + exclude_domains = kwargs.get("exclude_domains", []) # exclude_social_media_domains = kwargs.get('exclude_social_media_domains', set(SOCIAL_MEDIA_DOMAINS)) # exclude_social_media_domains = SOCIAL_MEDIA_DOMAINS + kwargs.get('exclude_social_media_domains', []) # exclude_social_media_domains = list(set(exclude_social_media_domains)) - + try: - if element.name == 'a' and element.get('href'): - href = element.get('href', '').strip() + if element.name == "a" and element.get("href"): + href = element.get("href", "").strip() if not href: # Skip empty hrefs return False - - url_base = url.split('/')[2] - + + # url_base = url.split("/")[2] + # Normalize the URL try: normalized_href = normalize_url(href, url) - except ValueError as e: + except ValueError: # logging.warning(f"Invalid URL format: {href}, Error: {str(e)}") return False - + link_data = { - 'href': normalized_href, - 'text': element.get_text().strip(), - 'title': element.get('title', '').strip(), - 'base_domain': base_domain + "href": normalized_href, + "text": element.get_text().strip(), + "title": element.get("title", "").strip(), + "base_domain": base_domain, } - + is_external = is_external_url(normalized_href, base_domain) - + keep_element = True - + # Handle external link exclusions if is_external: link_base_domain = get_base_domain(normalized_href) - link_data['base_domain'] = link_base_domain - if kwargs.get('exclude_external_links', False): + link_data["base_domain"] = link_base_domain + if kwargs.get("exclude_external_links", False): element.decompose() return False # elif kwargs.get('exclude_social_media_links', False): # if link_base_domain in exclude_social_media_domains: # element.decompose() # return False - # if any(domain in normalized_href.lower() for domain in exclude_social_media_domains): - # element.decompose() - # return False + # if any(domain in normalized_href.lower() for domain in exclude_social_media_domains): + # element.decompose() + # return False elif exclude_domains: if link_base_domain in exclude_domains: element.decompose() @@ -482,32 +532,36 @@ class WebScrapingStrategy(ContentScrapingStrategy): if normalized_href not in internal_links_dict: internal_links_dict[normalized_href] = link_data - except Exception as e: raise Exception(f"Error processing links: {str(e)}") try: - if element.name == 'img': - potential_sources = ['src', 'data-src', 'srcset' 'data-lazy-src', 'data-original'] - src = element.get('src', '') + if element.name == "img": + potential_sources = [ + "src", + "data-src", + "srcset" "data-lazy-src", + "data-original", + ] + src = element.get("src", "") while not src and potential_sources: - src = element.get(potential_sources.pop(0), '') + src = element.get(potential_sources.pop(0), "") if not src: element.decompose() return False - + # If it is srcset pick up the first image - if 'srcset' in element.attrs: - src = element.attrs['srcset'].split(',')[0].split(' ')[0] - + if "srcset" in element.attrs: + src = element.attrs["srcset"].split(",")[0].split(" ")[0] + # If image src is internal, then skip if not is_external_url(src, base_domain): return True - + image_src_base_domain = get_base_domain(src) - + # Check flag if we should remove external images - if kwargs.get('exclude_external_images', False): + if kwargs.get("exclude_external_images", False): element.decompose() return False # src_url_base = src.split('/')[2] @@ -515,78 +569,98 @@ class WebScrapingStrategy(ContentScrapingStrategy): # if url_base not in src_url_base: # element.decompose() # return False - + # if kwargs.get('exclude_social_media_links', False): # if image_src_base_domain in exclude_social_media_domains: # element.decompose() # return False - # src_url_base = src.split('/')[2] - # url_base = url.split('/')[2] - # if any(domain in src for domain in exclude_social_media_domains): - # element.decompose() - # return False - + # src_url_base = src.split('/')[2] + # url_base = url.split('/')[2] + # if any(domain in src for domain in exclude_social_media_domains): + # element.decompose() + # return False + # Handle exclude domains - if exclude_domains: + if exclude_domains: if image_src_base_domain in exclude_domains: element.decompose() return False # if any(domain in src for domain in kwargs.get('exclude_domains', [])): # element.decompose() # return False - + return True # Always keep image elements - except Exception as e: + except Exception: raise "Error processing images" - - + # Check if flag to remove all forms is set - if kwargs.get('remove_forms', False) and element.name == 'form': + if kwargs.get("remove_forms", False) and element.name == "form": element.decompose() return False - - if element.name in ['video', 'audio']: - media[f"{element.name}s"].append({ - 'src': element.get('src'), - 'alt': element.get('alt'), - 'type': element.name, - 'description': self.find_closest_parent_with_useful_text(element, **kwargs) - }) - source_tags = element.find_all('source') + + if element.name in ["video", "audio"]: + media[f"{element.name}s"].append( + { + "src": element.get("src"), + "alt": element.get("alt"), + "type": element.name, + "description": self.find_closest_parent_with_useful_text( + element, **kwargs + ), + } + ) + source_tags = element.find_all("source") for source_tag in source_tags: - media[f"{element.name}s"].append({ - 'src': source_tag.get('src'), - 'alt': element.get('alt'), - 'type': element.name, - 'description': self.find_closest_parent_with_useful_text(element, **kwargs) - }) + media[f"{element.name}s"].append( + { + "src": source_tag.get("src"), + "alt": element.get("alt"), + "type": element.name, + "description": self.find_closest_parent_with_useful_text( + element, **kwargs + ), + } + ) return True # Always keep video and audio elements if element.name in ONLY_TEXT_ELIGIBLE_TAGS: - if kwargs.get('only_text', False): + if kwargs.get("only_text", False): element.replace_with(element.get_text()) try: - self.remove_unwanted_attributes(element, IMPORTANT_ATTRS, kwargs.get('keep_data_attributes', False)) + self.remove_unwanted_attributes( + element, IMPORTANT_ATTRS, kwargs.get("keep_data_attributes", False) + ) except Exception as e: # print('Error removing unwanted attributes:', str(e)) - self._log('error', + self._log( + "error", message="Error removing unwanted attributes: {error}", tag="SCRAPE", - params={"error": str(e)} + params={"error": str(e)}, ) # Process children for child in list(element.children): - if isinstance(child, NavigableString) and not isinstance(child, Comment): + if isinstance(child, NavigableString) and not isinstance( + child, Comment + ): if len(child.strip()) > 0: keep_element = True else: - if self._process_element(url, child, media, internal_links_dict, external_links_dict, **kwargs): + if self._process_element( + url, + child, + media, + internal_links_dict, + external_links_dict, + **kwargs, + ): keep_element = True - # Check word count - word_count_threshold = kwargs.get('word_count_threshold', MIN_WORD_THRESHOLD) + word_count_threshold = kwargs.get( + "word_count_threshold", MIN_WORD_THRESHOLD + ) if not keep_element: word_count = len(element.get_text(strip=True).split()) keep_element = word_count >= word_count_threshold @@ -597,14 +671,22 @@ class WebScrapingStrategy(ContentScrapingStrategy): return keep_element except Exception as e: # print('Error processing element:', str(e)) - self._log('error', + self._log( + "error", message="Error processing element: {error}", tag="SCRAPE", - params={"error": str(e)} - ) + params={"error": str(e)}, + ) return False - def _scrap(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, css_selector: str = None, **kwargs) -> Dict[str, Any]: + def _scrap( + self, + url: str, + html: str, + word_count_threshold: int = MIN_WORD_THRESHOLD, + css_selector: str = None, + **kwargs, + ) -> Dict[str, Any]: """ Extract content from HTML using BeautifulSoup. @@ -622,83 +704,93 @@ class WebScrapingStrategy(ContentScrapingStrategy): if not html: return None - parser_type = kwargs.get('parser', 'lxml') + parser_type = kwargs.get("parser", "lxml") soup = BeautifulSoup(html, parser_type) body = soup.body base_domain = get_base_domain(url) - + try: meta = extract_metadata("", soup) except Exception as e: - self._log('error', + self._log( + "error", message="Error extracting metadata: {error}", tag="SCRAPE", - params={"error": str(e)} - ) + params={"error": str(e)}, + ) meta = {} - + # Handle tag-based removal first - faster than CSS selection - excluded_tags = set(kwargs.get('excluded_tags', []) or []) + excluded_tags = set(kwargs.get("excluded_tags", []) or []) if excluded_tags: for element in body.find_all(lambda tag: tag.name in excluded_tags): element.extract() - + # Handle CSS selector-based removal - excluded_selector = kwargs.get('excluded_selector', '') + excluded_selector = kwargs.get("excluded_selector", "") if excluded_selector: - is_single_selector = ',' not in excluded_selector and ' ' not in excluded_selector + is_single_selector = ( + "," not in excluded_selector and " " not in excluded_selector + ) if is_single_selector: while element := body.select_one(excluded_selector): element.extract() else: for element in body.select(excluded_selector): - element.extract() - + element.extract() + if css_selector: selected_elements = body.select(css_selector) if not selected_elements: return { - 'markdown': '', - 'cleaned_html': '', - 'success': True, - 'media': {'images': [], 'videos': [], 'audios': []}, - 'links': {'internal': [], 'external': []}, - 'metadata': {}, - 'message': f"No elements found for CSS selector: {css_selector}" + "markdown": "", + "cleaned_html": "", + "success": True, + "media": {"images": [], "videos": [], "audios": []}, + "links": {"internal": [], "external": []}, + "metadata": {}, + "message": f"No elements found for CSS selector: {css_selector}", } # raise InvalidCSSSelectorError(f"Invalid CSS selector, No elements found for CSS selector: {css_selector}") - body = soup.new_tag('div') + body = soup.new_tag("div") for el in selected_elements: body.append(el) - kwargs['exclude_social_media_domains'] = set(kwargs.get('exclude_social_media_domains', []) + SOCIAL_MEDIA_DOMAINS) - kwargs['exclude_domains'] = set(kwargs.get('exclude_domains', [])) - if kwargs.get('exclude_social_media_links', False): - kwargs['exclude_domains'] = kwargs['exclude_domains'].union(kwargs['exclude_social_media_domains']) - - result_obj = self.process_element( - url, - body, - word_count_threshold = word_count_threshold, - base_domain=base_domain, - **kwargs + kwargs["exclude_social_media_domains"] = set( + kwargs.get("exclude_social_media_domains", []) + SOCIAL_MEDIA_DOMAINS ) - - links = {'internal': [], 'external': []} - media = result_obj['media'] - internal_links_dict = result_obj['internal_links_dict'] - external_links_dict = result_obj['external_links_dict'] - + kwargs["exclude_domains"] = set(kwargs.get("exclude_domains", [])) + if kwargs.get("exclude_social_media_links", False): + kwargs["exclude_domains"] = kwargs["exclude_domains"].union( + kwargs["exclude_social_media_domains"] + ) + + result_obj = self.process_element( + url, + body, + word_count_threshold=word_count_threshold, + base_domain=base_domain, + **kwargs, + ) + + links = {"internal": [], "external": []} + media = result_obj["media"] + internal_links_dict = result_obj["internal_links_dict"] + external_links_dict = result_obj["external_links_dict"] + # Update the links dictionary with unique links - links['internal'] = list(internal_links_dict.values()) - links['external'] = list(external_links_dict.values()) + links["internal"] = list(internal_links_dict.values()) + links["external"] = list(external_links_dict.values()) # # Process images using ThreadPoolExecutor - imgs = body.find_all('img') - - media['images'] = [ - img for result in (self.process_image(img, url, i, len(imgs), **kwargs) - for i, img in enumerate(imgs)) + imgs = body.find_all("img") + + media["images"] = [ + img + for result in ( + self.process_image(img, url, i, len(imgs), **kwargs) + for i, img in enumerate(imgs) + ) if result is not None for img in result ] @@ -706,22 +798,22 @@ class WebScrapingStrategy(ContentScrapingStrategy): body = self.flatten_nested_elements(body) base64_pattern = re.compile(r'data:image/[^;]+;base64,([^"]+)') for img in imgs: - src = img.get('src', '') + src = img.get("src", "") if base64_pattern.match(src): # Replace base64 data with empty string - img['src'] = base64_pattern.sub('', src) - + img["src"] = base64_pattern.sub("", src) + str_body = "" try: - str_body = body.encode_contents().decode('utf-8') - except Exception as e: + str_body = body.encode_contents().decode("utf-8") + except Exception: # Reset body to the original HTML success = False - body = BeautifulSoup(html, 'html.parser') - + body = BeautifulSoup(html, "html.parser") + # Create a new div with a special ID - error_div = body.new_tag('div', id='crawl4ai_error_message') - error_div.string = ''' + error_div = body.new_tag("div", id="crawl4ai_error_message") + error_div.string = """ Crawl4AI Error: This page is not fully supported. Possible reasons: @@ -734,128 +826,146 @@ class WebScrapingStrategy(ContentScrapingStrategy): - Set headless=False to visualize what's happening on the page. If the issue persists, please check the page's structure and any potential anti-crawling measures. - ''' - + """ + # Append the error div to the body body.append(error_div) - str_body = body.encode_contents().decode('utf-8') - - print(f"[LOG] 😧 Error: After processing the crawled HTML and removing irrelevant tags, nothing was left in the page. Check the markdown for further details.") - self._log('error', + str_body = body.encode_contents().decode("utf-8") + + print( + "[LOG] 😧 Error: After processing the crawled HTML and removing irrelevant tags, nothing was left in the page. Check the markdown for further details." + ) + self._log( + "error", message="After processing the crawled HTML and removing irrelevant tags, nothing was left in the page. Check the markdown for further details.", - tag="SCRAPE" + tag="SCRAPE", ) - cleaned_html = str_body.replace('\n\n', '\n').replace(' ', ' ') + cleaned_html = str_body.replace("\n\n", "\n").replace(" ", " ") - return { # **markdown_content, - 'cleaned_html': cleaned_html, - 'success': success, - 'media': media, - 'links': links, - 'metadata': meta + "cleaned_html": cleaned_html, + "success": success, + "media": media, + "links": links, + "metadata": meta, } + class LXMLWebScrapingStrategy(WebScrapingStrategy): def __init__(self, logger=None): super().__init__(logger) self.DIMENSION_REGEX = re.compile(r"(\d+)(\D*)") self.BASE64_PATTERN = re.compile(r'data:image/[^;]+;base64,([^"]+)') - def _process_element(self, url: str, element: lhtml.HtmlElement, media: Dict[str, List], - internal_links_dict: Dict[str, Any], external_links_dict: Dict[str, Any], **kwargs) -> bool: + def _process_element( + self, + url: str, + element: lhtml.HtmlElement, + media: Dict[str, List], + internal_links_dict: Dict[str, Any], + external_links_dict: Dict[str, Any], + **kwargs, + ) -> bool: base_domain = kwargs.get("base_domain", get_base_domain(url)) - exclude_domains = set(kwargs.get('exclude_domains', [])) - + exclude_domains = set(kwargs.get("exclude_domains", [])) + # Process links - for link in element.xpath('.//a[@href]'): - href = link.get('href', '').strip() + for link in element.xpath(".//a[@href]"): + href = link.get("href", "").strip() if not href: continue - + try: normalized_href = normalize_url(href, url) link_data = { - 'href': normalized_href, - 'text': link.text_content().strip(), - 'title': link.get('title', '').strip(), - 'base_domain': base_domain + "href": normalized_href, + "text": link.text_content().strip(), + "title": link.get("title", "").strip(), + "base_domain": base_domain, } - + is_external = is_external_url(normalized_href, base_domain) if is_external: link_base_domain = get_base_domain(normalized_href) - link_data['base_domain'] = link_base_domain - if kwargs.get('exclude_external_links', False) or link_base_domain in exclude_domains: + link_data["base_domain"] = link_base_domain + if ( + kwargs.get("exclude_external_links", False) + or link_base_domain in exclude_domains + ): link.getparent().remove(link) continue - + if normalized_href not in external_links_dict: external_links_dict[normalized_href] = link_data else: if normalized_href not in internal_links_dict: internal_links_dict[normalized_href] = link_data - + except Exception as e: - self._log('error', f"Error processing link: {str(e)}", "SCRAPE") + self._log("error", f"Error processing link: {str(e)}", "SCRAPE") continue # Process images - images = element.xpath('.//img') + images = element.xpath(".//img") total_images = len(images) - + for idx, img in enumerate(images): - src = img.get('src') or '' + src = img.get("src") or "" img_domain = get_base_domain(src) # Decide if we need to exclude this image # 1) If its domain is in exclude_domains, remove. # 2) Or if exclude_external_images=True and it's an external domain, remove. if (img_domain in exclude_domains) or ( - kwargs.get('exclude_external_images', False) and is_external_url(src, base_domain) + kwargs.get("exclude_external_images", False) + and is_external_url(src, base_domain) ): parent = img.getparent() if parent is not None: parent.remove(img) continue - + # Otherwise, process the image as usual. try: - processed_images = self.process_image(img, url, idx, total_images, **kwargs) + processed_images = self.process_image( + img, url, idx, total_images, **kwargs + ) if processed_images: - media['images'].extend(processed_images) + media["images"].extend(processed_images) except Exception as e: - self._log('error', f"Error processing image: {str(e)}", "SCRAPE") + self._log("error", f"Error processing image: {str(e)}", "SCRAPE") # Process videos and audios - for media_type in ['video', 'audio']: - for elem in element.xpath(f'.//{media_type}'): + for media_type in ["video", "audio"]: + for elem in element.xpath(f".//{media_type}"): media_info = { - 'src': elem.get('src'), - 'alt': elem.get('alt'), - 'type': media_type, - 'description': self.find_closest_parent_with_useful_text(elem, **kwargs) + "src": elem.get("src"), + "alt": elem.get("alt"), + "type": media_type, + "description": self.find_closest_parent_with_useful_text( + elem, **kwargs + ), } media[f"{media_type}s"].append(media_info) - + # Process source tags within media elements - for source in elem.xpath('.//source'): - if src := source.get('src'): - media[f"{media_type}s"].append({**media_info, 'src': src}) + for source in elem.xpath(".//source"): + if src := source.get("src"): + media[f"{media_type}s"].append({**media_info, "src": src}) # Clean up unwanted elements - if kwargs.get('remove_forms', False): - for form in element.xpath('.//form'): + if kwargs.get("remove_forms", False): + for form in element.xpath(".//form"): form.getparent().remove(form) - if excluded_tags := kwargs.get('excluded_tags', []): + if excluded_tags := kwargs.get("excluded_tags", []): for tag in excluded_tags: - for elem in element.xpath(f'.//{tag}'): + for elem in element.xpath(f".//{tag}"): elem.getparent().remove(elem) - if excluded_selector := kwargs.get('excluded_selector', ''): + if excluded_selector := kwargs.get("excluded_selector", ""): try: for elem in element.cssselect(excluded_selector): elem.getparent().remove(elem) @@ -864,12 +974,19 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): return True - def find_closest_parent_with_useful_text(self, element: lhtml.HtmlElement, **kwargs) -> Optional[str]: - image_description_min_word_threshold = kwargs.get('image_description_min_word_threshold', - IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD) + def find_closest_parent_with_useful_text( + self, element: lhtml.HtmlElement, **kwargs + ) -> Optional[str]: + image_description_min_word_threshold = kwargs.get( + "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD + ) current = element while current is not None: - if current.text and len(current.text_content().split()) >= image_description_min_word_threshold: + if ( + current.text + and len(current.text_content().split()) + >= image_description_min_word_threshold + ): return current.text_content().strip() current = current.getparent() return None @@ -878,52 +995,57 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): """Flatten nested elements of the same type in LXML tree""" if len(element) == 1 and element.tag == element[0].tag: return self.flatten_nested_elements(element[0]) - + for child in element: child_idx = element.index(child) flattened_child = self.flatten_nested_elements(child) if flattened_child is not child: # Only replace if actually flattened element[child_idx] = flattened_child - + return element - def process_image(self, img: lhtml.HtmlElement, url: str, index: int, total_images: int, **kwargs) -> Optional[List[Dict]]: + def process_image( + self, img: lhtml.HtmlElement, url: str, index: int, total_images: int, **kwargs + ) -> Optional[List[Dict]]: # Quick validation checks - style = img.get('style', '') - alt = img.get('alt', '') - src = img.get('src', '') - data_src = img.get('data-src', '') - srcset = img.get('srcset', '') - data_srcset = img.get('data-srcset', '') - - if 'display:none' in style: + style = img.get("style", "") + alt = img.get("alt", "") + src = img.get("src", "") + data_src = img.get("data-src", "") + srcset = img.get("srcset", "") + data_srcset = img.get("data-srcset", "") + + if "display:none" in style: return None parent = img.getparent() - if parent.tag in ['button', 'input']: + if parent.tag in ["button", "input"]: return None - parent_classes = parent.get('class', '').split() - if any('button' in cls or 'icon' in cls or 'logo' in cls for cls in parent_classes): + parent_classes = parent.get("class", "").split() + if any( + "button" in cls or "icon" in cls or "logo" in cls for cls in parent_classes + ): return None - + # If src is in class or alt, likely an icon - if (src and any(c in src for c in ['button', 'icon', 'logo'])) or \ - (alt and any(c in alt for c in ['button', 'icon', 'logo'])): + if (src and any(c in src for c in ["button", "icon", "logo"])) or ( + alt and any(c in alt for c in ["button", "icon", "logo"]) + ): return None # Score calculation score = 0 - if (width := img.get('width')) and width.isdigit(): + if (width := img.get("width")) and width.isdigit(): score += 1 if int(width) > 150 else 0 - if (height := img.get('height')) and height.isdigit(): + if (height := img.get("height")) and height.isdigit(): score += 1 if int(height) > 150 else 0 if alt: score += 1 - score += index/total_images < 0.5 + score += index / total_images < 0.5 # Check formats in all possible sources - image_formats = {'jpg', 'jpeg', 'png', 'webp', 'avif', 'gif'} + image_formats = {"jpg", "jpeg", "png", "webp", "avif", "gif"} detected_format = None for url in [src, data_src, srcset, data_srcset]: if url: @@ -936,51 +1058,55 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): if srcset or data_srcset: score += 1 - if picture := img.xpath('./ancestor::picture[1]'): + if picture := img.xpath("./ancestor::picture[1]"): score += 1 - if score <= kwargs.get('image_score_threshold', IMAGE_SCORE_THRESHOLD): + if score <= kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD): return None # Process image variants unique_urls = set() image_variants = [] base_info = { - 'alt': alt, - 'desc': self.find_closest_parent_with_useful_text(img, **kwargs), - 'score': score, - 'type': 'image', - 'group_id': index, - 'format': detected_format, + "alt": alt, + "desc": self.find_closest_parent_with_useful_text(img, **kwargs), + "score": score, + "type": "image", + "group_id": index, + "format": detected_format, } def add_variant(src: str, width: Optional[str] = None): - if src and not src.startswith('data:') and src not in unique_urls: + if src and not src.startswith("data:") and src not in unique_urls: unique_urls.add(src) - variant = {**base_info, 'src': src} + variant = {**base_info, "src": src} if width: - variant['width'] = width + variant["width"] = width image_variants.append(variant) # Add variants from different sources add_variant(src) add_variant(data_src) - + for srcset_attr in [srcset, data_srcset]: if srcset_attr: for source in parse_srcset(srcset_attr): - add_variant(source['url'], source['width']) + add_variant(source["url"], source["width"]) # Handle picture element if picture: - for source in picture[0].xpath('.//source[@srcset]'): - if source_srcset := source.get('srcset'): + for source in picture[0].xpath(".//source[@srcset]"): + if source_srcset := source.get("srcset"): for src_data in parse_srcset(source_srcset): - add_variant(src_data['url'], src_data['width']) + add_variant(src_data["url"], src_data["width"]) # Check framework-specific attributes for attr, value in img.attrib.items(): - if attr.startswith('data-') and ('src' in attr or 'srcset' in attr) and 'http' in value: + if ( + attr.startswith("data-") + and ("src" in attr or "srcset" in attr) + and "http" in value + ): add_variant(value) return image_variants if image_variants else None @@ -990,33 +1116,44 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): Remove elements that fall below the desired word threshold in a single pass from the bottom up. Skips non-element nodes like HtmlComment and bypasses certain tags that are allowed to have no content. """ - bypass_tags = {'a', 'img', 'br', 'hr', 'input', 'meta', 'link', 'source', 'track', 'wbr'} - + bypass_tags = { + "a", + "img", + "br", + "hr", + "input", + "meta", + "link", + "source", + "track", + "wbr", + } + for el in reversed(list(root.iterdescendants())): if not isinstance(el, lhtml.HtmlElement): continue - + if el.tag in bypass_tags: continue - + text_content = (el.text_content() or "").strip() - if len(text_content.split()) < word_count_threshold and not el.getchildren(): + if ( + len(text_content.split()) < word_count_threshold + and not el.getchildren() + ): parent = el.getparent() if parent is not None: parent.remove(el) - + return root - + def remove_unwanted_attributes_fast( - self, - root: lhtml.HtmlElement, - important_attrs=None, - keep_data_attributes=False + self, root: lhtml.HtmlElement, important_attrs=None, keep_data_attributes=False ) -> lhtml.HtmlElement: """ Removes all attributes from each element (including root) except those in `important_attrs`. If `keep_data_attributes=True`, also retain any attribute starting with 'data-'. - + Returns the same root element, mutated in-place, for fluent usage. """ if important_attrs is None: @@ -1029,26 +1166,32 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): # We only remove attributes on HtmlElement nodes, skip comments or text nodes if not isinstance(el, lhtml.HtmlElement): continue - + old_attribs = dict(el.attrib) new_attribs = {} - + for attr_name, attr_val in old_attribs.items(): # If it's an important attribute, keep it if attr_name in important_attrs: new_attribs[attr_name] = attr_val # Or if keep_data_attributes is True and it's a 'data-*' attribute - elif keep_data_attributes and attr_name.startswith('data-'): + elif keep_data_attributes and attr_name.startswith("data-"): new_attribs[attr_name] = attr_val # Clear old attributes and set the filtered set el.attrib.clear() el.attrib.update(new_attribs) - + return root - - def _scrap(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, - css_selector: str = None, **kwargs) -> Dict[str, Any]: + + def _scrap( + self, + url: str, + html: str, + word_count_threshold: int = MIN_WORD_THRESHOLD, + css_selector: str = None, + **kwargs, + ) -> Dict[str, Any]: if not html: return None @@ -1058,38 +1201,42 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): # Match BeautifulSoup's behavior of using body or full doc # body = doc.xpath('//body')[0] if doc.xpath('//body') else doc body = doc - + base_domain = get_base_domain(url) - - # Add comment removal - if kwargs.get('remove_comments', False): - comments = body.xpath('//comment()') + + # Add comment removal + if kwargs.get("remove_comments", False): + comments = body.xpath("//comment()") for comment in comments: comment.getparent().remove(comment) - + # Handle tag-based removal first - excluded_tags = set(kwargs.get('excluded_tags', []) or []) + excluded_tags = set(kwargs.get("excluded_tags", []) or []) if excluded_tags: for tag in excluded_tags: - for element in body.xpath(f'.//{tag}'): + for element in body.xpath(f".//{tag}"): if element.getparent() is not None: element.getparent().remove(element) - + # Handle CSS selector-based exclusion - excluded_selector = kwargs.get('excluded_selector', '') + excluded_selector = kwargs.get("excluded_selector", "") if excluded_selector: try: for element in body.cssselect(excluded_selector): if element.getparent() is not None: element.getparent().remove(element) except Exception as e: - self._log('error', f"Error with excluded CSS selector: {str(e)}", "SCRAPE") + self._log( + "error", f"Error with excluded CSS selector: {str(e)}", "SCRAPE" + ) # Extract metadata before any content filtering try: - meta = extract_metadata_using_lxml("", doc) # Using same function as BeautifulSoup version + meta = extract_metadata_using_lxml( + "", doc + ) # Using same function as BeautifulSoup version except Exception as e: - self._log('error', f"Error extracting metadata: {str(e)}", "SCRAPE") + self._log("error", f"Error extracting metadata: {str(e)}", "SCRAPE") meta = {} # Handle CSS selector targeting @@ -1098,101 +1245,106 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): selected_elements = body.cssselect(css_selector) if not selected_elements: return { - 'markdown': '', - 'cleaned_html': '', - 'success': True, - 'media': {'images': [], 'videos': [], 'audios': []}, - 'links': {'internal': [], 'external': []}, - 'metadata': meta, - 'message': f"No elements found for CSS selector: {css_selector}" + "markdown": "", + "cleaned_html": "", + "success": True, + "media": {"images": [], "videos": [], "audios": []}, + "links": {"internal": [], "external": []}, + "metadata": meta, + "message": f"No elements found for CSS selector: {css_selector}", } - body = lhtml.Element('div') + body = lhtml.Element("div") body.extend(selected_elements) except Exception as e: - self._log('error', f"Error with CSS selector: {str(e)}", "SCRAPE") + self._log("error", f"Error with CSS selector: {str(e)}", "SCRAPE") return None # Remove script and style tags - for tag in ['script', 'style', 'link', 'meta', 'noscript']: - for element in body.xpath(f'.//{tag}'): + for tag in ["script", "style", "link", "meta", "noscript"]: + for element in body.xpath(f".//{tag}"): if element.getparent() is not None: element.getparent().remove(element) # Handle social media and domain exclusions - kwargs['exclude_domains'] = set(kwargs.get('exclude_domains', [])) - if kwargs.get('exclude_social_media_links', False): - kwargs['exclude_social_media_domains'] = set(kwargs.get('exclude_social_media_domains', []) + SOCIAL_MEDIA_DOMAINS) - kwargs['exclude_domains'].update(kwargs['exclude_social_media_domains']) + kwargs["exclude_domains"] = set(kwargs.get("exclude_domains", [])) + if kwargs.get("exclude_social_media_links", False): + kwargs["exclude_social_media_domains"] = set( + kwargs.get("exclude_social_media_domains", []) + + SOCIAL_MEDIA_DOMAINS + ) + kwargs["exclude_domains"].update(kwargs["exclude_social_media_domains"]) # Process forms if needed - if kwargs.get('remove_forms', False): - for form in body.xpath('.//form'): + if kwargs.get("remove_forms", False): + for form in body.xpath(".//form"): if form.getparent() is not None: form.getparent().remove(form) - # Process content - media = {'images': [], 'videos': [], 'audios': []} + media = {"images": [], "videos": [], "audios": []} internal_links_dict = {} external_links_dict = {} - + self._process_element( - url, - body, - media, + url, + body, + media, internal_links_dict, external_links_dict, base_domain=base_domain, - **kwargs + **kwargs, ) # Handle only_text option - if kwargs.get('only_text', False): + if kwargs.get("only_text", False): for tag in ONLY_TEXT_ELIGIBLE_TAGS: - for element in body.xpath(f'.//{tag}'): + for element in body.xpath(f".//{tag}"): if element.text: - new_text = lhtml.Element('span') + new_text = lhtml.Element("span") new_text.text = element.text_content() if element.getparent() is not None: element.getparent().replace(element, new_text) # Clean base64 images - for img in body.xpath('.//img[@src]'): - src = img.get('src', '') + for img in body.xpath(".//img[@src]"): + src = img.get("src", "") if self.BASE64_PATTERN.match(src): - img.set('src', self.BASE64_PATTERN.sub('', src)) - + img.set("src", self.BASE64_PATTERN.sub("", src)) # Remove empty elements self.remove_empty_elements_fast(body, 1) - - # Remvoe unneeded attributes - self.remove_unwanted_attributes_fast(body, keep_data_attributes=kwargs.get('keep_data_attributes', False)) + # Remvoe unneeded attributes + self.remove_unwanted_attributes_fast( + body, keep_data_attributes=kwargs.get("keep_data_attributes", False) + ) # Generate output HTML - cleaned_html = lhtml.tostring(body, encoding='unicode', - pretty_print=True, - method='html', - with_tail=False).strip() + cleaned_html = lhtml.tostring( + body, + encoding="unicode", + pretty_print=True, + method="html", + with_tail=False, + ).strip() return { - 'cleaned_html': cleaned_html, - 'success': success, - 'media': media, - 'links': { - 'internal': list(internal_links_dict.values()), - 'external': list(external_links_dict.values()) + "cleaned_html": cleaned_html, + "success": success, + "media": media, + "links": { + "internal": list(internal_links_dict.values()), + "external": list(external_links_dict.values()), }, - 'metadata': meta + "metadata": meta, } - + except Exception as e: - self._log('error', f"Error processing HTML: {str(e)}", "SCRAPE") + self._log("error", f"Error processing HTML: {str(e)}", "SCRAPE") # Create error message in case of failure - error_body = lhtml.Element('div') + error_body = lhtml.Element("div") # Use etree.SubElement rather than lhtml.SubElement - error_div = etree.SubElement(error_body, 'div', id='crawl4ai_error_message') - error_div.text = f''' + error_div = etree.SubElement(error_body, "div", id="crawl4ai_error_message") + error_div.text = f""" Crawl4AI Error: This page is not fully supported. Error Message: {str(e)} @@ -1207,12 +1359,14 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): - Set headless=False to visualize what's happening on the page. If the issue persists, please check the page's structure and any potential anti-crawling measures. - ''' - cleaned_html = lhtml.tostring(error_body, encoding='unicode', pretty_print=True) + """ + cleaned_html = lhtml.tostring( + error_body, encoding="unicode", pretty_print=True + ) return { - 'cleaned_html': cleaned_html, - 'success': False, - 'media': {'images': [], 'videos': [], 'audios': []}, - 'links': {'internal': [], 'external': []}, - 'metadata': {} - } \ No newline at end of file + "cleaned_html": cleaned_html, + "success": False, + "media": {"images": [], "videos": [], "audios": []}, + "links": {"internal": [], "external": []}, + "metadata": {}, + } diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index 898dcfa8..34e20ecd 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -15,54 +15,53 @@ import logging, time import base64 from PIL import Image, ImageDraw, ImageFont from io import BytesIO -from typing import List, Callable +from typing import Callable import requests import os from pathlib import Path from .utils import * -logger = logging.getLogger('selenium.webdriver.remote.remote_connection') +logger = logging.getLogger("selenium.webdriver.remote.remote_connection") logger.setLevel(logging.WARNING) -logger_driver = logging.getLogger('selenium.webdriver.common.service') +logger_driver = logging.getLogger("selenium.webdriver.common.service") logger_driver.setLevel(logging.WARNING) -urllib3_logger = logging.getLogger('urllib3.connectionpool') +urllib3_logger = logging.getLogger("urllib3.connectionpool") urllib3_logger.setLevel(logging.WARNING) # Disable http.client logging -http_client_logger = logging.getLogger('http.client') +http_client_logger = logging.getLogger("http.client") http_client_logger.setLevel(logging.WARNING) # Disable driver_finder and service logging -driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder') +driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder") driver_finder_logger.setLevel(logging.WARNING) - - class CrawlerStrategy(ABC): @abstractmethod def crawl(self, url: str, **kwargs) -> str: pass - + @abstractmethod def take_screenshot(self, save_path: str): pass - + @abstractmethod def update_user_agent(self, user_agent: str): pass - + @abstractmethod def set_hook(self, hook_type: str, hook: Callable): pass + class CloudCrawlerStrategy(CrawlerStrategy): - def __init__(self, use_cached_html = False): + def __init__(self, use_cached_html=False): super().__init__() self.use_cached_html = use_cached_html - + def crawl(self, url: str) -> str: data = { "urls": [url], @@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy): html = response["results"][0]["html"] return sanitize_input_encode(html) + class LocalSeleniumCrawlerStrategy(CrawlerStrategy): def __init__(self, use_cached_html=False, js_code=None, **kwargs): super().__init__() @@ -87,20 +87,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): if kwargs.get("user_agent"): self.options.add_argument("--user-agent=" + kwargs.get("user_agent")) else: - user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") + user_agent = kwargs.get( + "user_agent", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + ) self.options.add_argument(f"--user-agent={user_agent}") - self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") - + self.options.add_argument( + "user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + ) + self.options.headless = kwargs.get("headless", True) if self.options.headless: self.options.add_argument("--headless") - - self.options.add_argument("--disable-gpu") + + self.options.add_argument("--disable-gpu") self.options.add_argument("--window-size=1920,1080") self.options.add_argument("--no-sandbox") self.options.add_argument("--disable-dev-shm-usage") - self.options.add_argument("--disable-blink-features=AutomationControlled") - + self.options.add_argument("--disable-blink-features=AutomationControlled") + # self.options.add_argument("--disable-dev-shm-usage") self.options.add_argument("--disable-gpu") # self.options.add_argument("--disable-extensions") @@ -120,14 +125,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.use_cached_html = use_cached_html self.js_code = js_code self.verbose = kwargs.get("verbose", False) - + # Hooks self.hooks = { - 'on_driver_created': None, - 'on_user_agent_updated': None, - 'before_get_url': None, - 'after_get_url': None, - 'before_return_html': None + "on_driver_created": None, + "on_user_agent_updated": None, + "before_get_url": None, + "after_get_url": None, + "before_return_html": None, } # chromedriver_autoinstaller.install() @@ -137,31 +142,28 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): # chromedriver_path = chromedriver_autoinstaller.install() # chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver() # self.service = Service(chromedriver_autoinstaller.install()) - - + # chromedriver_path = ChromeDriverManager().install() # self.service = Service(chromedriver_path) # self.service.log_path = "NUL" # self.driver = webdriver.Chrome(service=self.service, options=self.options) - + # Use selenium-manager (built into Selenium 4.10.0+) self.service = Service() self.driver = webdriver.Chrome(options=self.options) - - self.driver = self.execute_hook('on_driver_created', self.driver) - + + self.driver = self.execute_hook("on_driver_created", self.driver) + if kwargs.get("cookies"): for cookie in kwargs.get("cookies"): self.driver.add_cookie(cookie) - - def set_hook(self, hook_type: str, hook: Callable): if hook_type in self.hooks: self.hooks[hook_type] = hook else: raise ValueError(f"Invalid hook type: {hook_type}") - + def execute_hook(self, hook_type: str, *args): hook = self.hooks.get(hook_type) if hook: @@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): if isinstance(result, webdriver.Chrome): return result else: - raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.") + raise TypeError( + f"Hook {hook_type} must return an instance of webdriver.Chrome or None." + ) # If the hook returns None or there is no hook, return self.driver return self.driver @@ -178,60 +182,77 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.options.add_argument(f"user-agent={user_agent}") self.driver.quit() self.driver = webdriver.Chrome(service=self.service, options=self.options) - self.driver = self.execute_hook('on_user_agent_updated', self.driver) + self.driver = self.execute_hook("on_user_agent_updated", self.driver) def set_custom_headers(self, headers: dict): # Enable Network domain for sending headers - self.driver.execute_cdp_cmd('Network.enable', {}) + self.driver.execute_cdp_cmd("Network.enable", {}) # Set extra HTTP headers - self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers}) + self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers}) - def _ensure_page_load(self, max_checks=6, check_interval=0.01): + def _ensure_page_load(self, max_checks=6, check_interval=0.01): initial_length = len(self.driver.page_source) - + for ix in range(max_checks): # print(f"Checking page load: {ix}") time.sleep(check_interval) current_length = len(self.driver.page_source) - + if current_length != initial_length: break return self.driver.page_source - + def crawl(self, url: str, **kwargs) -> str: # Create md5 hash of the URL import hashlib + url_hash = hashlib.md5(url.encode()).hexdigest() - + if self.use_cached_html: - cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) + cache_file_path = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), + ".crawl4ai", + "cache", + url_hash, + ) if os.path.exists(cache_file_path): with open(cache_file_path, "r") as f: return sanitize_input_encode(f.read()) try: - self.driver = self.execute_hook('before_get_url', self.driver) + self.driver = self.execute_hook("before_get_url", self.driver) if self.verbose: print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...") - self.driver.get(url) # - + self.driver.get(url) # + WebDriverWait(self.driver, 20).until( - lambda d: d.execute_script('return document.readyState') == 'complete' + lambda d: d.execute_script("return document.readyState") == "complete" ) WebDriverWait(self.driver, 10).until( EC.presence_of_all_elements_located((By.TAG_NAME, "body")) ) - - self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);") - - self.driver = self.execute_hook('after_get_url', self.driver) - html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source - can_not_be_done_headless = False # Look at my creativity for naming variables - + + self.driver.execute_script( + "window.scrollTo(0, document.body.scrollHeight);" + ) + + self.driver = self.execute_hook("after_get_url", self.driver) + html = sanitize_input_encode( + self._ensure_page_load() + ) # self.driver.page_source + can_not_be_done_headless = ( + False # Look at my creativity for naming variables + ) + # TODO: Very ugly approach, but promise to change it! - if kwargs.get('bypass_headless', False) or html == "": - print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...") + if ( + kwargs.get("bypass_headless", False) + or html == "" + ): + print( + "[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..." + ) can_not_be_done_headless = True options = Options() options.headless = False @@ -239,27 +260,31 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): options.add_argument("--window-size=5,5") driver = webdriver.Chrome(service=self.service, options=options) driver.get(url) - self.driver = self.execute_hook('after_get_url', driver) + self.driver = self.execute_hook("after_get_url", driver) html = sanitize_input_encode(driver.page_source) driver.quit() - + # Execute JS code if provided self.js_code = kwargs.get("js_code", self.js_code) if self.js_code and type(self.js_code) == str: self.driver.execute_script(self.js_code) # Optionally, wait for some condition after executing the JS code WebDriverWait(self.driver, 10).until( - lambda driver: driver.execute_script("return document.readyState") == "complete" + lambda driver: driver.execute_script("return document.readyState") + == "complete" ) elif self.js_code and type(self.js_code) == list: for js in self.js_code: self.driver.execute_script(js) WebDriverWait(self.driver, 10).until( - lambda driver: driver.execute_script("return document.readyState") == "complete" + lambda driver: driver.execute_script( + "return document.readyState" + ) + == "complete" ) - + # Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky) - wait_for = kwargs.get('wait_for', False) + wait_for = kwargs.get("wait_for", False) if wait_for: if callable(wait_for): print("[LOG] 🔄 Waiting for condition...") @@ -268,32 +293,37 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): print("[LOG] 🔄 Waiting for condition...") WebDriverWait(self.driver, 20).until( EC.presence_of_element_located((By.CSS_SELECTOR, wait_for)) - ) - + ) + if not can_not_be_done_headless: html = sanitize_input_encode(self.driver.page_source) - self.driver = self.execute_hook('before_return_html', self.driver, html) - + self.driver = self.execute_hook("before_return_html", self.driver, html) + # Store in cache - cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) + cache_file_path = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), + ".crawl4ai", + "cache", + url_hash, + ) with open(cache_file_path, "w", encoding="utf-8") as f: f.write(html) - + if self.verbose: print(f"[LOG] ✅ Crawled {url} successfully!") - + return html except InvalidArgumentException as e: - if not hasattr(e, 'msg'): + if not hasattr(e, "msg"): e.msg = sanitize_input_encode(str(e)) raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}") except WebDriverException as e: # If e does nlt have msg attribute create it and set it to str(e) - if not hasattr(e, 'msg'): + if not hasattr(e, "msg"): e.msg = sanitize_input_encode(str(e)) - raise WebDriverException(f"Failed to crawl {url}: {e.msg}") + raise WebDriverException(f"Failed to crawl {url}: {e.msg}") except Exception as e: - if not hasattr(e, 'msg'): + if not hasattr(e, "msg"): e.msg = sanitize_input_encode(str(e)) raise Exception(f"Failed to crawl {url}: {e.msg}") @@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): try: # Get the dimensions of the page total_width = self.driver.execute_script("return document.body.scrollWidth") - total_height = self.driver.execute_script("return document.body.scrollHeight") + total_height = self.driver.execute_script( + "return document.body.scrollHeight" + ) # Set the window size to the dimensions of the page self.driver.set_window_size(total_width, total_height) @@ -313,25 +345,27 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): image = Image.open(BytesIO(screenshot)) # Convert image to RGB mode (this will handle both RGB and RGBA images) - rgb_image = image.convert('RGB') + rgb_image = image.convert("RGB") # Convert to JPEG and compress buffered = BytesIO() rgb_image.save(buffered, format="JPEG", quality=85) - img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") if self.verbose: - print(f"[LOG] 📸 Screenshot taken and converted to base64") + print("[LOG] 📸 Screenshot taken and converted to base64") return img_base64 except Exception as e: - error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}") + error_message = sanitize_input_encode( + f"Failed to take screenshot: {str(e)}" + ) print(error_message) # Generate an image with black background - img = Image.new('RGB', (800, 600), color='black') + img = Image.new("RGB", (800, 600), color="black") draw = ImageDraw.Draw(img) - + # Load a font try: font = ImageFont.truetype("arial.ttf", 40) @@ -345,16 +379,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): # Calculate text position text_position = (10, 10) - + # Draw the text on the image draw.text(text_position, wrapped_text, fill=text_color, font=font) - + # Convert to base64 buffered = BytesIO() img.save(buffered, format="JPEG") - img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_base64 - + def quit(self): self.driver.quit() diff --git a/crawl4ai/database.py b/crawl4ai/database.py index 42ad7017..815b6b05 100644 --- a/crawl4ai/database.py +++ b/crawl4ai/database.py @@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra os.makedirs(DB_PATH, exist_ok=True) DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") + def init_db(): global DB_PATH conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS crawled_data ( url TEXT PRIMARY KEY, html TEXT, @@ -24,31 +26,42 @@ def init_db(): metadata TEXT DEFAULT "{}", screenshot TEXT DEFAULT "" ) - ''') + """ + ) conn.commit() conn.close() + def alter_db_add_screenshot(new_column: str = "media"): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') + cursor.execute( + f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""' + ) conn.commit() conn.close() except Exception as e: print(f"Error altering database to add screenshot column: {e}") + def check_db_path(): if not DB_PATH: raise ValueError("Database path is not set or is empty.") -def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]: + +def get_cached_url( + url: str, +) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]: check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,)) + cursor.execute( + "SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?", + (url,), + ) result = cursor.fetchone() conn.close() return result @@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str print(f"Error retrieving cached URL: {e}") return None -def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""): + +def cache_url( + url: str, + html: str, + cleaned_html: str, + markdown: str, + extracted_content: str, + success: bool, + media: str = "{}", + links: str = "{}", + metadata: str = "{}", + screenshot: str = "", +): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(url) DO UPDATE SET @@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c links = excluded.links, metadata = excluded.metadata, screenshot = excluded.screenshot - ''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)) + """, + ( + url, + html, + cleaned_html, + markdown, + extracted_content, + success, + media, + links, + metadata, + screenshot, + ), + ) conn.commit() conn.close() except Exception as e: print(f"Error caching URL: {e}") + def get_total_count() -> int: check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('SELECT COUNT(*) FROM crawled_data') + cursor.execute("SELECT COUNT(*) FROM crawled_data") result = cursor.fetchone() conn.close() return result[0] @@ -93,43 +133,48 @@ def get_total_count() -> int: print(f"Error getting total count: {e}") return 0 + def clear_db(): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('DELETE FROM crawled_data') + cursor.execute("DELETE FROM crawled_data") conn.commit() conn.close() except Exception as e: print(f"Error clearing database: {e}") - + + def flush_db(): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('DROP TABLE crawled_data') + cursor.execute("DROP TABLE crawled_data") conn.commit() conn.close() except Exception as e: print(f"Error flushing database: {e}") + def update_existing_records(new_column: str = "media", default_value: str = "{}"): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL') + cursor.execute( + f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL' + ) conn.commit() conn.close() except Exception as e: print(f"Error updating existing records: {e}") + if __name__ == "__main__": # Delete the existing database file if os.path.exists(DB_PATH): os.remove(DB_PATH) - init_db() + init_db() # alter_db_add_screenshot("COL_NAME") - diff --git a/crawl4ai/docs_manager.py b/crawl4ai/docs_manager.py index aacc5812..9a6096a5 100644 --- a/crawl4ai/docs_manager.py +++ b/crawl4ai/docs_manager.py @@ -4,6 +4,7 @@ from pathlib import Path from crawl4ai.async_logger import AsyncLogger from crawl4ai.llmtxt import AsyncLLMTextManager + class DocsManager: def __init__(self, logger=None): self.docs_dir = Path.home() / ".crawl4ai" / "docs" @@ -21,11 +22,14 @@ class DocsManager: """Copy from local docs or download from GitHub""" try: # Try local first - if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))): + if self.local_docs.exists() and ( + any(self.local_docs.glob("*.md")) + or any(self.local_docs.glob("*.tokens")) + ): # Empty the local docs directory for file_path in self.docs_dir.glob("*.md"): file_path.unlink() - # for file_path in self.docs_dir.glob("*.tokens"): + # for file_path in self.docs_dir.glob("*.tokens"): # file_path.unlink() for file_path in self.local_docs.glob("*.md"): shutil.copy2(file_path, self.docs_dir / file_path.name) @@ -36,14 +40,14 @@ class DocsManager: # Fallback to GitHub response = requests.get( "https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt", - headers={'Accept': 'application/vnd.github.v3+json'} + headers={"Accept": "application/vnd.github.v3+json"}, ) response.raise_for_status() - + for item in response.json(): - if item['type'] == 'file' and item['name'].endswith('.md'): - content = requests.get(item['download_url']).text - with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f: + if item["type"] == "file" and item["name"].endswith(".md"): + content = requests.get(item["download_url"]).text + with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f: f.write(content) return True @@ -57,11 +61,15 @@ class DocsManager: # Remove [0-9]+_ prefix names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names] # Exclude those end with .xs.md and .q.md - names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")] + names = [ + name + for name in names + if not name.endswith(".xs") and not name.endswith(".q") + ] return names - + def generate(self, sections, mode="extended"): return self.llm_text.generate(sections, mode) - + def search(self, query: str, top_k: int = 5): - return self.llm_text.search(query, top_k) \ No newline at end of file + return self.llm_text.search(query, top_k) diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index f6c62cc9..65cc005d 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -1,26 +1,54 @@ from abc import ABC, abstractmethod -from typing import Any, List, Dict, Optional, Union +from typing import Any, List, Dict, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -import json, time -# from optimum.intel import IPEXModel -from .prompts import * -from .config import * -from .utils import * -from .models import * +import json +import time +import os + +from .prompts import PROMPT_EXTRACT_BLOCKS +from .config import ( + DEFAULT_PROVIDER, PROVIDER_MODELS, + CHUNK_TOKEN_THRESHOLD, + OVERLAP_RATE, + WORD_TOKEN_RATE, + PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION, + PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION +) +from .utils import * # noqa: F403 + +from .utils import ( + sanitize_html, + calculate_batch_size, + escape_json_string, + perform_completion_with_backoff, + extract_xml_data, + split_and_parse_json_objects, + sanitize_input_encode, +) +from .models import * # noqa: F403 + +from .models import TokenUsage + +from .model_loader import * # noqa: F403 +from .model_loader import ( + get_device, + load_HF_embedding_model, + load_text_multilabel_classifier, +) + from functools import partial -from .model_loader import * import math import numpy as np import re from bs4 import BeautifulSoup from lxml import html, etree -from dataclasses import dataclass + class ExtractionStrategy(ABC): """ Abstract base class for all extraction strategies. """ - + def __init__(self, input_format: str = "markdown", **kwargs): """ Initialize the extraction strategy. @@ -45,7 +73,7 @@ class ExtractionStrategy(ABC): :return: A list of extracted blocks or chunks. """ pass - + def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: """ Process sections of text in parallel by default. @@ -56,40 +84,49 @@ class ExtractionStrategy(ABC): """ extracted_content = [] with ThreadPoolExecutor() as executor: - futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections] + futures = [ + executor.submit(self.extract, url, section, **kwargs) + for section in sections + ] for future in as_completed(futures): extracted_content.extend(future.result()) - return extracted_content - + return extracted_content + + class NoExtractionStrategy(ExtractionStrategy): """ A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block. """ + def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: """ Extract meaningful blocks or chunks from the given HTML. """ return [{"index": 0, "content": html}] - + def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: - return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)] + return [ + {"index": i, "tags": [], "content": section} + for i, section in enumerate(sections) + ] ####################################################### # Strategies using clustering for text data extraction # ####################################################### + class CosineStrategy(ExtractionStrategy): """ Extract meaningful blocks or chunks from the given HTML using cosine similarity. - + How it works: 1. Pre-filter documents using embeddings and semantic_filter. 2. Perform clustering using cosine similarity. 3. Organize texts by their cluster labels, retaining order. 4. Filter clusters by word count. 5. Extract meaningful blocks or chunks from the filtered clusters. - + Attributes: semantic_filter (str): A keyword filter for document filtering. word_count_threshold (int): Minimum number of words per cluster. @@ -98,8 +135,19 @@ class CosineStrategy(ExtractionStrategy): top_k (int): Number of top categories to extract. model_name (str): The name of the sentence-transformers model. sim_threshold (float): The similarity threshold for clustering. - """ - def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'sentence-transformers/all-MiniLM-L6-v2', sim_threshold = 0.3, **kwargs): + """ + + def __init__( + self, + semantic_filter=None, + word_count_threshold=10, + max_dist=0.2, + linkage_method="ward", + top_k=3, + model_name="sentence-transformers/all-MiniLM-L6-v2", + sim_threshold=0.3, + **kwargs, + ): """ Initialize the strategy with clustering parameters. @@ -111,9 +159,9 @@ class CosineStrategy(ExtractionStrategy): top_k (int): Number of top categories to extract. """ super().__init__(**kwargs) - + import numpy as np - + self.semantic_filter = semantic_filter self.word_count_threshold = word_count_threshold self.max_dist = max_dist @@ -122,14 +170,14 @@ class CosineStrategy(ExtractionStrategy): self.sim_threshold = sim_threshold self.timer = time.time() self.verbose = kwargs.get("verbose", False) - + self.buffer_embeddings = np.array([]) self.get_embedding_method = "direct" - + self.device = get_device() # import torch # self.device = torch.device('cpu') - + self.default_batch_size = calculate_batch_size(self.device) if self.verbose: @@ -143,10 +191,10 @@ class CosineStrategy(ExtractionStrategy): self.tokenizer, self.model = load_HF_embedding_model(model_name) self.model.to(self.device) - self.model.eval() - + self.model.eval() + self.get_embedding_method = "batch" - + self.buffer_embeddings = np.array([]) # if model_name == "bert-base-uncased": @@ -161,18 +209,23 @@ class CosineStrategy(ExtractionStrategy): # self.model = load_onnx_all_MiniLM_l6_v2() # self.tokenizer = self.model.tokenizer # self.get_embedding_method = "direct" - - + if self.verbose: print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.") - + self.nlp, _ = load_text_multilabel_classifier() # self.default_batch_size = 16 if self.device.type == 'cpu' else 64 - - if self.verbose: - print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") - def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, at_least_k: int = 20) -> List[str]: + if self.verbose: + print( + f"[LOG] Model loaded {model_name}, models/reuters, took " + + str(time.time() - self.timer) + + " seconds" + ) + + def filter_documents_embeddings( + self, documents: List[str], semantic_filter: str, at_least_k: int = 20 + ) -> List[str]: """ Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding. @@ -184,39 +237,51 @@ class CosineStrategy(ExtractionStrategy): Returns: List[str]: A list of filtered and sorted document texts. """ - + if not semantic_filter: return documents - + if len(documents) < at_least_k: at_least_k = len(documents) // 2 - + from sklearn.metrics.pairwise import cosine_similarity - + # Compute embedding for the keyword filter query_embedding = self.get_embeddings([semantic_filter])[0] - + # Compute embeddings for the documents document_embeddings = self.get_embeddings(documents) - + # Calculate cosine similarity between the query embedding and document embeddings - similarities = cosine_similarity([query_embedding], document_embeddings).flatten() - + similarities = cosine_similarity( + [query_embedding], document_embeddings + ).flatten() + # Filter documents based on the similarity threshold - filtered_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim >= self.sim_threshold] - + filtered_docs = [ + (doc, sim) + for doc, sim in zip(documents, similarities) + if sim >= self.sim_threshold + ] + # If the number of filtered documents is less than at_least_k, sort remaining documents by similarity if len(filtered_docs) < at_least_k: - remaining_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim < self.sim_threshold] + remaining_docs = [ + (doc, sim) + for doc, sim in zip(documents, similarities) + if sim < self.sim_threshold + ] remaining_docs.sort(key=lambda x: x[1], reverse=True) - filtered_docs.extend(remaining_docs[:at_least_k - len(filtered_docs)]) - + filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)]) + # Extract the document texts from the tuples filtered_docs = [doc for doc, _ in filtered_docs] - + return filtered_docs[:at_least_k] - - def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=False): + + def get_embeddings( + self, sentences: List[str], batch_size=None, bypass_buffer=False + ): """ Get BERT embeddings for a list of sentences. @@ -228,43 +293,48 @@ class CosineStrategy(ExtractionStrategy): """ # if self.buffer_embeddings.any() and not bypass_buffer: # return self.buffer_embeddings - - if self.device.type in [ "cpu", "gpu", "cuda", "mps"]: - import torch + + if self.device.type in ["cpu", "gpu", "cuda", "mps"]: + import torch + # Tokenize sentences and convert to tensor if batch_size is None: batch_size = self.default_batch_size - + all_embeddings = [] for i in range(0, len(sentences), batch_size): - batch_sentences = sentences[i:i + batch_size] - encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt') - encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()} - + batch_sentences = sentences[i : i + batch_size] + encoded_input = self.tokenizer( + batch_sentences, padding=True, truncation=True, return_tensors="pt" + ) + encoded_input = { + key: tensor.to(self.device) for key, tensor in encoded_input.items() + } + # Ensure no gradients are calculated with torch.no_grad(): model_output = self.model(**encoded_input) - + # Get embeddings from the last hidden state (mean pooling) embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy() all_embeddings.append(embeddings) - + self.buffer_embeddings = np.vstack(all_embeddings) - elif self.device.type == "cpu": + elif self.device.type == "cpu": # self.buffer_embeddings = self.model(sentences) if batch_size is None: batch_size = self.default_batch_size - + all_embeddings = [] for i in range(0, len(sentences), batch_size): - batch_sentences = sentences[i:i + batch_size] + batch_sentences = sentences[i : i + batch_size] embeddings = self.model(batch_sentences) all_embeddings.append(embeddings) - + self.buffer_embeddings = np.vstack(all_embeddings) return self.buffer_embeddings - def hierarchical_clustering(self, sentences: List[str], embeddings = None): + def hierarchical_clustering(self, sentences: List[str], embeddings=None): """ Perform hierarchical clustering on sentences and return cluster labels. @@ -277,18 +347,21 @@ class CosineStrategy(ExtractionStrategy): # Get embeddings from scipy.cluster.hierarchy import linkage, fcluster from scipy.spatial.distance import pdist + self.timer = time.time() embeddings = self.get_embeddings(sentences, bypass_buffer=True) # print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds") # Compute pairwise cosine distances - distance_matrix = pdist(embeddings, 'cosine') + distance_matrix = pdist(embeddings, "cosine") # Perform agglomerative clustering respecting order linked = linkage(distance_matrix, method=self.linkage_method) # Form flat clusters - labels = fcluster(linked, self.max_dist, criterion='distance') + labels = fcluster(linked, self.max_dist, criterion="distance") return labels - def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]) -> Dict[int, List[str]]: + def filter_clusters_by_word_count( + self, clusters: Dict[int, List[str]] + ) -> Dict[int, List[str]]: """ Filter clusters to remove those with a word count below the threshold. @@ -304,7 +377,7 @@ class CosineStrategy(ExtractionStrategy): full_text = " ".join(texts) # Count words word_count = len(full_text.split()) - + # Keep clusters with word count above the threshold if word_count >= self.word_count_threshold: filtered_clusters[cluster_id] = texts @@ -325,9 +398,11 @@ class CosineStrategy(ExtractionStrategy): # Assume `html` is a list of text chunks for this strategy t = time.time() text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed - + # Pre-filter documents using embeddings and semantic_filter - text_chunks = self.filter_documents_embeddings(text_chunks, self.semantic_filter) + text_chunks = self.filter_documents_embeddings( + text_chunks, self.semantic_filter + ) if not text_chunks: return [] @@ -346,16 +421,19 @@ class CosineStrategy(ExtractionStrategy): filtered_clusters = self.filter_clusters_by_word_count(clusters) # Convert filtered clusters to a sorted list of dictionaries - cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)] - + cluster_list = [ + {"index": int(idx), "tags": [], "content": " ".join(filtered_clusters[idx])} + for idx in sorted(filtered_clusters) + ] + if self.verbose: print(f"[LOG] 🚀 Assign tags using {self.device}") - + if self.device.type in ["gpu", "cuda", "mps", "cpu"]: - labels = self.nlp([cluster['content'] for cluster in cluster_list]) - + labels = self.nlp([cluster["content"] for cluster in cluster_list]) + for cluster, label in zip(cluster_list, labels): - cluster['tags'] = label + cluster["tags"] = label # elif self.device.type == "cpu": # # Process the text with the loaded model # texts = [cluster['content'] for cluster in cluster_list] @@ -366,16 +444,16 @@ class CosineStrategy(ExtractionStrategy): # tok_k = self.top_k # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] # cluster['tags'] = [cat for cat, _ in top_categories] - - # for cluster in cluster_list: - # doc = self.nlp(cluster['content']) - # tok_k = self.top_k - # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] - # cluster['tags'] = [cat for cat, _ in top_categories] - + + # for cluster in cluster_list: + # doc = self.nlp(cluster['content']) + # tok_k = self.top_k + # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] + # cluster['tags'] = [cat for cat, _ in top_categories] + if self.verbose: print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") - + return cluster_list def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: @@ -389,9 +467,8 @@ class CosineStrategy(ExtractionStrategy): Returns: """ # This strategy processes all sections together - - return self.extract(url, self.DEL.join(sections), **kwargs) + return self.extract(url, self.DEL.join(sections), **kwargs) ####################################################### @@ -400,11 +477,11 @@ class CosineStrategy(ExtractionStrategy): class LLMExtractionStrategy(ExtractionStrategy): """ A strategy that uses an LLM to extract meaningful content from the HTML. - + Attributes: provider: The provider to use for extraction. It follows the formatThis is a test paragraph.
- """) - + """ + ) + async with AsyncWebCrawler(verbose=True) as crawler: # Process local file - local_result = await crawler.arun( - url=f"file://{os.path.abspath(sample_file)}" - ) - + local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}") + # Process raw HTML raw_html = """ @@ -78,16 +79,15 @@ async def local_and_raw_html_example():This is a test of raw HTML processing.
""" - raw_result = await crawler.arun( - url=f"raw:{raw_html}" - ) - + raw_result = await crawler.arun(url=f"raw:{raw_html}") + # Clean up os.remove(sample_file) - + print("Local file content:", local_result.markdown) print("\nRaw HTML content:", raw_result.markdown) + # 3. Enhanced Markdown Generation Example async def markdown_generation_example(): """Example of enhanced markdown generation with citations and LLM-friendly features""" @@ -97,58 +97,66 @@ async def markdown_generation_example(): # user_query="History and cultivation", bm25_threshold=1.0 ) - + result = await crawler.arun( url="https://en.wikipedia.org/wiki/Apple", css_selector="main div#bodyContent", content_filter=content_filter, - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - - from crawl4ai import AsyncWebCrawler + from crawl4ai.content_filter_strategy import BM25ContentFilter - + result = await crawler.arun( url="https://en.wikipedia.org/wiki/Apple", css_selector="main div#bodyContent", - content_filter=BM25ContentFilter() + content_filter=BM25ContentFilter(), ) print(result.markdown_v2.fit_markdown) - + print("\nMarkdown Generation Results:") print(f"1. Original markdown length: {len(result.markdown)}") - print(f"2. New markdown versions (markdown_v2):") + print("2. New markdown versions (markdown_v2):") print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}") - print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}") - print(f" - References section length: {len(result.markdown_v2.references_markdown)}") + print( + f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}" + ) + print( + f" - References section length: {len(result.markdown_v2.references_markdown)}" + ) if result.markdown_v2.fit_markdown: - print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}") - + print( + f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}" + ) + # Save examples to files output_dir = os.path.join(__data__, "markdown_examples") os.makedirs(output_dir, exist_ok=True) - + # Save different versions with open(os.path.join(output_dir, "1_raw_markdown.md"), "w") as f: f.write(result.markdown_v2.raw_markdown) - + with open(os.path.join(output_dir, "2_citations_markdown.md"), "w") as f: f.write(result.markdown_v2.markdown_with_citations) - + with open(os.path.join(output_dir, "3_references.md"), "w") as f: f.write(result.markdown_v2.references_markdown) - + if result.markdown_v2.fit_markdown: with open(os.path.join(output_dir, "4_filtered_markdown.md"), "w") as f: f.write(result.markdown_v2.fit_markdown) - + print(f"\nMarkdown examples saved to: {output_dir}") - + # Show a sample of citations and references print("\nSample of markdown with citations:") print(result.markdown_v2.markdown_with_citations[:500] + "...\n") print("Sample of references:") - print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...") + print( + "\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..." + ) + # 4. Browser Management Example async def browser_management_example(): @@ -156,38 +164,38 @@ async def browser_management_example(): # Use the specified user directory path user_data_dir = os.path.join(Path.home(), ".crawl4ai", "browser_profile") os.makedirs(user_data_dir, exist_ok=True) - + print(f"Browser profile will be saved to: {user_data_dir}") - + async with AsyncWebCrawler( use_managed_browser=True, user_data_dir=user_data_dir, headless=False, - verbose=True + verbose=True, ) as crawler: - result = await crawler.arun( url="https://crawl4ai.com", # session_id="persistent_session_1", - cache_mode=CacheMode.BYPASS - ) + cache_mode=CacheMode.BYPASS, + ) # Use GitHub as an example - it's a good test for browser management # because it requires proper browser handling result = await crawler.arun( url="https://github.com/trending", # session_id="persistent_session_1", - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - + print("\nBrowser session result:", result.success) if result.success: - print("Page title:", result.metadata.get('title', 'No title found')) + print("Page title:", result.metadata.get("title", "No title found")) + # 5. API Usage Example async def api_example(): """Example of using the new API endpoints""" - api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" - headers = {'Authorization': f'Bearer {api_token}'} + api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code" + headers = {"Authorization": f"Bearer {api_token}"} async with aiohttp.ClientSession() as session: # Submit crawl job crawl_request = { @@ -199,25 +207,17 @@ async def api_example(): "name": "Hacker News Articles", "baseSelector": ".athing", "fields": [ - { - "name": "title", - "selector": ".title a", - "type": "text" - }, - { - "name": "score", - "selector": ".score", - "type": "text" - }, + {"name": "title", "selector": ".title a", "type": "text"}, + {"name": "score", "selector": ".score", "type": "text"}, { "name": "url", "selector": ".title a", "type": "attribute", - "attribute": "href" - } - ] + "attribute": "href", + }, + ], } - } + }, }, "crawler_params": { "headless": True, @@ -227,51 +227,50 @@ async def api_example(): # "screenshot": True, # "magic": True } - + async with session.post( - "http://localhost:11235/crawl", - json=crawl_request, - headers=headers + "http://localhost:11235/crawl", json=crawl_request, headers=headers ) as response: task_data = await response.json() task_id = task_data["task_id"] - + # Check task status while True: async with session.get( - f"http://localhost:11235/task/{task_id}", - headers=headers + f"http://localhost:11235/task/{task_id}", headers=headers ) as status_response: result = await status_response.json() print(f"Task status: {result['status']}") - + if result["status"] == "completed": print("Task completed!") print("Results:") - news = json.loads(result["results"][0]['extracted_content']) + news = json.loads(result["results"][0]["extracted_content"]) print(json.dumps(news[:4], indent=2)) break else: await asyncio.sleep(1) + # Main execution async def main(): # print("Running Crawl4AI feature examples...") - + # print("\n1. Running Download Example:") # await download_example() - + # print("\n2. Running Markdown Generation Example:") # await markdown_generation_example() - + # # print("\n3. Running Local and Raw HTML Example:") # await local_and_raw_html_example() - + # # print("\n4. Running Browser Management Example:") await browser_management_example() - + # print("\n5. Running API Example:") await api_example() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docs/examples/v0_4_24_walkthrough.py b/docs/examples/v0_4_24_walkthrough.py index 135ac29c..996e7b04 100644 --- a/docs/examples/v0_4_24_walkthrough.py +++ b/docs/examples/v0_4_24_walkthrough.py @@ -10,18 +10,17 @@ import asyncio import os import json import re -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field +from typing import List from crawl4ai import ( AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode, LLMExtractionStrategy, - JsonCssExtractionStrategy + JsonCssExtractionStrategy, ) from crawl4ai.content_filter_strategy import RelevantContentFilter -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from bs4 import BeautifulSoup # Sample HTML for demonstrations @@ -52,17 +51,18 @@ SAMPLE_HTML = """ """ + async def demo_ssl_features(): """ Enhanced SSL & Security Features Demo ----------------------------------- - + This example demonstrates the new SSL certificate handling and security features: 1. Custom certificate paths 2. SSL verification options 3. HTTPS error handling 4. Certificate validation configurations - + These features are particularly useful when: - Working with self-signed certificates - Dealing with corporate proxies @@ -76,14 +76,11 @@ async def demo_ssl_features(): run_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, - fetch_ssl_certificate=True # Enable SSL certificate fetching + fetch_ssl_certificate=True, # Enable SSL certificate fetching ) async with AsyncWebCrawler(config=browser_config) as crawler: - result = await crawler.arun( - url="https://example.com", - config=run_config - ) + result = await crawler.arun(url="https://example.com", config=run_config) print(f"SSL Crawl Success: {result.success}") result.ssl_certificate.to_json( os.path.join(os.getcwd(), "ssl_certificate.json") @@ -91,11 +88,12 @@ async def demo_ssl_features(): if not result.success: print(f"SSL Error: {result.error_message}") + async def demo_content_filtering(): """ Smart Content Filtering Demo ---------------------- - + Demonstrates advanced content filtering capabilities: 1. Custom filter to identify and extract specific content 2. Integration with markdown generation @@ -110,87 +108,90 @@ async def demo_content_filtering(): super().__init__() # Add news-specific patterns self.negative_patterns = re.compile( - r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending', - re.I + r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending", + re.I, ) self.min_word_count = 30 # Higher threshold for news content - def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: + def filter_content( + self, html: str, min_word_threshold: int = None + ) -> List[str]: """ Implements news-specific content filtering logic. - + Args: html (str): HTML content to be filtered min_word_threshold (int, optional): Minimum word count threshold - + Returns: List[str]: List of filtered HTML content blocks """ if not html or not isinstance(html, str): return [] - - soup = BeautifulSoup(html, 'lxml') + + soup = BeautifulSoup(html, "lxml") if not soup.body: - soup = BeautifulSoup(f'{html}', 'lxml') - - body = soup.find('body') - + soup = BeautifulSoup(f"{html}", "lxml") + + body = soup.find("body") + # Extract chunks with metadata - chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count) - + chunks = self.extract_text_chunks( + body, min_word_threshold or self.min_word_count + ) + # Filter chunks based on news-specific criteria filtered_chunks = [] for _, text, tag_type, element in chunks: # Skip if element has negative class/id if self.is_excluded(element): continue - + # Headers are important in news articles - if tag_type == 'header': + if tag_type == "header": filtered_chunks.append(self.clean_element(element)) continue - + # For content, check word count and link density text = element.get_text(strip=True) if len(text.split()) >= (min_word_threshold or self.min_word_count): # Calculate link density - links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a')) + links_text = " ".join( + a.get_text(strip=True) for a in element.find_all("a") + ) link_density = len(links_text) / len(text) if text else 1 - + # Accept if link density is reasonable if link_density < 0.5: filtered_chunks.append(self.clean_element(element)) - + return filtered_chunks # Create markdown generator with custom filter - markdown_gen = DefaultMarkdownGenerator( - content_filter=CustomNewsFilter() - ) + markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter()) run_config = CrawlerRunConfig( - markdown_generator=markdown_gen, - cache_mode=CacheMode.BYPASS + markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS ) async with AsyncWebCrawler() as crawler: result = await crawler.arun( - url="https://news.ycombinator.com", - config=run_config + url="https://news.ycombinator.com", config=run_config ) print("Filtered Content Sample:") print(result.markdown[:500]) # Show first 500 chars + async def demo_json_extraction(): """ Improved JSON Extraction Demo --------------------------- - + Demonstrates the enhanced JSON extraction capabilities: 1. Base element attributes extraction 2. Complex nested structures 3. Multiple extraction patterns - + Key features shown: - Extracting attributes from base elements (href, data-* attributes) - Processing repeated patterns @@ -206,7 +207,7 @@ async def demo_json_extraction(): "baseSelector": "div.article-list", "baseFields": [ {"name": "list_id", "type": "attribute", "attribute": "data-list-id"}, - {"name": "category", "type": "attribute", "attribute": "data-category"} + {"name": "category", "type": "attribute", "attribute": "data-category"}, ], "fields": [ { @@ -214,8 +215,16 @@ async def demo_json_extraction(): "selector": "article.post", "type": "nested_list", "baseFields": [ - {"name": "post_id", "type": "attribute", "attribute": "data-post-id"}, - {"name": "author_id", "type": "attribute", "attribute": "data-author"} + { + "name": "post_id", + "type": "attribute", + "attribute": "data-post-id", + }, + { + "name": "author_id", + "type": "attribute", + "attribute": "data-author", + }, ], "fields": [ { @@ -223,60 +232,68 @@ async def demo_json_extraction(): "selector": "h2.title a", "type": "text", "baseFields": [ - {"name": "url", "type": "attribute", "attribute": "href"} - ] + { + "name": "url", + "type": "attribute", + "attribute": "href", + } + ], }, { "name": "author", "selector": "div.meta a.author", "type": "text", "baseFields": [ - {"name": "profile_url", "type": "attribute", "attribute": "href"} - ] - }, - { - "name": "date", - "selector": "span.date", - "type": "text" + { + "name": "profile_url", + "type": "attribute", + "attribute": "href", + } + ], }, + {"name": "date", "selector": "span.date", "type": "text"}, { "name": "read_more", "selector": "a.read-more", "type": "nested", "fields": [ {"name": "text", "type": "text"}, - {"name": "url", "type": "attribute", "attribute": "href"} - ] - } - ] + { + "name": "url", + "type": "attribute", + "attribute": "href", + }, + ], + }, + ], } - ] + ], } ) # Demonstrate extraction from raw HTML run_config = CrawlerRunConfig( - extraction_strategy=json_strategy, - cache_mode=CacheMode.BYPASS + extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS ) async with AsyncWebCrawler() as crawler: result = await crawler.arun( url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML - config=run_config + config=run_config, ) print("Extracted Content:") print(result.extracted_content) + async def demo_input_formats(): """ Input Format Handling Demo ---------------------- - + Demonstrates how LLM extraction can work with different input formats: 1. Markdown (default) - Good for simple text extraction 2. HTML - Better when you need structure and attributes - + This example shows how HTML input can be beneficial when: - You need to understand the DOM structure - You want to extract both visible text and HTML attributes @@ -350,7 +367,7 @@ async def demo_input_formats(): """ - + # Use raw:// prefix to pass HTML content directly url = f"raw://{dummy_html}" @@ -359,18 +376,30 @@ async def demo_input_formats(): # Define our schema using Pydantic class JobRequirement(BaseModel): - category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)") - items: List[str] = Field(description="List of specific requirements in this category") - priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context") + category: str = Field( + description="Category of the requirement (e.g., Technical, Soft Skills)" + ) + items: List[str] = Field( + description="List of specific requirements in this category" + ) + priority: str = Field( + description="Priority level (Required/Preferred) based on the HTML class or context" + ) class JobPosting(BaseModel): title: str = Field(description="Job title") department: str = Field(description="Department or team") location: str = Field(description="Job location, including remote options") salary_range: Optional[str] = Field(description="Salary range if specified") - requirements: List[JobRequirement] = Field(description="Categorized job requirements") - application_deadline: Optional[str] = Field(description="Application deadline if specified") - contact_info: Optional[dict] = Field(description="Contact information from footer or contact section") + requirements: List[JobRequirement] = Field( + description="Categorized job requirements" + ) + application_deadline: Optional[str] = Field( + description="Application deadline if specified" + ) + contact_info: Optional[dict] = Field( + description="Contact information from footer or contact section" + ) # First try with markdown (default) markdown_strategy = LLMExtractionStrategy( @@ -382,7 +411,7 @@ async def demo_input_formats(): Extract job posting details into structured data. Focus on the visible text content and organize requirements into categories. """, - input_format="markdown" # default + input_format="markdown", # default ) # Then with HTML for better structure understanding @@ -400,34 +429,25 @@ async def demo_input_formats(): Use HTML attributes and classes to enhance extraction accuracy. """, - input_format="html" # explicitly use HTML + input_format="html", # explicitly use HTML ) async with AsyncWebCrawler() as crawler: # Try with markdown first - markdown_config = CrawlerRunConfig( - extraction_strategy=markdown_strategy - ) - markdown_result = await crawler.arun( - url=url, - config=markdown_config - ) + markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy) + markdown_result = await crawler.arun(url=url, config=markdown_config) print("\nMarkdown-based Extraction Result:") items = json.loads(markdown_result.extracted_content) print(json.dumps(items, indent=2)) # Then with HTML for better structure understanding - html_config = CrawlerRunConfig( - extraction_strategy=html_strategy - ) - html_result = await crawler.arun( - url=url, - config=html_config - ) + html_config = CrawlerRunConfig(extraction_strategy=html_strategy) + html_result = await crawler.arun(url=url, config=html_config) print("\nHTML-based Extraction Result:") items = json.loads(html_result.extracted_content) print(json.dumps(items, indent=2)) + # Main execution async def main(): print("Crawl4AI v0.4.24 Feature Walkthrough") @@ -439,5 +459,6 @@ async def main(): await demo_json_extraction() # await demo_input_formats() + if __name__ == "__main__": asyncio.run(main()) diff --git a/main.py b/main.py index 21e411d0..1f9e01a3 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,9 @@ import asyncio, os -from fastapi import FastAPI, HTTPException, BackgroundTasks, Request -from fastapi.responses import JSONResponse -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates -from fastapi.exceptions import RequestValidationError -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import FileResponse from fastapi.responses import RedirectResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import Depends, Security @@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union import psutil import time import uuid -from collections import defaultdict -from urllib.parse import urlparse import math import logging from enum import Enum from dataclasses import dataclass -import json from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode from crawl4ai.config import MIN_WORD_THRESHOLD from crawl4ai.extraction_strategy import ( @@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class TaskStatus(str, Enum): PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" + class CrawlerType(str, Enum): BASIC = "basic" LLM = "llm" COSINE = "cosine" JSON_CSS = "json_css" + class ExtractionConfig(BaseModel): type: CrawlerType params: Dict[str, Any] = {} + class ChunkingStrategy(BaseModel): type: str params: Dict[str, Any] = {} + class ContentFilter(BaseModel): type: str = "bm25" params: Dict[str, Any] = {} + class CrawlRequest(BaseModel): urls: Union[HttpUrl, List[HttpUrl]] word_count_threshold: int = MIN_WORD_THRESHOLD @@ -77,9 +75,10 @@ class CrawlRequest(BaseModel): session_id: Optional[str] = None cache_mode: Optional[CacheMode] = CacheMode.ENABLED priority: int = Field(default=5, ge=1, le=10) - ttl: Optional[int] = 3600 + ttl: Optional[int] = 3600 crawler_params: Dict[str, Any] = {} + @dataclass class TaskInfo: id: str @@ -89,6 +88,7 @@ class TaskInfo: created_at: float = time.time() ttl: int = 3600 + class ResourceMonitor: def __init__(self, max_concurrent_tasks: int = 10): self.max_concurrent_tasks = max_concurrent_tasks @@ -106,7 +106,9 @@ class ResourceMonitor: mem_usage = psutil.virtual_memory().percent / 100 cpu_usage = psutil.cpu_percent() / 100 - memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold) + memory_factor = max( + 0, (self.memory_threshold - mem_usage) / self.memory_threshold + ) cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) self._last_available_slots = math.floor( @@ -116,6 +118,7 @@ class ResourceMonitor: return self._last_available_slots + class TaskManager: def __init__(self, cleanup_interval: int = 300): self.tasks: Dict[str, TaskInfo] = {} @@ -149,12 +152,16 @@ class TaskManager: except asyncio.TimeoutError: try: # Then try low priority - _, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1) + _, task_id = await asyncio.wait_for( + self.low_priority.get(), timeout=0.1 + ) return task_id except asyncio.TimeoutError: return None - def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None): + def update_task( + self, task_id: str, status: TaskStatus, result: Any = None, error: str = None + ): if task_id in self.tasks: task_info = self.tasks[task_id] task_info.status = status @@ -180,6 +187,7 @@ class TaskManager: except Exception as e: logger.error(f"Error in cleanup loop: {e}") + class CrawlerPool: def __init__(self, max_size: int = 10): self.max_size = max_size @@ -222,6 +230,7 @@ class CrawlerPool: await crawler.__aexit__(None, None, None) self.active_crawlers.clear() + class CrawlerService: def __init__(self, max_concurrent_tasks: int = 10): self.resource_monitor = ResourceMonitor(max_concurrent_tasks) @@ -258,10 +267,10 @@ class CrawlerService: async def submit_task(self, request: CrawlRequest) -> str: task_id = str(uuid.uuid4()) await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) - + # Store request data with task self.task_manager.tasks[task_id].request = request - + return task_id async def _process_queue(self): @@ -286,9 +295,11 @@ class CrawlerService: try: crawler = await self.crawler_pool.acquire(**request.crawler_params) - - extraction_strategy = self._create_extraction_strategy(request.extraction_config) - + + extraction_strategy = self._create_extraction_strategy( + request.extraction_config + ) + if isinstance(request.urls, list): results = await crawler.arun_many( urls=[str(url) for url in request.urls], @@ -318,16 +329,21 @@ class CrawlerService: ) await self.crawler_pool.release(crawler) - self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results) + self.task_manager.update_task( + task_id, TaskStatus.COMPLETED, results + ) except Exception as e: logger.error(f"Error processing task {task_id}: {str(e)}") - self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e)) + self.task_manager.update_task( + task_id, TaskStatus.FAILED, error=str(e) + ) except Exception as e: logger.error(f"Error in queue processing: {str(e)}") await asyncio.sleep(1) + app = FastAPI(title="Crawl4AI API") # CORS configuration @@ -344,6 +360,7 @@ app.add_middleware( security = HTTPBearer() CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN") + async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): if not CRAWL4AI_API_TOKEN: return credentials # No token verification if CRAWL4AI_API_TOKEN is not set @@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu raise HTTPException(status_code=401, detail="Invalid token") return credentials + def secure_endpoint(): """Returns security dependency only if CRAWL4AI_API_TOKEN is set""" return Depends(verify_token) if CRAWL4AI_API_TOKEN else None + # Check if site directory exists if os.path.exists(__location__ + "/site"): # Mount the site directory as a static directory @@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site") crawler_service = CrawlerService() + @app.on_event("startup") async def startup_event(): await crawler_service.start() + @app.on_event("shutdown") async def shutdown_event(): await crawler_service.stop() + @app.get("/") def read_root(): if os.path.exists(__location__ + "/site"): @@ -379,12 +401,16 @@ def read_root(): # Return a json response return {"message": "Crawl4AI API service is running"} + @app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) async def crawl(request: CrawlRequest) -> Dict[str, str]: task_id = await crawler_service.submit_task(request) return {"task_id": task_id} -@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) + +@app.get( + "/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] +) async def get_task_status(task_id: str): task_info = crawler_service.task_manager.get_task(task_id) if not task_info: @@ -406,36 +432,45 @@ async def get_task_status(task_id: str): return response + @app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: task_id = await crawler_service.submit_task(request) - + # Wait up to 60 seconds for task completion for _ in range(60): task_info = crawler_service.task_manager.get_task(task_id) if not task_info: raise HTTPException(status_code=404, detail="Task not found") - + if task_info.status == TaskStatus.COMPLETED: # Return same format as /task/{task_id} endpoint if isinstance(task_info.result, list): - return {"status": task_info.status, "results": [result.dict() for result in task_info.result]} + return { + "status": task_info.status, + "results": [result.dict() for result in task_info.result], + } return {"status": task_info.status, "result": task_info.result.dict()} - + if task_info.status == TaskStatus.FAILED: raise HTTPException(status_code=500, detail=task_info.error) - + await asyncio.sleep(1) - + # If we get here, task didn't complete within timeout raise HTTPException(status_code=408, detail="Task timed out") -@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) + +@app.post( + "/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] +) async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: try: crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) - extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config) - + extraction_strategy = crawler_service._create_extraction_strategy( + request.extraction_config + ) + try: if isinstance(request.urls, list): results = await crawler.arun_many( @@ -470,7 +505,8 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: except Exception as e: logger.error(f"Error in direct crawl: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - + + @app.get("/health") async def health_check(): available_slots = await crawler_service.resource_monitor.get_available_slots() @@ -482,6 +518,8 @@ async def health_check(): "cpu_usage": psutil.cpu_percent(), } + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=11235) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=11235) diff --git a/setup.py b/setup.py index dad3199d..16b1b53c 100644 --- a/setup.py +++ b/setup.py @@ -51,9 +51,7 @@ setup( author_email="unclecode@kidocode.com", license="MIT", packages=find_packages(), - package_data={ - 'crawl4ai': ['js_snippet/*.js'] - }, + package_data={"crawl4ai": ["js_snippet/*.js"]}, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/tests/async/test_0.4.2_browser_manager.py b/tests/async/test_0.4.2_browser_manager.py index 9bb19582..21b4be11 100644 --- a/tests/async/test_0.4.2_browser_manager.py +++ b/tests/async/test_0.4.2_browser_manager.py @@ -1,17 +1,18 @@ -import os, sys -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) - -import os, sys +import os +import sys import asyncio from crawl4ai import AsyncWebCrawler, CacheMode -from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator -# Assuming that the changes made allow different configurations +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + + +# Assuming that the changes made allow different configurations # for managed browser, persistent context, and so forth. + async def test_default_headless(): async with AsyncWebCrawler( headless=True, @@ -24,13 +25,14 @@ async def test_default_headless(): # Testing normal ephemeral context ) as crawler: result = await crawler.arun( - url='https://www.kidocode.com/degrees/technology', + url="https://www.kidocode.com/degrees/technology", cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_default_headless] success:", result.success) print("HTML length:", len(result.html if result.html else "")) - + + async def test_managed_browser_persistent(): # Treating use_persistent_context=True as managed_browser scenario. async with AsyncWebCrawler( @@ -44,13 +46,14 @@ async def test_managed_browser_persistent(): # This should store and reuse profile data across runs ) as crawler: result = await crawler.arun( - url='https://www.google.com', + url="https://www.google.com", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_managed_browser_persistent] success:", result.success) print("HTML length:", len(result.html if result.html else "")) + async def test_session_reuse(): # Test creating a session, using it for multiple calls session_id = "my_session" @@ -62,25 +65,25 @@ async def test_session_reuse(): use_managed_browser=False, use_persistent_context=False, ) as crawler: - # First call: create session result1 = await crawler.arun( - url='https://www.example.com', + url="https://www.example.com", cache_mode=CacheMode.BYPASS, session_id=session_id, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_session_reuse first call] success:", result1.success) - + # Second call: same session, possibly cookie retained result2 = await crawler.arun( - url='https://www.example.com/about', + url="https://www.example.com/about", cache_mode=CacheMode.BYPASS, session_id=session_id, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_session_reuse second call] success:", result2.success) + async def test_magic_mode(): # Test magic mode with override_navigator and simulate_user async with AsyncWebCrawler( @@ -95,13 +98,14 @@ async def test_magic_mode(): simulate_user=True, ) as crawler: result = await crawler.arun( - url='https://www.kidocode.com/degrees/business', + url="https://www.kidocode.com/degrees/business", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_magic_mode] success:", result.success) print("HTML length:", len(result.html if result.html else "")) + async def test_proxy_settings(): # Test with a proxy (if available) to ensure code runs with proxy async with AsyncWebCrawler( @@ -113,14 +117,15 @@ async def test_proxy_settings(): use_persistent_context=False, ) as crawler: result = await crawler.arun( - url='https://httpbin.org/ip', + url="https://httpbin.org/ip", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_proxy_settings] success:", result.success) if result.success: print("HTML preview:", result.html[:200] if result.html else "") + async def test_ignore_https_errors(): # Test ignore HTTPS errors with a self-signed or invalid cert domain # This is just conceptual, the domain should be one that triggers SSL error. @@ -134,12 +139,13 @@ async def test_ignore_https_errors(): use_persistent_context=False, ) as crawler: result = await crawler.arun( - url='https://self-signed.badssl.com/', + url="https://self-signed.badssl.com/", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_ignore_https_errors] success:", result.success) + async def main(): print("Running tests...") # await test_default_headless() @@ -149,5 +155,6 @@ async def main(): # await test_proxy_settings() await test_ignore_https_errors() + if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/async/test_0.4.2_config_params.py b/tests/async/test_0.4.2_config_params.py index 623ac3ab..9a15f864 100644 --- a/tests/async/test_0.4.2_config_params.py +++ b/tests/async/test_0.4.2_config_params.py @@ -1,15 +1,16 @@ import os, sys + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) import asyncio from crawl4ai import AsyncWebCrawler, CacheMode -from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig +from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.chunking_strategy import RegexChunking -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator + # Category 1: Browser Configuration Tests async def test_browser_config_object(): @@ -21,29 +22,31 @@ async def test_browser_config_object(): viewport_height=1080, use_managed_browser=True, user_agent_mode="random", - user_agent_generator_config={"device_type": "desktop", "os_type": "windows"} + user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}, ) - + async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler: - result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS) + result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS) assert result.success, "Browser config crawl failed" assert len(result.html) > 0, "No HTML content retrieved" + async def test_browser_performance_config(): """Test browser configurations focused on performance""" browser_config = BrowserConfig( text_mode=True, light_mode=True, - extra_args=['--disable-gpu', '--disable-software-rasterizer'], + extra_args=["--disable-gpu", "--disable-software-rasterizer"], ignore_https_errors=True, - java_script_enabled=False + java_script_enabled=False, ) - + async with AsyncWebCrawler(config=browser_config) as crawler: - result = await crawler.arun('https://example.com') + result = await crawler.arun("https://example.com") assert result.success, "Performance optimized crawl failed" assert result.status_code == 200, "Unexpected status code" + # Category 2: Content Processing Tests async def test_content_extraction_config(): """Test content extraction with various strategies""" @@ -53,24 +56,20 @@ async def test_content_extraction_config(): schema={ "name": "article", "baseSelector": "div", - "fields": [{ - "name": "title", - "selector": "h1", - "type": "text" - }] + "fields": [{"name": "title", "selector": "h1", "type": "text"}], } ), chunking_strategy=RegexChunking(), - content_filter=PruningContentFilter() + content_filter=PruningContentFilter(), ) - + async with AsyncWebCrawler() as crawler: result = await crawler.arun( - 'https://example.com/article', - config=crawler_config + "https://example.com/article", config=crawler_config ) assert result.extracted_content is not None, "Content extraction failed" - assert 'title' in result.extracted_content, "Missing expected content field" + assert "title" in result.extracted_content, "Missing expected content field" + # Category 3: Cache and Session Management Tests async def test_cache_and_session_management(): @@ -79,25 +78,20 @@ async def test_cache_and_session_management(): crawler_config = CrawlerRunConfig( cache_mode=CacheMode.WRITE_ONLY, process_iframes=True, - remove_overlay_elements=True + remove_overlay_elements=True, ) - + async with AsyncWebCrawler(config=browser_config) as crawler: # First request - should write to cache - result1 = await crawler.arun( - 'https://example.com', - config=crawler_config - ) - + result1 = await crawler.arun("https://example.com", config=crawler_config) + # Second request - should use fresh fetch due to WRITE_ONLY mode - result2 = await crawler.arun( - 'https://example.com', - config=crawler_config - ) - + result2 = await crawler.arun("https://example.com", config=crawler_config) + assert result1.success and result2.success, "Cache mode crawl failed" assert result1.html == result2.html, "Inconsistent results between requests" + # Category 4: Media Handling Tests async def test_media_handling_config(): """Test configurations related to media handling""" @@ -107,24 +101,22 @@ async def test_media_handling_config(): viewport_width=1920, viewport_height=1080, accept_downloads=True, - downloads_path= os.path.expanduser("~/.crawl4ai/downloads") + downloads_path=os.path.expanduser("~/.crawl4ai/downloads"), ) crawler_config = CrawlerRunConfig( screenshot=True, pdf=True, adjust_viewport_to_content=True, wait_for_images=True, - screenshot_height_threshold=20000 + screenshot_height_threshold=20000, ) - + async with AsyncWebCrawler(config=browser_config) as crawler: - result = await crawler.arun( - 'https://example.com', - config=crawler_config - ) + result = await crawler.arun("https://example.com", config=crawler_config) assert result.screenshot is not None, "Screenshot capture failed" assert result.pdf is not None, "PDF generation failed" + # Category 5: Anti-Bot and Site Interaction Tests async def test_antibot_config(): """Test configurations for handling anti-bot measures""" @@ -135,76 +127,64 @@ async def test_antibot_config(): wait_for="js:()=>document.querySelector('body')", delay_before_return_html=1.0, log_console=True, - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - + async with AsyncWebCrawler() as crawler: - result = await crawler.arun( - 'https://example.com', - config=crawler_config - ) + result = await crawler.arun("https://example.com", config=crawler_config) assert result.success, "Anti-bot measure handling failed" + # Category 6: Parallel Processing Tests async def test_parallel_processing(): """Test parallel processing capabilities""" - crawler_config = CrawlerRunConfig( - mean_delay=0.5, - max_range=1.0, - semaphore_count=5 - ) - - urls = [ - 'https://example.com/1', - 'https://example.com/2', - 'https://example.com/3' - ] - + crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5) + + urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"] + async with AsyncWebCrawler() as crawler: - results = await crawler.arun_many( - urls, - config=crawler_config - ) + results = await crawler.arun_many(urls, config=crawler_config) assert len(results) == len(urls), "Not all URLs were processed" assert all(r.success for r in results), "Some parallel requests failed" + # Category 7: Backwards Compatibility Tests async def test_legacy_parameter_support(): """Test that legacy parameters still work""" async with AsyncWebCrawler( - headless=True, - browser_type="chromium", - viewport_width=1024, - viewport_height=768 + headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768 ) as crawler: result = await crawler.arun( - 'https://example.com', + "https://example.com", screenshot=True, word_count_threshold=200, bypass_cache=True, - css_selector=".main-content" + css_selector=".main-content", ) assert result.success, "Legacy parameter support failed" + # Category 8: Mixed Configuration Tests async def test_mixed_config_usage(): """Test mixing new config objects with legacy parameters""" browser_config = BrowserConfig(headless=True) crawler_config = CrawlerRunConfig(screenshot=True) - + async with AsyncWebCrawler( config=browser_config, - verbose=True # legacy parameter + verbose=True, # legacy parameter ) as crawler: result = await crawler.arun( - 'https://example.com', + "https://example.com", config=crawler_config, cache_mode=CacheMode.BYPASS, # legacy parameter - css_selector="body" # legacy parameter + css_selector="body", # legacy parameter ) assert result.success, "Mixed configuration usage failed" + if __name__ == "__main__": + async def run_tests(): test_functions = [ test_browser_config_object, @@ -217,7 +197,7 @@ if __name__ == "__main__": # test_legacy_parameter_support, # test_mixed_config_usage ] - + for test in test_functions: print(f"\nRunning {test.__name__}...") try: @@ -227,5 +207,5 @@ if __name__ == "__main__": print(f"✗ {test.__name__} failed: {str(e)}") except Exception as e: print(f"✗ {test.__name__} error: {str(e)}") - - asyncio.run(run_tests()) \ No newline at end of file + + asyncio.run(run_tests()) diff --git a/tests/async/test_async_doanloader.py b/tests/async/test_async_doanloader.py index 4798b4ca..055886cb 100644 --- a/tests/async/test_async_doanloader.py +++ b/tests/async/test_async_doanloader.py @@ -4,7 +4,6 @@ import asyncio import shutil from typing import List import tempfile -import time # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -12,28 +11,27 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + class TestDownloads: def __init__(self): self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_") self.download_dir = os.path.join(self.temp_dir, "downloads") os.makedirs(self.download_dir, exist_ok=True) self.results: List[str] = [] - + def cleanup(self): shutil.rmtree(self.temp_dir) - + def log_result(self, test_name: str, success: bool, message: str = ""): result = f"{'✅' if success else '❌'} {test_name}: {message}" self.results.append(result) print(result) - + async def test_basic_download(self): """Test basic file download functionality""" try: async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=self.download_dir, - verbose=True + accept_downloads=True, downloads_path=self.download_dir, verbose=True ) as crawler: # Python.org downloads page typically has stable download links result = await crawler.arun( @@ -42,14 +40,19 @@ class TestDownloads: // Click first download link const downloadLink = document.querySelector('a[href$=".exe"]'); if (downloadLink) downloadLink.click(); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 0 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 0 self.log_result( "Basic Download", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "No files downloaded", ) except Exception as e: self.log_result("Basic Download", False, str(e)) @@ -59,27 +62,32 @@ class TestDownloads: try: user_data_dir = os.path.join(self.temp_dir, "user_data") os.makedirs(user_data_dir, exist_ok=True) - + async with AsyncWebCrawler( accept_downloads=True, downloads_path=self.download_dir, use_persistent_context=True, user_data_dir=user_data_dir, - verbose=True + verbose=True, ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", js_code=""" const downloadLink = document.querySelector('a[href$=".exe"]'); if (downloadLink) downloadLink.click(); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 0 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 0 self.log_result( "Persistent Context Download", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "No files downloaded", ) except Exception as e: self.log_result("Persistent Context Download", False, str(e)) @@ -88,9 +96,7 @@ class TestDownloads: """Test multiple simultaneous downloads""" try: async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=self.download_dir, - verbose=True + accept_downloads=True, downloads_path=self.download_dir, verbose=True ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", @@ -98,14 +104,19 @@ class TestDownloads: // Click multiple download links const downloadLinks = document.querySelectorAll('a[href$=".exe"]'); downloadLinks.forEach(link => link.click()); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 1 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 1 self.log_result( "Multiple Downloads", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "Not enough files downloaded", ) except Exception as e: self.log_result("Multiple Downloads", False, str(e)) @@ -113,49 +124,51 @@ class TestDownloads: async def test_different_browsers(self): """Test downloads across different browser types""" browsers = ["chromium", "firefox", "webkit"] - + for browser_type in browsers: try: async with AsyncWebCrawler( accept_downloads=True, downloads_path=self.download_dir, browser_type=browser_type, - verbose=True + verbose=True, ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", js_code=""" const downloadLink = document.querySelector('a[href$=".exe"]'); if (downloadLink) downloadLink.click(); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 0 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 0 self.log_result( f"{browser_type.title()} Download", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "No files downloaded", ) except Exception as e: self.log_result(f"{browser_type.title()} Download", False, str(e)) async def test_edge_cases(self): """Test various edge cases""" - + # Test 1: Downloads without specifying download path try: - async with AsyncWebCrawler( - accept_downloads=True, - verbose=True - ) as crawler: + async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()" + js_code="document.querySelector('a[href$=\".exe\"]').click()", ) self.log_result( "Default Download Path", True, - f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}" + f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}", ) except Exception as e: self.log_result("Default Download Path", False, str(e)) @@ -165,31 +178,34 @@ class TestDownloads: async with AsyncWebCrawler( accept_downloads=True, downloads_path="/invalid/path/that/doesnt/exist", - verbose=True + verbose=True, ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()" + js_code="document.querySelector('a[href$=\".exe\"]').click()", ) - self.log_result("Invalid Download Path", False, "Should have raised an error") - except Exception as e: - self.log_result("Invalid Download Path", True, "Correctly handled invalid path") + self.log_result( + "Invalid Download Path", False, "Should have raised an error" + ) + except Exception: + self.log_result( + "Invalid Download Path", True, "Correctly handled invalid path" + ) # Test 3: Download with accept_downloads=False try: - async with AsyncWebCrawler( - accept_downloads=False, - verbose=True - ) as crawler: + async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()" + js_code="document.querySelector('a[href$=\".exe\"]').click()", ) success = result.downloaded_files is None self.log_result( "Disabled Downloads", success, - "Correctly ignored downloads" if success else "Unexpectedly downloaded files" + "Correctly ignored downloads" + if success + else "Unexpectedly downloaded files", ) except Exception as e: self.log_result("Disabled Downloads", False, str(e)) @@ -197,33 +213,35 @@ class TestDownloads: async def run_all_tests(self): """Run all test cases""" print("\n🧪 Running Download Tests...\n") - + test_methods = [ self.test_basic_download, self.test_persistent_context_download, self.test_multiple_downloads, self.test_different_browsers, - self.test_edge_cases + self.test_edge_cases, ] - + for test in test_methods: print(f"\n📝 Running {test.__doc__}...") await test() await asyncio.sleep(2) # Brief pause between tests - + print("\n📊 Test Results Summary:") for result in self.results: print(result) - - successes = len([r for r in self.results if '✅' in r]) + + successes = len([r for r in self.results if "✅" in r]) total = len(self.results) print(f"\nTotal: {successes}/{total} tests passed") - + self.cleanup() + async def main(): tester = TestDownloads() await tester.run_all_tests() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/async/test_basic_crawling.py b/tests/async/test_basic_crawling.py index ce38ac2f..ee4bb633 100644 --- a/tests/async/test_basic_crawling.py +++ b/tests/async/test_basic_crawling.py @@ -1,15 +1,17 @@ import os import sys import pytest -import asyncio import time # Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_successful_crawl(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -21,6 +23,7 @@ async def test_successful_crawl(): assert result.markdown assert result.cleaned_html + @pytest.mark.asyncio async def test_invalid_url(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -29,19 +32,21 @@ async def test_invalid_url(): assert not result.success assert result.error_message + @pytest.mark.asyncio async def test_multiple_urls(): async with AsyncWebCrawler(verbose=True) as crawler: urls = [ "https://www.nbcnews.com/business", "https://www.example.com", - "https://www.python.org" + "https://www.python.org", ] results = await crawler.arun_many(urls=urls, bypass_cache=True) assert len(results) == len(urls) assert all(result.success for result in results) assert all(result.html for result in results) + @pytest.mark.asyncio async def test_javascript_execution(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -51,6 +56,7 @@ async def test_javascript_execution(): assert result.success assert "