From 8ec12d7d68f2f35183d1c472a256282c2d79aa0d Mon Sep 17 00:00:00 2001 From: UncleCode Date: Mon, 13 Jan 2025 19:19:58 +0800 Subject: [PATCH] Apply Ruff Corrections --- .pre-commit-config.yaml | 8 + crawl4ai/__init__.py | 67 +- crawl4ai/async_configs.py | 59 +- crawl4ai/async_crawler_strategy.py | 535 ++++--- crawl4ai/async_database.py | 316 ++-- crawl4ai/async_dispatcher.py | 304 ++-- crawl4ai/async_logger.py | 88 +- crawl4ai/async_webcrawler.py | 1126 +++++++------- crawl4ai/cache_context.py | 32 +- crawl4ai/chunking_strategy.py | 109 +- crawl4ai/cli.py | 58 +- crawl4ai/config.py | 56 +- crawl4ai/content_filter_strategy.py | 543 ++++--- crawl4ai/content_scraping_strategy.py | 992 +++++++----- crawl4ai/crawler_strategy.py | 206 +-- crawl4ai/database.py | 75 +- crawl4ai/docs_manager.py | 30 +- crawl4ai/extraction_strategy.py | 594 ++++--- crawl4ai/html2text/__init__.py | 61 +- crawl4ai/html2text/_typing.py | 3 +- crawl4ai/html2text/utils.py | 3 +- crawl4ai/install.py | 54 +- crawl4ai/js_snippet/__init__.py | 11 +- crawl4ai/llmtxt.py | 166 +- crawl4ai/markdown_generation_strategy.py | 158 +- crawl4ai/migrations.py | 118 +- crawl4ai/model_loader.py | 132 +- crawl4ai/models.py | 42 +- crawl4ai/ssl_certificate.py | 87 +- crawl4ai/user_agent_generator.py | 150 +- crawl4ai/utils.py | 1360 ++++++++++------- crawl4ai/version_manager.py | 9 +- crawl4ai/web_crawler.py | 337 ++-- .../amazon_product_extraction_direct_url.py | 46 +- .../amazon_product_extraction_using_hooks.py | 83 +- ...product_extraction_using_use_javascript.py | 47 +- .../async_webcrawler_multiple_urls_example.py | 17 +- docs/examples/browser_optimization_example.py | 2 - docs/examples/crawlai_vs_firecrawl.py | 39 +- docs/examples/dispatcher_example.py | 82 +- docs/examples/docker_example.py | 191 +-- .../examples/extraction_strategies_example.py | 70 +- docs/examples/hello_world.py | 13 +- docs/examples/hooks_example.py | 55 +- docs/examples/language_support_example.py | 25 +- .../examples/llm_extraction_openai_pricing.py | 24 +- docs/examples/quickstart_async.config.py | 72 +- docs/examples/quickstart_async.py | 251 +-- docs/examples/quickstart_sync.py | 290 ++-- docs/examples/research_assistant.py | 77 +- docs/examples/rest_call.py | 42 +- docs/examples/ssl_example.py | 31 +- docs/examples/summarize_page.py | 26 +- docs/examples/v0.3.74.overview.py | 155 +- docs/examples/v0_4_24_walkthrough.py | 191 +-- main.py | 106 +- setup.py | 4 +- tests/async/test_0.4.2_browser_manager.py | 55 +- tests/async/test_0.4.2_config_params.py | 126 +- tests/async/test_async_doanloader.py | 130 +- tests/async/test_basic_crawling.py | 29 +- tests/async/test_caching.py | 33 +- ...test_chunking_and_extraction_strategies.py | 25 +- tests/async/test_content_extraction.py | 15 +- tests/async/test_content_filter_bm25.py | 72 +- tests/async/test_content_filter_prune.py | 73 +- tests/async/test_content_scraper_strategy.py | 217 ++- tests/async/test_crawler_strategy.py | 13 +- tests/async/test_database_operations.py | 26 +- tests/async/test_dispatchers.py | 119 +- tests/async/test_edge_cases.py | 38 +- tests/async/test_error_handling.py | 2 +- ...on_scraping_methods_performance.configs.py | 281 ++-- tests/async/test_markdown_genertor.py | 106 +- tests/async/test_parameters_and_options.py | 62 +- tests/async/test_performance.py | 37 +- tests/async/test_screenshot.py | 70 +- tests/docker_example.py | 174 ++- tests/test_cli_docs.py | 13 +- tests/test_docker.py | 135 +- tests/test_llmtxt.py | 17 +- tests/test_main.py | 103 +- tests/test_scraping_strategy.py | 17 +- tests/test_web_crawler.py | 121 +- 84 files changed, 6861 insertions(+), 5076 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..92f3c1a7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +# .pre-commit-config.yaml +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.11 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format \ No newline at end of file diff --git a/crawl4ai/__init__.py b/crawl4ai/__init__.py index 86c2cb9e..78ccdb02 100644 --- a/crawl4ai/__init__.py +++ b/crawl4ai/__init__.py @@ -2,14 +2,28 @@ from .async_webcrawler import AsyncWebCrawler, CacheMode from .async_configs import BrowserConfig, CrawlerRunConfig -from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy -from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy +from .content_scraping_strategy import ( + ContentScrapingStrategy, + WebScrapingStrategy, + LXMLWebScrapingStrategy, +) +from .extraction_strategy import ( + ExtractionStrategy, + LLMExtractionStrategy, + CosineStrategy, + JsonCssExtractionStrategy, +) from .chunking_strategy import ChunkingStrategy, RegexChunking from .markdown_generation_strategy import DefaultMarkdownGenerator from .content_filter_strategy import PruningContentFilter, BM25ContentFilter from .models import CrawlResult, MarkdownGenerationResult -from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode -from .__version__ import __version__ +from .async_dispatcher import ( + MemoryAdaptiveDispatcher, + SemaphoreDispatcher, + RateLimiter, + CrawlerMonitor, + DisplayMode, +) __all__ = [ "AsyncWebCrawler", @@ -18,40 +32,45 @@ __all__ = [ "ContentScrapingStrategy", "WebScrapingStrategy", "LXMLWebScrapingStrategy", - 'BrowserConfig', - 'CrawlerRunConfig', - 'ExtractionStrategy', - 'LLMExtractionStrategy', - 'CosineStrategy', - 'JsonCssExtractionStrategy', - 'ChunkingStrategy', - 'RegexChunking', - 'DefaultMarkdownGenerator', - 'PruningContentFilter', - 'BM25ContentFilter', - 'MemoryAdaptiveDispatcher', - 'SemaphoreDispatcher', - 'RateLimiter', - 'CrawlerMonitor', - 'DisplayMode', - 'MarkdownGenerationResult', + "BrowserConfig", + "CrawlerRunConfig", + "ExtractionStrategy", + "LLMExtractionStrategy", + "CosineStrategy", + "JsonCssExtractionStrategy", + "ChunkingStrategy", + "RegexChunking", + "DefaultMarkdownGenerator", + "PruningContentFilter", + "BM25ContentFilter", + "MemoryAdaptiveDispatcher", + "SemaphoreDispatcher", + "RateLimiter", + "CrawlerMonitor", + "DisplayMode", + "MarkdownGenerationResult", ] + def is_sync_version_installed(): try: import selenium + return True except ImportError: return False + if is_sync_version_installed(): try: from .web_crawler import WebCrawler + __all__.append("WebCrawler") except ImportError: - import warnings - print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.") + print( + "Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies." + ) else: WebCrawler = None # import warnings - # print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.") \ No newline at end of file + # print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.") diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index 28f90bb3..c6f25994 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -5,7 +5,6 @@ from .config import ( PAGE_TIMEOUT, IMAGE_SCORE_THRESHOLD, SOCIAL_MEDIA_DOMAINS, - ) from .user_agent_generator import UserAgentGenerator from .extraction_strategy import ExtractionStrategy @@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy from typing import Union, List + class BrowserConfig: """ Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy. @@ -84,7 +84,7 @@ class BrowserConfig: proxy: str = None, proxy_config: dict = None, viewport_width: int = 1080, - viewport_height: int = 600, + viewport_height: int = 600, accept_downloads: bool = False, downloads_path: str = None, storage_state=None, @@ -103,7 +103,7 @@ class BrowserConfig: text_mode: bool = False, light_mode: bool = False, extra_args: list = None, - debugging_port : int = 9222, + debugging_port: int = 9222, ): self.browser_type = browser_type self.headless = headless @@ -142,7 +142,7 @@ class BrowserConfig: self.user_agent = user_agenr_generator.generate() else: pass - + self.browser_hint = user_agenr_generator.generate_client_hints(self.user_agent) self.headers.setdefault("sec-ch-ua", self.browser_hint) @@ -313,7 +313,7 @@ class CrawlerRunConfig: Default: True. log_console (bool): If True, log console messages from the page. Default: False. - + # Optional Parameters url: str = None # This is not a compulsory parameter """ @@ -335,10 +335,8 @@ class CrawlerRunConfig: prettiify: bool = False, parser_type: str = "lxml", scraping_strategy: ContentScrapingStrategy = None, - # SSL Parameters fetch_ssl_certificate: bool = False, - # Caching Parameters cache_mode=None, session_id: str = None, @@ -346,7 +344,6 @@ class CrawlerRunConfig: disable_cache: bool = False, no_cache_read: bool = False, no_cache_write: bool = False, - # Page Navigation and Timing Parameters wait_until: str = "domcontentloaded", page_timeout: int = PAGE_TIMEOUT, @@ -356,7 +353,6 @@ class CrawlerRunConfig: mean_delay: float = 0.1, max_range: float = 0.3, semaphore_count: int = 5, - # Page Interaction Parameters js_code: Union[str, List[str]] = None, js_only: bool = False, @@ -369,7 +365,6 @@ class CrawlerRunConfig: override_navigator: bool = False, magic: bool = False, adjust_viewport_to_content: bool = False, - # Media Handling Parameters screenshot: bool = False, screenshot_wait_for: float = None, @@ -378,21 +373,18 @@ class CrawlerRunConfig: image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, image_score_threshold: int = IMAGE_SCORE_THRESHOLD, exclude_external_images: bool = False, - # Link and Domain Handling Parameters exclude_social_media_domains: list = None, exclude_external_links: bool = False, exclude_social_media_links: bool = False, exclude_domains: list = None, - # Debugging and Logging Parameters verbose: bool = True, log_console: bool = False, - url: str = None, ): self.url = url - + # Content Processing Parameters self.word_count_threshold = word_count_threshold self.extraction_strategy = extraction_strategy @@ -453,7 +445,9 @@ class CrawlerRunConfig: self.exclude_external_images = exclude_external_images # Link and Domain Handling Parameters - self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS + self.exclude_social_media_domains = ( + exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS + ) self.exclude_external_links = exclude_external_links self.exclude_social_media_links = exclude_social_media_links self.exclude_domains = exclude_domains or [] @@ -466,11 +460,15 @@ class CrawlerRunConfig: if self.extraction_strategy is not None and not isinstance( self.extraction_strategy, ExtractionStrategy ): - raise ValueError("extraction_strategy must be an instance of ExtractionStrategy") + raise ValueError( + "extraction_strategy must be an instance of ExtractionStrategy" + ) if self.chunking_strategy is not None and not isinstance( self.chunking_strategy, ChunkingStrategy ): - raise ValueError("chunking_strategy must be an instance of ChunkingStrategy") + raise ValueError( + "chunking_strategy must be an instance of ChunkingStrategy" + ) # Set default chunking strategy if None if self.chunking_strategy is None: @@ -494,10 +492,8 @@ class CrawlerRunConfig: prettiify=kwargs.get("prettiify", False), parser_type=kwargs.get("parser_type", "lxml"), scraping_strategy=kwargs.get("scraping_strategy"), - # SSL Parameters fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False), - # Caching Parameters cache_mode=kwargs.get("cache_mode"), session_id=kwargs.get("session_id"), @@ -505,7 +501,6 @@ class CrawlerRunConfig: disable_cache=kwargs.get("disable_cache", False), no_cache_read=kwargs.get("no_cache_read", False), no_cache_write=kwargs.get("no_cache_write", False), - # Page Navigation and Timing Parameters wait_until=kwargs.get("wait_until", "domcontentloaded"), page_timeout=kwargs.get("page_timeout", 60000), @@ -515,7 +510,6 @@ class CrawlerRunConfig: mean_delay=kwargs.get("mean_delay", 0.1), max_range=kwargs.get("max_range", 0.3), semaphore_count=kwargs.get("semaphore_count", 5), - # Page Interaction Parameters js_code=kwargs.get("js_code"), js_only=kwargs.get("js_only", False), @@ -528,29 +522,34 @@ class CrawlerRunConfig: override_navigator=kwargs.get("override_navigator", False), magic=kwargs.get("magic", False), adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False), - # Media Handling Parameters screenshot=kwargs.get("screenshot", False), screenshot_wait_for=kwargs.get("screenshot_wait_for"), - screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD), + screenshot_height_threshold=kwargs.get( + "screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD + ), pdf=kwargs.get("pdf", False), - image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD), - image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD), + image_description_min_word_threshold=kwargs.get( + "image_description_min_word_threshold", + IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, + ), + image_score_threshold=kwargs.get( + "image_score_threshold", IMAGE_SCORE_THRESHOLD + ), exclude_external_images=kwargs.get("exclude_external_images", False), - # Link and Domain Handling Parameters - exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS), + exclude_social_media_domains=kwargs.get( + "exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS + ), exclude_external_links=kwargs.get("exclude_external_links", False), exclude_social_media_links=kwargs.get("exclude_social_media_links", False), exclude_domains=kwargs.get("exclude_domains", []), - # Debugging and Logging Parameters verbose=kwargs.get("verbose", True), log_console=kwargs.get("log_console", False), - url=kwargs.get("url"), ) - + # Create a funciton returns dict of the object def to_dict(self): return { diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index ed7407f8..0edefa73 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -2,27 +2,25 @@ import asyncio import base64 import time from abc import ABC, abstractmethod -from typing import Callable, Dict, Any, List, Optional, Awaitable, Union -import os, sys, shutil -import tempfile, subprocess -from playwright.async_api import async_playwright, Page, Browser, Error, BrowserContext +from typing import Callable, Dict, Any, List, Optional, Union +import os +import sys +import shutil +import tempfile +import subprocess +from playwright.async_api import Page, Error, BrowserContext from playwright.async_api import TimeoutError as PlaywrightTimeoutError from io import BytesIO from PIL import Image, ImageDraw, ImageFont -from pathlib import Path -from playwright.async_api import ProxySettings -from pydantic import BaseModel import hashlib -import json import uuid from .js_snippet import load_js_script from .models import AsyncCrawlResponse -from .utils import get_error_context from .user_agent_generator import UserAgentGenerator from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT from .async_configs import BrowserConfig, CrawlerRunConfig from .async_logger import AsyncLogger -from playwright_stealth import StealthConfig, stealth_async +from playwright_stealth import StealthConfig from .ssl_certificate import SSLCertificate stealth_config = StealthConfig( @@ -66,7 +64,7 @@ BROWSER_DISABLE_OPTIONS = [ class ManagedBrowser: """ Manages the browser process and context. This class allows to connect to the browser using CDP protocol. - + Attributes: browser_type (str): The type of browser to launch. Supported values: "chromium", "firefox", "webkit". Default: "chromium". @@ -75,16 +73,16 @@ class ManagedBrowser: headless (bool): Whether to run the browser in headless mode (no visible GUI). Default: True. browser_process (subprocess.Popen): The process object for the browser. - temp_dir (str): Temporary directory for user data if not provided. + temp_dir (str): Temporary directory for user data if not provided. debugging_port (int): Port for debugging the browser. host (str): Host for debugging the browser. - + Methods: start(): Starts the browser process and returns the CDP endpoint URL. _get_browser_path(): Returns the browser executable path based on OS and browser type. _get_browser_args(): Returns browser-specific command line arguments. _get_user_data_dir(): Returns the user data directory path. - _cleanup(): Terminates the browser process and removes the temporary directory. + _cleanup(): Terminates the browser process and removes the temporary directory. """ browser_type: str @@ -94,6 +92,7 @@ class ManagedBrowser: temp_dir: str debugging_port: int host: str + def __init__( self, browser_type: str = "chromium", @@ -105,7 +104,7 @@ class ManagedBrowser: ): """ Initialize the ManagedBrowser instance. - + Args: browser_type (str): The type of browser to launch. Supported values: "chromium", "firefox", "webkit". Default: "chromium". @@ -116,7 +115,7 @@ class ManagedBrowser: logger (logging.Logger): Logger instance for logging messages. Default: None. host (str): Host for debugging the browser. Default: "localhost". debugging_port (int): Port for debugging the browser. Default: 9222. - """ + """ self.browser_type = browser_type self.user_data_dir = user_data_dir self.headless = headless @@ -139,7 +138,7 @@ class ManagedBrowser: self.user_data_dir = self.temp_dir # Get browser path and args based on OS and browser type - browser_path = self._get_browser_path() + # browser_path = self._get_browser_path() args = self._get_browser_args() # Start browser process @@ -158,13 +157,13 @@ class ManagedBrowser: async def _monitor_browser_process(self): """ Monitor the browser process for unexpected termination. - + How it works: 1. Read stdout and stderr from the browser process. 2. If the process has terminated, log the error message and terminate the browser. 3. If the shutting_down flag is set, log the normal termination message. 4. If any other error occurs, log the error message. - + Note: This method should be called in a separate task to avoid blocking the main event loop. """ if self.browser_process: @@ -289,17 +288,18 @@ class ManagedBrowser: class BrowserManager: """ Manages the browser instance and context. - - Attributes: + + Attributes: config (BrowserConfig): Configuration object containing all browser settings logger: Logger instance for recording events and errors browser (Browser): The browser instance - default_context (BrowserContext): The default browser context + default_context (BrowserContext): The default browser context managed_browser (ManagedBrowser): The managed browser instance playwright (Playwright): The Playwright instance sessions (dict): Dictionary to store session information session_ttl (int): Session timeout in seconds """ + def __init__(self, browser_config: BrowserConfig, logger=None): """ Initialize the BrowserManager with a browser configuration. @@ -334,13 +334,13 @@ class BrowserManager: async def start(self): """ Start the browser instance and set up the default context. - + How it works: 1. Check if Playwright is already initialized. 2. If not, initialize Playwright. 3. If managed browser is used, start it and connect to the CDP endpoint. 4. If managed browser is not used, launch the browser and set up the default context. - + Note: This method should be called in a separate task to avoid blocking the main event loop. """ if self.playwright is None: @@ -453,7 +453,12 @@ class BrowserManager: return browser_args - async def setup_context(self, context: BrowserContext, crawlerRunConfig: CrawlerRunConfig = None, is_default=False): + async def setup_context( + self, + context: BrowserContext, + crawlerRunConfig: CrawlerRunConfig = None, + is_default=False, + ): """ Set up a browser context with the configured options. @@ -474,11 +479,11 @@ class BrowserManager: 14. Set default timeouts for navigation and download if enabled. 15. Set user agent if provided. 16. Set browser hints if provided. - + Args: context (BrowserContext): The browser context to set up crawlerRunConfig (CrawlerRunConfig): Configuration object containing all browser settings - is_default (bool): Flag indicating if this is the default context + is_default (bool): Flag indicating if this is the default context Returns: None """ @@ -496,9 +501,9 @@ class BrowserManager: context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT) if self.config.downloads_path: context._impl_obj._options["accept_downloads"] = True - context._impl_obj._options["downloads_path"] = ( - self.config.downloads_path - ) + context._impl_obj._options[ + "downloads_path" + ] = self.config.downloads_path # Handle user agent and browser hints if self.config.user_agent: @@ -511,7 +516,15 @@ class BrowserManager: # Add default cookie await context.add_cookies( - [{"name": "cookiesEnabled", "value": "true", "url": crawlerRunConfig.url if crawlerRunConfig else "https://crawl4ai.com/"}] + [ + { + "name": "cookiesEnabled", + "value": "true", + "url": crawlerRunConfig.url + if crawlerRunConfig + else "https://crawl4ai.com/", + } + ] ) # Handle navigator overrides @@ -527,7 +540,7 @@ class BrowserManager: """ Creates and returns a new browser context with configured settings. Applies text-only mode settings if text_mode is enabled in config. - + Returns: Context: Browser context object with the specified configurations """ @@ -538,25 +551,62 @@ class BrowserManager: "height": self.config.viewport_height, } proxy_settings = {"server": self.config.proxy} if self.config.proxy else None - + blocked_extensions = [ # Images - 'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'ico', 'bmp', 'tiff', 'psd', + "jpg", + "jpeg", + "png", + "gif", + "webp", + "svg", + "ico", + "bmp", + "tiff", + "psd", # Fonts - 'woff', 'woff2', 'ttf', 'otf', 'eot', + "woff", + "woff2", + "ttf", + "otf", + "eot", # Styles # 'css', 'less', 'scss', 'sass', # Media - 'mp4', 'webm', 'ogg', 'avi', 'mov', 'wmv', 'flv', 'm4v', - 'mp3', 'wav', 'aac', 'm4a', 'opus', 'flac', + "mp4", + "webm", + "ogg", + "avi", + "mov", + "wmv", + "flv", + "m4v", + "mp3", + "wav", + "aac", + "m4a", + "opus", + "flac", # Documents - 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', + "pdf", + "doc", + "docx", + "xls", + "xlsx", + "ppt", + "pptx", # Archives - 'zip', 'rar', '7z', 'tar', 'gz', + "zip", + "rar", + "7z", + "tar", + "gz", # Scripts and data - 'xml', 'swf', 'wasm' + "xml", + "swf", + "wasm", ] - + # Common context settings context_settings = { "user_agent": user_agent, @@ -568,7 +618,7 @@ class BrowserManager: "device_scale_factor": 1.0, "java_script_enabled": self.config.java_script_enabled, } - + if self.config.text_mode: text_mode_settings = { "has_touch": False, @@ -576,22 +626,22 @@ class BrowserManager: } # Update context settings with text mode settings context_settings.update(text_mode_settings) - + # Create and return the context with all settings context = await self.browser.new_context(**context_settings) - + # Apply text mode settings if enabled if self.config.text_mode: # Create and apply route patterns for each extension for ext in blocked_extensions: await context.route(f"**/*.{ext}", lambda route: route.abort()) return context - + # async def get_page(self, session_id: Optional[str], user_agent: str): async def get_page(self, crawlerRunConfig: CrawlerRunConfig): """ Get a page for the given session ID, creating a new one if needed. - + Args: crawlerRunConfig (CrawlerRunConfig): Configuration object containing all browser settings @@ -621,8 +671,8 @@ class BrowserManager: async def kill_session(self, session_id: str): """ - Kill a browser session and clean up resources. - + Kill a browser session and clean up resources. + Args: session_id (str): The session ID to kill. """ @@ -672,20 +722,20 @@ class AsyncCrawlerStrategy(ABC): Abstract base class for crawler strategies. Subclasses must implement the crawl method. """ + @abstractmethod async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: pass # 4 + 3 - class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Crawler strategy using Playwright. - + Attributes: browser_config (BrowserConfig): Configuration object containing browser settings. logger (AsyncLogger): Logger instance for recording events and errors. - _downloaded_files (List[str]): List of downloaded file paths. + _downloaded_files (List[str]): List of downloaded file paths. hooks (Dict[str, Callable]): Dictionary of hooks for custom behavior. browser_manager (BrowserManager): Manager for browser creation and management. @@ -704,8 +754,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): Kill a browser session and clean up resources. crawl(self, url, **kwargs): Run the crawler for a single URL. - + """ + def __init__( self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs ): @@ -769,10 +820,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def kill_session(self, session_id: str): """ Kill a browser session and clean up resources. - + Args: session_id (str): The ID of the session to kill. - + Returns: None """ @@ -787,20 +838,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Set a hook function for a specific hook type. Following are list of hook types: - on_browser_created: Called when a new browser instance is created. - - on_page_context_created: Called when a new page context is created. - - on_user_agent_updated: Called when the user agent is updated. - - on_execution_started: Called when the execution starts. - - before_goto: Called before a goto operation. - - after_goto: Called after a goto operation. - - before_return_html: Called before returning HTML content. - - before_retrieve_html: Called before retrieving HTML content. - + - on_page_context_created: Called when a new page context is created. + - on_user_agent_updated: Called when the user agent is updated. + - on_execution_started: Called when the execution starts. + - before_goto: Called before a goto operation. + - after_goto: Called after a goto operation. + - before_return_html: Called before returning HTML content. + - before_retrieve_html: Called before retrieving HTML content. + All hooks except on_browser_created accepts a context and a page as arguments and **kwargs. However, on_browser_created accepts a browser and a context as arguments and **kwargs. - + Args: hook_type (str): The type of the hook. hook (Callable): The hook function to set. - + Returns: None """ @@ -812,12 +863,12 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def execute_hook(self, hook_type: str, *args, **kwargs): """ Execute a hook function for a specific hook type. - + Args: hook_type (str): The type of the hook. *args: Variable length positional arguments. **kwargs: Keyword arguments. - + Returns: The return value of the hook function, if any. """ @@ -832,42 +883,42 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): def update_user_agent(self, user_agent: str): """ Update the user agent for the browser. - + Args: user_agent (str): The new user agent string. - + Returns: None """ self.user_agent = user_agent def set_custom_headers(self, headers: Dict[str, str]): - """ - Set custom headers for the browser. - + """ + Set custom headers for the browser. + Args: headers (Dict[str, str]): A dictionary of headers to set. - + Returns: None """ self.headers = headers async def smart_wait(self, page: Page, wait_for: str, timeout: float = 30000): - """ + """ Wait for a condition in a smart way. This functions works as below: - + 1. If wait_for starts with 'js:', it assumes it's a JavaScript function and waits for it to return true. 2. If wait_for starts with 'css:', it assumes it's a CSS selector and waits for it to be present. 3. Otherwise, it tries to evaluate wait_for as a JavaScript function and waits for it to return true. 4. If it's not a JavaScript function, it assumes it's a CSS selector and waits for it to be present. - - This is a more advanced version of the wait_for parameter in CrawlerStrategy.crawl(). + + This is a more advanced version of the wait_for parameter in CrawlerStrategy.crawl(). Args: page: Playwright page object wait_for (str): The condition to wait for. Can be a CSS selector, a JavaScript function, or explicitly prefixed with 'js:' or 'css:'. timeout (float): Maximum time to wait in milliseconds - + Returns: None """ @@ -917,18 +968,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): "or explicitly prefixed with 'js:' or 'css:'." ) - async def csp_compliant_wait( self, page: Page, user_wait_function: str, timeout: float = 30000 ): + async def csp_compliant_wait( + self, page: Page, user_wait_function: str, timeout: float = 30000 + ): """ Wait for a condition in a CSP-compliant way. - + Args: page: Playwright page object user_wait_function: JavaScript function as string that returns boolean timeout: Maximum time to wait in milliseconds - + Returns: bool: True if condition was met, False if timed out - + Raises: RuntimeError: If there's an error evaluating the condition """ @@ -964,10 +1017,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def process_iframes(self, page): """ Process iframes on a page. This function will extract the content of each iframe and replace it with a div containing the extracted content. - + Args: page: Playwright page object - + Returns: Playwright page object """ @@ -1029,10 +1082,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Creates a new browser session and returns its ID. A browse session is a unique openned page can be reused for multiple crawls. This function is asynchronous and returns a string representing the session ID. - + Args: **kwargs: Optional keyword arguments to configure the session. - + Returns: str: The session ID. """ @@ -1045,7 +1098,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): page, context = await self.browser_manager.get_page(session_id, user_agent) return session_id - async def crawl( self, url: str, config: CrawlerRunConfig, **kwargs ) -> AsyncCrawlResponse: + async def crawl( + self, url: str, config: CrawlerRunConfig, **kwargs + ) -> AsyncCrawlResponse: """ Crawls a given URL or processes raw HTML/local file content based on the URL prefix. @@ -1104,7 +1159,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): "URL must start with 'http://', 'https://', 'file://', or 'raw:'" ) - async def _crawl_web( self, url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse: + async def _crawl_web( + self, url: str, config: CrawlerRunConfig + ) -> AsyncCrawlResponse: """ Internal method to crawl web URLs with the specified configuration. @@ -1188,11 +1245,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): try: # Generate a unique nonce for this request nonce = hashlib.sha256(os.urandom(32)).hexdigest() - + # Add CSP headers to the request - await page.set_extra_http_headers({ - 'Content-Security-Policy': f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'" - }) + await page.set_extra_http_headers( + { + "Content-Security-Policy": f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'" + } + ) response = await page.goto( url, wait_until=config.wait_until, timeout=config.page_timeout @@ -1200,7 +1259,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): except Error as e: raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}") - await self.execute_hook("after_goto", page, context=context, url=url, response=response) + await self.execute_hook( + "after_goto", page, context=context, url=url, response=response + ) if response is None: status_code = 200 @@ -1216,7 +1277,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Wait for body element and visibility try: await page.wait_for_selector("body", state="attached", timeout=30000) - + # Use the new check_visibility function with csp_compliant_wait is_visible = await self.csp_compliant_wait( page, @@ -1229,16 +1290,16 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): style.opacity !== '0'; return isVisible; }""", - timeout=30000 + timeout=30000, ) - + if not is_visible and not config.ignore_body_visibility: visibility_info = await self.check_visibility(page) raise Error(f"Body element is hidden: {visibility_info}") - except Error as e: + except Error: visibility_info = await self.check_visibility(page) - + if self.config.verbose: self.logger.debug( message="Body visibility info: {info}", @@ -1247,19 +1308,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): ) if not config.ignore_body_visibility: - raise Error(f"Body element is hidden: {visibility_info}") - - + raise Error(f"Body element is hidden: {visibility_info}") + # try: # await page.wait_for_selector("body", state="attached", timeout=30000) - + # await page.wait_for_function( # """ # () => { # const body = document.body; # const style = window.getComputedStyle(body); - # return style.display !== 'none' && - # style.visibility !== 'hidden' && + # return style.display !== 'none' && + # style.visibility !== 'hidden' && # style.opacity !== '0'; # } # """, @@ -1298,14 +1358,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): ): await page.wait_for_load_state("domcontentloaded") await asyncio.sleep(0.1) - + # Check for image loading with improved error handling images_loaded = await self.csp_compliant_wait( page, "() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)", - timeout=1000 + timeout=1000, ) - + if not images_loaded and self.logger: self.logger.warning( message="Some images failed to load within timeout", @@ -1316,8 +1376,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): if not self.browser_config.text_mode and config.adjust_viewport_to_content: try: dimensions = await self.get_page_dimensions(page) - page_height = dimensions['height'] - page_width = dimensions['width'] + page_height = dimensions["height"] + page_width = dimensions["width"] # page_width = await page.evaluate( # "document.documentElement.scrollWidth" # ) @@ -1361,16 +1421,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # elif isinstance(config.js_code, list): # for js in config.js_code: # await page.evaluate(js) - + if config.js_code: # execution_result = await self.execute_user_script(page, config.js_code) - execution_result = await self.robust_execute_user_script(page, config.js_code) + execution_result = await self.robust_execute_user_script( + page, config.js_code + ) if not execution_result["success"]: self.logger.warning( message="User script execution had issues: {error}", tag="JS_EXEC", - params={"error": execution_result.get("error")} - ) + params={"error": execution_result.get("error")}, + ) await self.execute_hook("on_execution_started", page, context=context) @@ -1385,7 +1447,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Todo: Decide how to handle this if not config.wait_for and config.css_selector and False: config.wait_for = f"css:{config.css_selector}" - + if config.wait_for: try: await self.smart_wait( @@ -1425,7 +1487,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Get final HTML content html = await page.content() - await self.execute_hook("before_return_html", page = page, html = html, context=context) + await self.execute_hook( + "before_return_html", page=page, html=html, context=context + ) # Handle PDF and screenshot generation start_export_time = time.perf_counter() @@ -1475,7 +1539,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): except Exception as e: raise e - + finally: # If no session_id is given we should close the page if not config.session_id: @@ -1483,20 +1547,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def _handle_full_page_scan(self, page: Page, scroll_delay: float = 0.1): """ - Helper method to handle full page scanning. - + Helper method to handle full page scanning. + How it works: 1. Get the viewport height. 2. Scroll to the bottom of the page. 3. Get the total height of the page. 4. Scroll back to the top of the page. - 5. Scroll to the bottom of the page again. + 5. Scroll to the bottom of the page again. 6. Continue scrolling until the bottom of the page is reached. - + Args: page (Page): The Playwright page object scroll_delay (float): The delay between page scrolls - + """ try: viewport_height = page.viewport_size.get( @@ -1511,8 +1575,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # total_height = await page.evaluate("document.documentElement.scrollHeight") dimensions = await self.get_page_dimensions(page) - total_height = dimensions['height'] - + total_height = dimensions["height"] + while current_position < total_height: current_position = min(current_position + viewport_height, total_height) await self.safe_scroll(page, 0, current_position, delay=scroll_delay) @@ -1521,8 +1585,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # new_height = await page.evaluate("document.documentElement.scrollHeight") dimensions = await self.get_page_dimensions(page) - new_height = dimensions['height'] - + new_height = dimensions["height"] + if new_height > total_height: total_height = new_height @@ -1542,7 +1606,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def _handle_download(self, download): """ Handle file downloads. - + How it works: 1. Get the suggested filename. 2. Get the download path. @@ -1550,10 +1614,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): 4. Start the download. 5. Save the downloaded file. 6. Log the completion. - + Args: download (Download): The Playwright download object - + Returns: None """ @@ -1598,7 +1662,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): remove_overlays_js = load_js_script("remove_overlay_elements") try: - await page.evaluate(f""" + await page.evaluate( + f""" (() => {{ try {{ {remove_overlays_js} @@ -1611,7 +1676,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): }}; }} }})() - """) + """ + ) await page.wait_for_timeout(500) # Wait for any animations to complete except Exception as e: self.logger.warning( @@ -1623,10 +1689,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def export_pdf(self, page: Page) -> bytes: """ Exports the current page as a PDF. - + Args: page (Page): The Playwright page object - + Returns: bytes: The PDF data """ @@ -1636,16 +1702,16 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def take_screenshot(self, page, **kwargs) -> str: """ Take a screenshot of the current page. - + Args: page (Page): The Playwright page object kwargs: Additional keyword arguments - + Returns: str: The base64-encoded screenshot data """ need_scroll = await self.page_need_scroll(page) - + if not need_scroll: # Page is short enough, just take a screenshot return await self.take_screenshot_naive(page) @@ -1656,13 +1722,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): async def take_screenshot_from_pdf(self, pdf_data: bytes) -> str: """ - Convert the first page of the PDF to a screenshot. - + Convert the first page of the PDF to a screenshot. + Requires pdf2image and poppler. - + Args: pdf_data (bytes): The PDF data - + Returns: str: The base64-encoded screenshot data """ @@ -1694,21 +1760,21 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Attempt to set a large viewport and take a full-page screenshot. If still too large, segment the page as before. - + Requires pdf2image and poppler. - + Args: page (Page): The Playwright page object kwargs: Additional keyword arguments - + Returns: str: The base64-encoded screenshot data """ try: # Get page height dimensions = await self.get_page_dimensions(page) - page_width = dimensions['width'] - page_height = dimensions['height'] + page_width = dimensions["width"] + page_height = dimensions["height"] # page_height = await page.evaluate("document.documentElement.scrollHeight") # page_width = await page.evaluate("document.documentElement.scrollWidth") @@ -1805,10 +1871,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ Exports the current storage state (cookies, localStorage, sessionStorage) to a JSON file at the specified path. - + Args: path (str): The path to save the storage state JSON file - + Returns: dict: The exported storage state """ @@ -1826,33 +1892,35 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): tag="WARNING", ) - async def robust_execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]: + async def robust_execute_user_script( + self, page: Page, js_code: Union[str, List[str]] + ) -> Dict[str, Any]: """ Executes user-provided JavaScript code with proper error handling and context, supporting both synchronous and async user code, plus navigations. - + How it works: 1. Wait for load state 'domcontentloaded' 2. If js_code is a string, execute it directly 3. If js_code is a list, execute each element in sequence - 4. Wait for load state 'networkidle' - 5. Return results - - Args: + 4. Wait for load state 'networkidle' + 5. Return results + + Args: page (Page): The Playwright page instance js_code (Union[str, List[str]]): The JavaScript code to execute - + Returns: Dict[str, Any]: The results of the execution """ try: - await page.wait_for_load_state('domcontentloaded') - + await page.wait_for_load_state("domcontentloaded") + if isinstance(js_code, str): scripts = [js_code] else: scripts = js_code - + results = [] for script in scripts: try: @@ -1861,7 +1929,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # then wait for the new page to load before continuing result = None try: - result = await page.evaluate(f""" + result = await page.evaluate( + f""" (async () => {{ try {{ {script} @@ -1870,53 +1939,62 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): return {{ success: false, error: err.toString(), stack: err.stack }}; }} }})(); - """) + """ + ) except Error as e: # If it's due to navigation destroying the context, handle gracefully if "Execution context was destroyed" in str(e): - self.logger.info("Navigation triggered by script, waiting for load state", tag="JS_EXEC") + self.logger.info( + "Navigation triggered by script, waiting for load state", + tag="JS_EXEC", + ) try: - await page.wait_for_load_state('load', timeout=30000) + await page.wait_for_load_state("load", timeout=30000) except Error as nav_err: self.logger.warning( message="Navigation wait failed: {error}", tag="JS_EXEC", - params={"error": str(nav_err)} + params={"error": str(nav_err)}, ) try: - await page.wait_for_load_state('networkidle', timeout=30000) + await page.wait_for_load_state( + "networkidle", timeout=30000 + ) except Error as nav_err: self.logger.warning( message="Network idle wait failed: {error}", tag="JS_EXEC", - params={"error": str(nav_err)} + params={"error": str(nav_err)}, ) # Return partial success, or adapt as you see fit result = { "success": True, - "info": "Navigation triggered, ignoring context destroyed error" + "info": "Navigation triggered, ignoring context destroyed error", } else: # It's some other error, log and continue self.logger.error( message="Playwright execution error: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) result = {"success": False, "error": str(e)} - + # If we made it this far with no repeated error, do post-load waits t1 = time.time() try: - await page.wait_for_load_state('domcontentloaded', timeout=5000) - print("DOM content loaded after script execution in", time.time() - t1) + await page.wait_for_load_state("domcontentloaded", timeout=5000) + print( + "DOM content loaded after script execution in", + time.time() - t1, + ) except Error as e: self.logger.warning( message="DOM content load timeout: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) - + # t1 = time.time() # try: # await page.wait_for_load_state('networkidle', timeout=5000) @@ -1935,46 +2013,49 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): self.logger.error( message="Script chunk failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) results.append({"success": False, "error": str(e)}) return {"success": True, "results": results} - + except Exception as e: self.logger.error( message="Script execution failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) return {"success": False, "error": str(e)} - async def execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]: + async def execute_user_script( + self, page: Page, js_code: Union[str, List[str]] + ) -> Dict[str, Any]: """ Executes user-provided JavaScript code with proper error handling and context. - + Args: page: Playwright page object js_code: Single JavaScript string or list of JavaScript code strings - + Returns: Dict containing execution status and results/errors """ try: # Ensure the page is ready for script execution - await page.wait_for_load_state('domcontentloaded') - + await page.wait_for_load_state("domcontentloaded") + # Handle single script or multiple scripts if isinstance(js_code, str): scripts = [js_code] else: scripts = js_code - + results = [] for script in scripts: try: # Execute the script and wait for network idle - result = await page.evaluate(f""" + result = await page.evaluate( + f""" (() => {{ return new Promise((resolve) => {{ try {{ @@ -2007,57 +2088,61 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): }} }}); }})() - """) - + """ + ) + # Wait for network idle after script execution t1 = time.time() - await page.wait_for_load_state('domcontentloaded', timeout=5000) - print("DOM content loaded after script execution in", time.time() - t1) + await page.wait_for_load_state("domcontentloaded", timeout=5000) + print( + "DOM content loaded after script execution in", time.time() - t1 + ) t1 = time.time() - await page.wait_for_load_state('networkidle', timeout=5000) + await page.wait_for_load_state("networkidle", timeout=5000) print("Network idle after script execution in", time.time() - t1) - + results.append(result if result else {"success": True}) - + except Error as e: # Handle Playwright-specific errors self.logger.error( message="Playwright execution error: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) results.append({"success": False, "error": str(e)}) - + return {"success": True, "results": results} - + except Exception as e: self.logger.error( message="Script execution failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) return {"success": False, "error": str(e)} - + except Exception as e: self.logger.error( message="Script execution failed: {error}", tag="JS_EXEC", - params={"error": str(e)} + params={"error": str(e)}, ) return {"success": False, "error": str(e)} async def check_visibility(self, page): """ Checks if an element is visible on the page. - + Args: page: Playwright page object - + Returns: Boolean indicating visibility """ - return await page.evaluate(""" + return await page.evaluate( + """ () => { const element = document.body; if (!element) return false; @@ -2067,31 +2152,32 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): style.opacity !== '0'; return isVisible; } - """) - + """ + ) + async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1): """ Safely scroll the page with rendering time. - + Args: page: Playwright page object x: Horizontal scroll position y: Vertical scroll position """ result = await self.csp_scroll_to(page, x, y) - if result['success']: + if result["success"]: await page.wait_for_timeout(delay * 1000) return result - + async def csp_scroll_to(self, page: Page, x: int, y: int) -> Dict[str, Any]: """ Performs a CSP-compliant scroll operation and returns the result status. - + Args: page: Playwright page object x: Horizontal scroll position y: Vertical scroll position - + Returns: Dict containing scroll status and position information """ @@ -2125,67 +2211,68 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): }} }}""" ) - - if not result['success']: + + if not result["success"]: self.logger.warning( message="Scroll operation failed: {error}", tag="SCROLL", - params={"error": result.get('error')} + params={"error": result.get("error")}, ) - + return result - + except Exception as e: self.logger.error( message="Failed to execute scroll: {error}", tag="SCROLL", - params={"error": str(e)} + params={"error": str(e)}, ) - return { - "success": False, - "error": str(e) - } - + return {"success": False, "error": str(e)} + async def get_page_dimensions(self, page: Page): """ Get the dimensions of the page. - + Args: page: Playwright page object - + Returns: Dict containing width and height of the page """ - return await page.evaluate(""" + return await page.evaluate( + """ () => { const {scrollWidth, scrollHeight} = document.documentElement; return {width: scrollWidth, height: scrollHeight}; } - """) - + """ + ) + async def page_need_scroll(self, page: Page) -> bool: """ Determine whether the page need to scroll - + Args: page: Playwright page object - + Returns: bool: True if page needs scrolling """ try: - need_scroll = await page.evaluate(""" + need_scroll = await page.evaluate( + """ () => { const scrollHeight = document.documentElement.scrollHeight; const viewportHeight = window.innerHeight; return scrollHeight > viewportHeight; } - """) + """ + ) return need_scroll except Exception as e: self.logger.warning( message="Failed to check scroll need: {error}. Defaulting to True for safety.", tag="SCROLL", - params={"error": str(e)} + params={"error": str(e)}, ) - return True # Default to scrolling if check fails \ No newline at end of file + return True # Default to scrolling if check fails diff --git a/crawl4ai/async_database.py b/crawl4ai/async_database.py index aed9c76b..669ddec2 100644 --- a/crawl4ai/async_database.py +++ b/crawl4ai/async_database.py @@ -1,27 +1,29 @@ -import os, sys +import os from pathlib import Path import aiosqlite import asyncio -from typing import Optional, Tuple, Dict +from typing import Optional, Dict from contextlib import asynccontextmanager import logging import json # Added for serialization/deserialization from .utils import ensure_content_dirs, generate_content_hash from .models import CrawlResult, MarkdownGenerationResult -import xxhash import aiofiles -from .config import NEED_MIGRATION from .version_manager import VersionManager from .async_logger import AsyncLogger from .utils import get_error_context, create_box_message + # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") +base_directory = DB_PATH = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai" +) os.makedirs(DB_PATH, exist_ok=True) DB_PATH = os.path.join(base_directory, "crawl4ai.db") + class AsyncDatabaseManager: def __init__(self, pool_size: int = 10, max_retries: int = 3): self.db_path = DB_PATH @@ -32,28 +34,27 @@ class AsyncDatabaseManager: self.pool_lock = asyncio.Lock() self.init_lock = asyncio.Lock() self.connection_semaphore = asyncio.Semaphore(pool_size) - self._initialized = False + self._initialized = False self.version_manager = VersionManager() self.logger = AsyncLogger( log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), verbose=False, - tag_width=10 + tag_width=10, ) - - + async def initialize(self): """Initialize the database and connection pool""" try: self.logger.info("Initializing database", tag="INIT") # Ensure the database file exists os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - + # Check if version update is needed needs_update = self.version_manager.needs_update() - + # Always ensure base table exists await self.ainit_db() - + # Verify the table exists async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with db.execute( @@ -62,33 +63,37 @@ class AsyncDatabaseManager: result = await cursor.fetchone() if not result: raise Exception("crawled_data table was not created") - + # If version changed or fresh install, run updates if needs_update: self.logger.info("New version detected, running updates", tag="INIT") await self.update_db_schema() - from .migrations import run_migration # Import here to avoid circular imports + from .migrations import ( + run_migration, + ) # Import here to avoid circular imports + await run_migration() self.version_manager.update_version() # Update stored version after successful migration - self.logger.success("Version update completed successfully", tag="COMPLETE") + self.logger.success( + "Version update completed successfully", tag="COMPLETE" + ) else: - self.logger.success("Database initialization completed successfully", tag="COMPLETE") + self.logger.success( + "Database initialization completed successfully", tag="COMPLETE" + ) - except Exception as e: self.logger.error( message="Database initialization error: {error}", tag="ERROR", - params={"error": str(e)} + params={"error": str(e)}, ) self.logger.info( - message="Database will be initialized on first use", - tag="INIT" + message="Database will be initialized on first use", tag="INIT" ) - + raise - async def cleanup(self): """Cleanup connections when shutting down""" async with self.pool_lock: @@ -107,6 +112,7 @@ class AsyncDatabaseManager: self._initialized = True except Exception as e: import sys + error_context = get_error_context(sys.exc_info()) self.logger.error( message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", @@ -115,41 +121,52 @@ class AsyncDatabaseManager: params={ "error": str(e), "context": error_context["code_context"], - "traceback": error_context["full_traceback"] - } + "traceback": error_context["full_traceback"], + }, ) raise await self.connection_semaphore.acquire() task_id = id(asyncio.current_task()) - + try: async with self.pool_lock: if task_id not in self.connection_pool: try: - conn = await aiosqlite.connect( - self.db_path, - timeout=30.0 - ) - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA busy_timeout = 5000') - + conn = await aiosqlite.connect(self.db_path, timeout=30.0) + await conn.execute("PRAGMA journal_mode = WAL") + await conn.execute("PRAGMA busy_timeout = 5000") + # Verify database structure - async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: + async with conn.execute( + "PRAGMA table_info(crawled_data)" + ) as cursor: columns = await cursor.fetchall() column_names = [col[1] for col in columns] expected_columns = { - 'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', - 'success', 'media', 'links', 'metadata', 'screenshot', - 'response_headers', 'downloaded_files' + "url", + "html", + "cleaned_html", + "markdown", + "extracted_content", + "success", + "media", + "links", + "metadata", + "screenshot", + "response_headers", + "downloaded_files", } missing_columns = expected_columns - set(column_names) if missing_columns: - raise ValueError(f"Database missing columns: {missing_columns}") - + raise ValueError( + f"Database missing columns: {missing_columns}" + ) + self.connection_pool[task_id] = conn except Exception as e: import sys + error_context = get_error_context(sys.exc_info()) error_message = ( f"Unexpected error in db get_connection at line {error_context['line_no']} " @@ -158,7 +175,7 @@ class AsyncDatabaseManager: f"Code context:\n{error_context['code_context']}" ) self.logger.error( - message=create_box_message(error_message, type= "error"), + message=create_box_message(error_message, type="error"), ) raise @@ -167,6 +184,7 @@ class AsyncDatabaseManager: except Exception as e: import sys + error_context = get_error_context(sys.exc_info()) error_message = ( f"Unexpected error in db get_connection at line {error_context['line_no']} " @@ -175,7 +193,7 @@ class AsyncDatabaseManager: f"Code context:\n{error_context['code_context']}" ) self.logger.error( - message=create_box_message(error_message, type= "error"), + message=create_box_message(error_message, type="error"), ) raise finally: @@ -185,7 +203,6 @@ class AsyncDatabaseManager: del self.connection_pool[task_id] self.connection_semaphore.release() - async def execute_with_retry(self, operation, *args): """Execute database operations with retry logic""" for attempt in range(self.max_retries): @@ -200,18 +217,16 @@ class AsyncDatabaseManager: message="Operation failed after {retries} attempts: {error}", tag="ERROR", force_verbose=True, - params={ - "retries": self.max_retries, - "error": str(e) - } - ) + params={"retries": self.max_retries, "error": str(e)}, + ) raise await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff async def ainit_db(self): """Initialize database schema""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: - await db.execute(''' + await db.execute( + """ CREATE TABLE IF NOT EXISTS crawled_data ( url TEXT PRIMARY KEY, html TEXT, @@ -226,21 +241,27 @@ class AsyncDatabaseManager: response_headers TEXT DEFAULT "{}", downloaded_files TEXT DEFAULT "{}" -- New column added ) - ''') + """ + ) await db.commit() - - async def update_db_schema(self): """Update database schema if needed""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: cursor = await db.execute("PRAGMA table_info(crawled_data)") columns = await cursor.fetchall() column_names = [column[1] for column in columns] - + # List of new columns to add - new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] - + new_columns = [ + "media", + "links", + "metadata", + "screenshot", + "response_headers", + "downloaded_files", + ] + for column in new_columns: if column not in column_names: await self.aalter_db_add_column(column, db) @@ -248,75 +269,91 @@ class AsyncDatabaseManager: async def aalter_db_add_column(self, new_column: str, db): """Add new column to the database""" - if new_column == 'response_headers': - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') + if new_column == "response_headers": + await db.execute( + f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"' + ) else: - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') + await db.execute( + f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""' + ) self.logger.info( message="Added column '{column}' to the database", tag="INIT", - params={"column": new_column} - ) - + params={"column": new_column}, + ) async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: """Retrieve cached URL data as CrawlResult""" + async def _get(db): async with db.execute( - 'SELECT * FROM crawled_data WHERE url = ?', (url,) + "SELECT * FROM crawled_data WHERE url = ?", (url,) ) as cursor: row = await cursor.fetchone() if not row: return None - + # Get column names columns = [description[0] for description in cursor.description] # Create dict from row data row_dict = dict(zip(columns, row)) - + # Load content from files using stored hashes content_fields = { - 'html': row_dict['html'], - 'cleaned_html': row_dict['cleaned_html'], - 'markdown': row_dict['markdown'], - 'extracted_content': row_dict['extracted_content'], - 'screenshot': row_dict['screenshot'], - 'screenshots': row_dict['screenshot'], + "html": row_dict["html"], + "cleaned_html": row_dict["cleaned_html"], + "markdown": row_dict["markdown"], + "extracted_content": row_dict["extracted_content"], + "screenshot": row_dict["screenshot"], + "screenshots": row_dict["screenshot"], } - + for field, hash_value in content_fields.items(): if hash_value: content = await self._load_content( - hash_value, - field.split('_')[0] # Get content type from field name + hash_value, + field.split("_")[0], # Get content type from field name ) row_dict[field] = content or "" else: row_dict[field] = "" # Parse JSON fields - json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] + json_fields = [ + "media", + "links", + "metadata", + "response_headers", + "markdown", + ] for field in json_fields: try: - row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} + row_dict[field] = ( + json.loads(row_dict[field]) if row_dict[field] else {} + ) except json.JSONDecodeError: row_dict[field] = {} - if isinstance(row_dict['markdown'], Dict): - row_dict['markdown_v2'] = row_dict['markdown'] - if row_dict['markdown'].get('raw_markdown'): - row_dict['markdown'] = row_dict['markdown']['raw_markdown'] - + if isinstance(row_dict["markdown"], Dict): + row_dict["markdown_v2"] = row_dict["markdown"] + if row_dict["markdown"].get("raw_markdown"): + row_dict["markdown"] = row_dict["markdown"]["raw_markdown"] + # Parse downloaded_files try: - row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] + row_dict["downloaded_files"] = ( + json.loads(row_dict["downloaded_files"]) + if row_dict["downloaded_files"] + else [] + ) except json.JSONDecodeError: - row_dict['downloaded_files'] = [] + row_dict["downloaded_files"] = [] # Remove any fields not in CrawlResult model valid_fields = CrawlResult.__annotations__.keys() filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields} - + return CrawlResult(**filtered_dict) try: @@ -326,7 +363,7 @@ class AsyncDatabaseManager: message="Error retrieving cached URL: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) return None @@ -334,37 +371,52 @@ class AsyncDatabaseManager: """Cache CrawlResult data""" # Store content files and get hashes content_map = { - 'html': (result.html, 'html'), - 'cleaned_html': (result.cleaned_html or "", 'cleaned'), - 'markdown': None, - 'extracted_content': (result.extracted_content or "", 'extracted'), - 'screenshot': (result.screenshot or "", 'screenshots') + "html": (result.html, "html"), + "cleaned_html": (result.cleaned_html or "", "cleaned"), + "markdown": None, + "extracted_content": (result.extracted_content or "", "extracted"), + "screenshot": (result.screenshot or "", "screenshots"), } try: if isinstance(result.markdown, MarkdownGenerationResult): - content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') - elif hasattr(result, 'markdown_v2'): - content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') + content_map["markdown"] = ( + result.markdown.model_dump_json(), + "markdown", + ) + elif hasattr(result, "markdown_v2"): + content_map["markdown"] = ( + result.markdown_v2.model_dump_json(), + "markdown", + ) elif isinstance(result.markdown, str): markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) - content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') + content_map["markdown"] = ( + markdown_result.model_dump_json(), + "markdown", + ) else: - content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') + content_map["markdown"] = ( + MarkdownGenerationResult().model_dump_json(), + "markdown", + ) except Exception as e: self.logger.warning( - message=f"Error processing markdown content: {str(e)}", - tag="WARNING" + message=f"Error processing markdown content: {str(e)}", tag="WARNING" ) # Fallback to empty markdown result - content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') - + content_map["markdown"] = ( + MarkdownGenerationResult().model_dump_json(), + "markdown", + ) + content_hashes = {} for field, (content, content_type) in content_map.items(): content_hashes[field] = await self._store_content(content, content_type) async def _cache(db): - await db.execute(''' + await db.execute( + """ INSERT INTO crawled_data ( url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, @@ -383,20 +435,22 @@ class AsyncDatabaseManager: screenshot = excluded.screenshot, response_headers = excluded.response_headers, downloaded_files = excluded.downloaded_files - ''', ( - result.url, - content_hashes['html'], - content_hashes['cleaned_html'], - content_hashes['markdown'], - content_hashes['extracted_content'], - result.success, - json.dumps(result.media), - json.dumps(result.links), - json.dumps(result.metadata or {}), - content_hashes['screenshot'], - json.dumps(result.response_headers or {}), - json.dumps(result.downloaded_files or []) - )) + """, + ( + result.url, + content_hashes["html"], + content_hashes["cleaned_html"], + content_hashes["markdown"], + content_hashes["extracted_content"], + result.success, + json.dumps(result.media), + json.dumps(result.links), + json.dumps(result.metadata or {}), + content_hashes["screenshot"], + json.dumps(result.response_headers or {}), + json.dumps(result.downloaded_files or []), + ), + ) try: await self.execute_with_retry(_cache) @@ -405,14 +459,14 @@ class AsyncDatabaseManager: message="Error caching URL: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) - async def aget_total_count(self) -> int: """Get total number of cached URLs""" + async def _count(db): - async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: + async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor: result = await cursor.fetchone() return result[0] if result else 0 @@ -423,14 +477,15 @@ class AsyncDatabaseManager: message="Error getting total count: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) return 0 async def aclear_db(self): """Clear all data from the database""" + async def _clear(db): - await db.execute('DELETE FROM crawled_data') + await db.execute("DELETE FROM crawled_data") try: await self.execute_with_retry(_clear) @@ -439,13 +494,14 @@ class AsyncDatabaseManager: message="Error clearing database: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) async def aflush_db(self): """Drop the entire table""" + async def _flush(db): - await db.execute('DROP TABLE IF EXISTS crawled_data') + await db.execute("DROP TABLE IF EXISTS crawled_data") try: await self.execute_with_retry(_flush) @@ -454,42 +510,44 @@ class AsyncDatabaseManager: message="Error flushing database: {error}", tag="ERROR", force_verbose=True, - params={"error": str(e)} + params={"error": str(e)}, ) - - + async def _store_content(self, content: str, content_type: str) -> str: """Store content in filesystem and return hash""" if not content: return "" - + content_hash = generate_content_hash(content) file_path = os.path.join(self.content_paths[content_type], content_hash) - + # Only write if file doesn't exist if not os.path.exists(file_path): - async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: + async with aiofiles.open(file_path, "w", encoding="utf-8") as f: await f.write(content) - + return content_hash - async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: + async def _load_content( + self, content_hash: str, content_type: str + ) -> Optional[str]: """Load content from filesystem by hash""" if not content_hash: return None - + file_path = os.path.join(self.content_paths[content_type], content_hash) try: - async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: return await f.read() except: self.logger.error( message="Failed to load content: {file_path}", tag="ERROR", force_verbose=True, - params={"file_path": file_path} + params={"file_path": file_path}, ) return None + # Create a singleton instance async_db_manager = AsyncDatabaseManager() diff --git a/crawl4ai/async_dispatcher.py b/crawl4ai/async_dispatcher.py index 8f5fbe81..b796a92b 100644 --- a/crawl4ai/async_dispatcher.py +++ b/crawl4ai/async_dispatcher.py @@ -1,14 +1,19 @@ -from typing import Dict, Optional, List -from .async_configs import * -from .models import * +from typing import Dict, Optional, List, Tuple +from .async_configs import CrawlerRunConfig +from .models import ( + CrawlResult, + CrawlerTaskResult, + CrawlStatus, + DisplayMode, + CrawlStats, + DomainState, +) from rich.live import Live from rich.table import Table from rich.console import Console -from rich.style import Style from rich import box from datetime import datetime, timedelta -from dataclasses import dataclass import time import psutil @@ -26,63 +31,66 @@ class RateLimiter: base_delay: Tuple[float, float] = (1.0, 3.0), max_delay: float = 60.0, max_retries: int = 3, - rate_limit_codes: List[int] = None + rate_limit_codes: List[int] = None, ): self.base_delay = base_delay self.max_delay = max_delay self.max_retries = max_retries self.rate_limit_codes = rate_limit_codes or [429, 503] self.domains: Dict[str, DomainState] = {} - + def get_domain(self, url: str) -> str: return urlparse(url).netloc - + async def wait_if_needed(self, url: str) -> None: domain = self.get_domain(url) state = self.domains.get(domain) - + if not state: self.domains[domain] = DomainState() state = self.domains[domain] - + now = time.time() if state.last_request_time: wait_time = max(0, state.current_delay - (now - state.last_request_time)) if wait_time > 0: await asyncio.sleep(wait_time) - + # Random delay within base range if no current delay if state.current_delay == 0: state.current_delay = random.uniform(*self.base_delay) - + state.last_request_time = time.time() - + def update_delay(self, url: str, status_code: int) -> bool: domain = self.get_domain(url) state = self.domains[domain] - + if status_code in self.rate_limit_codes: state.fail_count += 1 if state.fail_count > self.max_retries: return False - + # Exponential backoff with random jitter state.current_delay = min( - state.current_delay * 2 * random.uniform(0.75, 1.25), - self.max_delay + state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay ) else: # Gradually reduce delay on success state.current_delay = max( - random.uniform(*self.base_delay), - state.current_delay * 0.75 + random.uniform(*self.base_delay), state.current_delay * 0.75 ) state.fail_count = 0 - + return True + class CrawlerMonitor: - def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED): + def __init__( + self, + max_visible_rows: int = 15, + display_mode: DisplayMode = DisplayMode.DETAILED, + ): self.console = Console() self.max_visible_rows = max_visible_rows self.display_mode = display_mode @@ -90,23 +98,25 @@ class CrawlerMonitor: self.process = psutil.Process() self.start_time = datetime.now() self.live = Live(self._create_table(), refresh_per_second=2) - + def start(self): self.live.start() - + def stop(self): self.live.stop() - + def add_task(self, task_id: str, url: str): - self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED) + self.stats[task_id] = CrawlStats( + task_id=task_id, url=url, status=CrawlStatus.QUEUED + ) self.live.update(self._create_table()) - + def update_task(self, task_id: str, **kwargs): if task_id in self.stats: for key, value in kwargs.items(): setattr(self.stats[task_id], key, value) self.live.update(self._create_table()) - + def _create_aggregated_table(self) -> Table: """Creates a compact table showing only aggregated statistics""" table = Table( @@ -114,78 +124,78 @@ class CrawlerMonitor: title="Crawler Status Overview", title_style="bold magenta", header_style="bold blue", - show_lines=True + show_lines=True, ) - + # Calculate statistics total_tasks = len(self.stats) - queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED) - in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS) - completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED) - failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED) - + queued = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED + ) + in_progress = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS + ) + completed = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED + ) + failed = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED + ) + # Memory statistics current_memory = self.process.memory_info().rss / (1024 * 1024) total_task_memory = sum(stat.memory_usage for stat in self.stats.values()) - peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0) - + peak_memory = max( + (stat.peak_memory for stat in self.stats.values()), default=0.0 + ) + # Duration duration = datetime.now() - self.start_time - + # Create status row table.add_column("Status", style="bold cyan") table.add_column("Count", justify="right") table.add_column("Percentage", justify="right") - - table.add_row( - "Total Tasks", - str(total_tasks), - "100%" - ) + + table.add_row("Total Tasks", str(total_tasks), "100%") table.add_row( "[yellow]In Queue[/yellow]", str(queued), - f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) table.add_row( "[blue]In Progress[/blue]", str(in_progress), - f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) table.add_row( "[green]Completed[/green]", str(completed), - f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) table.add_row( "[red]Failed[/red]", str(failed), - f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" + f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%", ) - + # Add memory information table.add_section() table.add_row( - "[magenta]Current Memory[/magenta]", - f"{current_memory:.1f} MB", - "" + "[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", "" ) table.add_row( - "[magenta]Total Task Memory[/magenta]", - f"{total_task_memory:.1f} MB", - "" + "[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", "" ) table.add_row( - "[magenta]Peak Task Memory[/magenta]", - f"{peak_memory:.1f} MB", - "" + "[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", "" ) table.add_row( "[yellow]Runtime[/yellow]", str(timedelta(seconds=int(duration.total_seconds()))), - "" + "", ) - + return table def _create_detailed_table(self) -> Table: @@ -193,9 +203,9 @@ class CrawlerMonitor: box=box.ROUNDED, title="Crawler Performance Monitor", title_style="bold magenta", - header_style="bold blue" + header_style="bold blue", ) - + # Add columns table.add_column("Task ID", style="cyan", no_wrap=True) table.add_column("URL", style="cyan", no_wrap=True) @@ -204,47 +214,54 @@ class CrawlerMonitor: table.add_column("Peak (MB)", justify="right") table.add_column("Duration", justify="right") table.add_column("Info", style="italic") - + # Add summary row total_memory = sum(stat.memory_usage for stat in self.stats.values()) - active_count = sum(1 for stat in self.stats.values() - if stat.status == CrawlStatus.IN_PROGRESS) - completed_count = sum(1 for stat in self.stats.values() - if stat.status == CrawlStatus.COMPLETED) - failed_count = sum(1 for stat in self.stats.values() - if stat.status == CrawlStatus.FAILED) - + active_count = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS + ) + completed_count = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED + ) + failed_count = sum( + 1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED + ) + table.add_row( "[bold yellow]SUMMARY", f"Total: {len(self.stats)}", f"Active: {active_count}", f"{total_memory:.1f}", f"{self.process.memory_info().rss / (1024 * 1024):.1f}", - str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))), + str( + timedelta( + seconds=int((datetime.now() - self.start_time).total_seconds()) + ) + ), f"✓{completed_count} ✗{failed_count}", - style="bold" + style="bold", ) - + table.add_section() - + # Add rows for each task visible_stats = sorted( self.stats.values(), key=lambda x: ( x.status != CrawlStatus.IN_PROGRESS, x.status != CrawlStatus.QUEUED, - x.end_time or datetime.max - ) - )[:self.max_visible_rows] - + x.end_time or datetime.max, + ), + )[: self.max_visible_rows] + for stat in visible_stats: status_style = { CrawlStatus.QUEUED: "white", CrawlStatus.IN_PROGRESS: "yellow", CrawlStatus.COMPLETED: "green", - CrawlStatus.FAILED: "red" + CrawlStatus.FAILED: "red", }[stat.status] - + table.add_row( stat.task_id[:8], # Show first 8 chars of task ID stat.url[:40] + "..." if len(stat.url) > 40 else stat.url, @@ -252,9 +269,9 @@ class CrawlerMonitor: f"{stat.memory_usage:.1f}", f"{stat.peak_memory:.1f}", stat.duration, - stat.error_message[:40] if stat.error_message else "" + stat.error_message[:40] if stat.error_message else "", ) - + return table def _create_table(self) -> Table: @@ -268,7 +285,7 @@ class BaseDispatcher(ABC): def __init__( self, rate_limiter: Optional[RateLimiter] = None, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ): self.crawler = None self._domain_last_hit: Dict[str, float] = {} @@ -278,24 +295,25 @@ class BaseDispatcher(ABC): @abstractmethod async def crawl_url( - self, - url: str, - config: CrawlerRunConfig, + self, + url: str, + config: CrawlerRunConfig, task_id: str, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ) -> CrawlerTaskResult: pass @abstractmethod async def run_urls( - self, - urls: List[str], - crawler: "AsyncWebCrawler", + self, + urls: List[str], + crawler: "AsyncWebCrawler", # noqa: F821 config: CrawlerRunConfig, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ) -> List[CrawlerTaskResult]: pass + class MemoryAdaptiveDispatcher(BaseDispatcher): def __init__( self, @@ -304,39 +322,41 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): max_session_permit: int = 20, memory_wait_timeout: float = 300.0, # 5 minutes default timeout rate_limiter: Optional[RateLimiter] = None, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ): super().__init__(rate_limiter, monitor) self.memory_threshold_percent = memory_threshold_percent self.check_interval = check_interval self.max_session_permit = max_session_permit self.memory_wait_timeout = memory_wait_timeout - + async def crawl_url( - self, - url: str, - config: CrawlerRunConfig, + self, + url: str, + config: CrawlerRunConfig, task_id: str, ) -> CrawlerTaskResult: start_time = datetime.now() error_message = "" memory_usage = peak_memory = 0.0 - + try: if self.monitor: - self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) + self.monitor.update_task( + task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time + ) self.concurrent_sessions += 1 - + if self.rate_limiter: await self.rate_limiter.wait_if_needed(url) - + process = psutil.Process() start_memory = process.memory_info().rss / (1024 * 1024) result = await self.crawler.arun(url, config=config, session_id=task_id) end_memory = process.memory_info().rss / (1024 * 1024) - + memory_usage = peak_memory = end_memory - start_memory - + if self.rate_limiter and result.status_code: if not self.rate_limiter.update_delay(url, result.status_code): error_message = f"Rate limit retry count exceeded for domain {urlparse(url).netloc}" @@ -350,22 +370,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=datetime.now(), - error_message=error_message + error_message=error_message, ) - + if not result.success: error_message = result.error_message if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) elif self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.COMPLETED) - + except Exception as e: error_message = str(e) if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) - result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) - + result = CrawlResult( + url=url, html="", metadata={}, success=False, error_message=str(e) + ) + finally: end_time = datetime.now() if self.monitor: @@ -374,10 +396,10 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): end_time=end_time, memory_usage=memory_usage, peak_memory=peak_memory, - error_message=error_message + error_message=error_message, ) self.concurrent_sessions -= 1 - + return CrawlerTaskResult( task_id=task_id, url=url, @@ -386,20 +408,20 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=end_time, - error_message=error_message + error_message=error_message, ) async def run_urls( - self, - urls: List[str], - crawler: "AsyncWebCrawler", + self, + urls: List[str], + crawler: "AsyncWebCrawler", # noqa: F821 config: CrawlerRunConfig, ) -> List[CrawlerTaskResult]: self.crawler = crawler - + if self.monitor: self.monitor.start() - + try: pending_tasks = [] active_tasks = [] @@ -417,23 +439,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): if psutil.virtual_memory().percent >= self.memory_threshold_percent: # Check if we've exceeded the timeout if time.time() - wait_start_time > self.memory_wait_timeout: - raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds") + raise MemoryError( + f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds" + ) await asyncio.sleep(self.check_interval) continue - + url, task_id = task_queue.pop(0) task = asyncio.create_task(self.crawl_url(url, config, task_id)) active_tasks.append(task) - + if not active_tasks: await asyncio.sleep(self.check_interval) continue - + done, pending = await asyncio.wait( - active_tasks, - return_when=asyncio.FIRST_COMPLETED + active_tasks, return_when=asyncio.FIRST_COMPLETED ) - + pending_tasks.extend(done) active_tasks = list(pending) @@ -442,24 +465,25 @@ class MemoryAdaptiveDispatcher(BaseDispatcher): if self.monitor: self.monitor.stop() + class SemaphoreDispatcher(BaseDispatcher): def __init__( self, semaphore_count: int = 5, max_session_permit: int = 20, rate_limiter: Optional[RateLimiter] = None, - monitor: Optional[CrawlerMonitor] = None + monitor: Optional[CrawlerMonitor] = None, ): super().__init__(rate_limiter, monitor) self.semaphore_count = semaphore_count - self.max_session_permit = max_session_permit - + self.max_session_permit = max_session_permit + async def crawl_url( - self, - url: str, - config: CrawlerRunConfig, + self, + url: str, + config: CrawlerRunConfig, task_id: str, - semaphore: asyncio.Semaphore = None + semaphore: asyncio.Semaphore = None, ) -> CrawlerTaskResult: start_time = datetime.now() error_message = "" @@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher): try: if self.monitor: - self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) + self.monitor.update_task( + task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time + ) if self.rate_limiter: await self.rate_limiter.wait_if_needed(url) @@ -477,7 +503,7 @@ class SemaphoreDispatcher(BaseDispatcher): start_memory = process.memory_info().rss / (1024 * 1024) result = await self.crawler.arun(url, config=config, session_id=task_id) end_memory = process.memory_info().rss / (1024 * 1024) - + memory_usage = peak_memory = end_memory - start_memory if self.rate_limiter and result.status_code: @@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=datetime.now(), - error_message=error_message + error_message=error_message, ) if not result.success: @@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher): error_message = str(e) if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) - result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) + result = CrawlResult( + url=url, html="", metadata={}, success=False, error_message=str(e) + ) finally: end_time = datetime.now() @@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher): end_time=end_time, memory_usage=memory_usage, peak_memory=peak_memory, - error_message=error_message + error_message=error_message, ) return CrawlerTaskResult( @@ -528,13 +556,13 @@ class SemaphoreDispatcher(BaseDispatcher): peak_memory=peak_memory, start_time=start_time, end_time=end_time, - error_message=error_message + error_message=error_message, ) async def run_urls( - self, - crawler: "AsyncWebCrawler", - urls: List[str], + self, + crawler: "AsyncWebCrawler", # noqa: F821 + urls: List[str], config: CrawlerRunConfig, ) -> List[CrawlerTaskResult]: self.crawler = crawler @@ -557,4 +585,4 @@ class SemaphoreDispatcher(BaseDispatcher): return await asyncio.gather(*tasks, return_exceptions=True) finally: if self.monitor: - self.monitor.stop() \ No newline at end of file + self.monitor.stop() diff --git a/crawl4ai/async_logger.py b/crawl4ai/async_logger.py index 5d2d54b5..0e049289 100644 --- a/crawl4ai/async_logger.py +++ b/crawl4ai/async_logger.py @@ -1,10 +1,10 @@ from enum import Enum -from typing import Optional, Dict, Any, Union -from colorama import Fore, Back, Style, init -import time +from typing import Optional, Dict, Any +from colorama import Fore, Style, init import os from datetime import datetime + class LogLevel(Enum): DEBUG = 1 INFO = 2 @@ -12,23 +12,24 @@ class LogLevel(Enum): WARNING = 4 ERROR = 5 + class AsyncLogger: """ Asynchronous logger with support for colored console output and file logging. Supports templated messages with colored components. """ - + DEFAULT_ICONS = { - 'INIT': '→', - 'READY': '✓', - 'FETCH': '↓', - 'SCRAPE': '◆', - 'EXTRACT': '■', - 'COMPLETE': '●', - 'ERROR': '×', - 'DEBUG': '⋯', - 'INFO': 'ℹ', - 'WARNING': '⚠', + "INIT": "→", + "READY": "✓", + "FETCH": "↓", + "SCRAPE": "◆", + "EXTRACT": "■", + "COMPLETE": "●", + "ERROR": "×", + "DEBUG": "⋯", + "INFO": "ℹ", + "WARNING": "⚠", } DEFAULT_COLORS = { @@ -46,11 +47,11 @@ class AsyncLogger: tag_width: int = 10, icons: Optional[Dict[str, str]] = None, colors: Optional[Dict[LogLevel, str]] = None, - verbose: bool = True + verbose: bool = True, ): """ Initialize the logger. - + Args: log_file: Optional file path for logging log_level: Minimum log level to display @@ -66,7 +67,7 @@ class AsyncLogger: self.icons = icons or self.DEFAULT_ICONS self.colors = colors or self.DEFAULT_COLORS self.verbose = verbose - + # Create log file directory if needed if log_file: os.makedirs(os.path.dirname(os.path.abspath(log_file)), exist_ok=True) @@ -77,18 +78,20 @@ class AsyncLogger: def _get_icon(self, tag: str) -> str: """Get the icon for a tag, defaulting to info icon if not found.""" - return self.icons.get(tag, self.icons['INFO']) + return self.icons.get(tag, self.icons["INFO"]) def _write_to_file(self, message: str): """Write a message to the log file if configured.""" if self.log_file: - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] - with open(self.log_file, 'a', encoding='utf-8') as f: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + with open(self.log_file, "a", encoding="utf-8") as f: # Strip ANSI color codes for file output - clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '') + clean_message = message.replace(Fore.RESET, "").replace( + Style.RESET_ALL, "" + ) for color in vars(Fore).values(): if isinstance(color, str): - clean_message = clean_message.replace(color, '') + clean_message = clean_message.replace(color, "") f.write(f"[{timestamp}] {clean_message}\n") def _log( @@ -99,11 +102,11 @@ class AsyncLogger: params: Optional[Dict[str, Any]] = None, colors: Optional[Dict[str, str]] = None, base_color: Optional[str] = None, - **kwargs + **kwargs, ): """ Core logging method that handles message formatting and output. - + Args: level: Log level for this message message: Message template string @@ -120,7 +123,7 @@ class AsyncLogger: try: # First format the message with raw parameters formatted_message = message.format(**params) - + # Then apply colors if specified if colors: for key, color in colors.items(): @@ -128,12 +131,13 @@ class AsyncLogger: if key in params: value_str = str(params[key]) formatted_message = formatted_message.replace( - value_str, - f"{color}{value_str}{Style.RESET_ALL}" + value_str, f"{color}{value_str}{Style.RESET_ALL}" ) - + except KeyError as e: - formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template" + formatted_message = ( + f"LOGGING ERROR: Missing parameter {e} in message template" + ) level = LogLevel.ERROR else: formatted_message = message @@ -175,11 +179,11 @@ class AsyncLogger: success: bool, timing: float, tag: str = "FETCH", - url_length: int = 50 + url_length: int = 50, ): """ Convenience method for logging URL fetch status. - + Args: url: The URL being processed success: Whether the operation was successful @@ -195,24 +199,20 @@ class AsyncLogger: "url": url, "url_length": url_length, "status": success, - "timing": timing + "timing": timing, }, colors={ "status": Fore.GREEN if success else Fore.RED, - "timing": Fore.YELLOW - } + "timing": Fore.YELLOW, + }, ) def error_status( - self, - url: str, - error: str, - tag: str = "ERROR", - url_length: int = 50 + self, url: str, error: str, tag: str = "ERROR", url_length: int = 50 ): """ Convenience method for logging error status. - + Args: url: The URL being processed error: Error message @@ -223,9 +223,5 @@ class AsyncLogger: level=LogLevel.ERROR, message="{url:.{url_length}}... | Error: {error}", tag=tag, - params={ - "url": url, - "url_length": url_length, - "error": error - } - ) \ No newline at end of file + params={"url": url, "url_length": url_length, "error": error}, + ) diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 9ef19966..a7596a55 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -1,49 +1,54 @@ -import os, sys +import os +import sys import time import warnings -from enum import Enum -from colorama import init, Fore, Back, Style +from colorama import Fore from pathlib import Path -from typing import Optional, List, Union +from typing import Optional, List import json import asyncio + # from contextlib import nullcontext, asynccontextmanager from contextlib import asynccontextmanager -from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult +from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult, RateLimiter from .async_database import async_db_manager -from .chunking_strategy import * -from .content_filter_strategy import * -from .extraction_strategy import * -from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse +from .chunking_strategy import * # noqa: F403 +from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking +from .content_filter_strategy import * # noqa: F403 +from .content_filter_strategy import RelevantContentFilter +from .extraction_strategy import * # noqa: F403 +from .extraction_strategy import NoExtractionStrategy, ExtractionStrategy +from .async_crawler_strategy import ( + AsyncCrawlerStrategy, + AsyncPlaywrightCrawlerStrategy, + AsyncCrawlResponse, +) from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode -from .markdown_generation_strategy import DefaultMarkdownGenerator, MarkdownGenerationStrategy -from .content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy +from .markdown_generation_strategy import ( + DefaultMarkdownGenerator, + MarkdownGenerationStrategy, +) from .async_logger import AsyncLogger from .async_configs import BrowserConfig, CrawlerRunConfig -from .async_dispatcher import * +from .async_dispatcher import * # noqa: F403 +from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher -from .config import ( - MIN_WORD_THRESHOLD, - IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, - URL_LOG_SHORTEN_LENGTH -) +from .config import MIN_WORD_THRESHOLD from .utils import ( sanitize_input_encode, InvalidCSSSelectorError, - format_html, fast_format_html, - create_box_message + create_box_message, + get_error_context, ) -from urllib.parse import urlparse -import random from .__version__ import __version__ as crawl4ai_version class AsyncWebCrawler: """ Asynchronous web crawler with flexible caching capabilities. - + There are two ways to use the crawler: 1. Using context manager (recommended for simple cases): @@ -56,23 +61,23 @@ class AsyncWebCrawler: ```python crawler = AsyncWebCrawler() await crawler.start() - + # Use the crawler multiple times result1 = await crawler.arun(url="https://example.com") result2 = await crawler.arun(url="https://another.com") - + await crawler.close() ``` - + Migration Guide: Old way (deprecated): crawler = AsyncWebCrawler(always_by_pass_cache=True, browser_type="chromium", headless=True) - + New way (recommended): browser_config = BrowserConfig(browser_type="chromium", headless=True) crawler = AsyncWebCrawler(config=browser_config) - - + + Attributes: browser_config (BrowserConfig): Configuration object for browser settings. crawler_strategy (AsyncCrawlerStrategy): Strategy for crawling web pages. @@ -81,7 +86,7 @@ class AsyncWebCrawler: crawl4ai_folder (str): Directory for storing cache. base_directory (str): Base directory for storing cache. ready (bool): Whether the crawler is ready for use. - + Methods: start(): Start the crawler explicitly without using context manager. close(): Close the crawler explicitly without using context manager. @@ -89,21 +94,22 @@ class AsyncWebCrawler: awarmup(): Perform warmup sequence. arun_many(): Run the crawler for multiple sources. aprocess_html(): Process HTML content. - + Typical Usage: async with AsyncWebCrawler() as crawler: result = await crawler.arun(url="https://example.com") print(result.markdown) - + Using configuration: browser_config = BrowserConfig(browser_type="chromium", headless=True) async with AsyncWebCrawler(config=browser_config) as crawler: crawler_config = CrawlerRunConfig( - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS ) result = await crawler.arun(url="https://example.com", config=crawler_config) print(result.markdown) """ + _domain_last_hit = {} def __init__( @@ -127,43 +133,48 @@ class AsyncWebCrawler: base_directory: Base directory for storing cache thread_safe: Whether to use thread-safe operations **kwargs: Additional arguments for backwards compatibility - """ + """ # Handle browser configuration browser_config = config if browser_config is not None: - if any(k in kwargs for k in ["browser_type", "headless", "viewport_width", "viewport_height"]): + if any( + k in kwargs + for k in [ + "browser_type", + "headless", + "viewport_width", + "viewport_height", + ] + ): self.logger.warning( message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.", - tag="WARNING" + tag="WARNING", ) else: # Create browser config from kwargs for backwards compatibility browser_config = BrowserConfig.from_kwargs(kwargs) self.browser_config = browser_config - + # Initialize logger first since other components may need it self.logger = AsyncLogger( log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"), - verbose=self.browser_config.verbose, - tag_width=10 + verbose=self.browser_config.verbose, + tag_width=10, ) - # Initialize crawler strategy - params = { - k:v for k, v in kwargs.items() if k in ['browser_congig', 'logger'] - } + params = {k: v for k, v in kwargs.items() if k in ["browser_congig", "logger"]} self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy( browser_config=browser_config, logger=self.logger, - **params # Pass remaining kwargs for backwards compatibility + **params, # Pass remaining kwargs for backwards compatibility ) - + # If craweler strategy doesnt have logger, use crawler logger if not self.crawler_strategy.logger: self.crawler_strategy.logger = self.logger - + # Handle deprecated cache parameter if always_by_pass_cache is not None: if kwargs.get("warning", True): @@ -172,7 +183,7 @@ class AsyncWebCrawler: "Use 'always_bypass_cache' instead. " "Pass warning=False to suppress this warning.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) self.always_bypass_cache = always_by_pass_cache else: @@ -180,24 +191,24 @@ class AsyncWebCrawler: # Thread safety setup self._lock = asyncio.Lock() if thread_safe else None - + # Initialize directories self.crawl4ai_folder = os.path.join(base_directory, ".crawl4ai") os.makedirs(self.crawl4ai_folder, exist_ok=True) os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) - + self.ready = False async def start(self): """ Start the crawler explicitly without using context manager. This is equivalent to using 'async with' but gives more control over the lifecycle. - + This method will: 1. Initialize the browser and context 2. Perform warmup sequence 3. Return the crawler instance for method chaining - + Returns: AsyncWebCrawler: The initialized crawler instance """ @@ -209,7 +220,7 @@ class AsyncWebCrawler: """ Close the crawler explicitly without using context manager. This should be called when you're done with the crawler if you used start(). - + This method will: 1. Clean up browser resources 2. Close any open pages and contexts @@ -221,11 +232,11 @@ class AsyncWebCrawler: async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() - + async def awarmup(self): """ Initialize the crawler with warm-up sequence. - + This method: 1. Logs initialization info 2. Sets up browser configuration @@ -240,548 +251,553 @@ class AsyncWebCrawler: yield async def arun( - self, - url: str, - config: Optional[CrawlerRunConfig] = None, - # Legacy parameters maintained for backwards compatibility - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - content_filter: RelevantContentFilter = None, - cache_mode: Optional[CacheMode] = None, - # Deprecated cache parameters - bypass_cache: bool = False, - disable_cache: bool = False, - no_cache_read: bool = False, - no_cache_write: bool = False, - # Other legacy parameters - css_selector: str = None, - screenshot: bool = False, - pdf: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> CrawlResult: - """ - Runs the crawler for a single source: URL (web, local file, or raw HTML). + self, + url: str, + config: Optional[CrawlerRunConfig] = None, + # Legacy parameters maintained for backwards compatibility + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = RegexChunking(), + content_filter: RelevantContentFilter = None, + cache_mode: Optional[CacheMode] = None, + # Deprecated cache parameters + bypass_cache: bool = False, + disable_cache: bool = False, + no_cache_read: bool = False, + no_cache_write: bool = False, + # Other legacy parameters + css_selector: str = None, + screenshot: bool = False, + pdf: bool = False, + user_agent: str = None, + verbose=True, + **kwargs, + ) -> CrawlResult: + """ + Runs the crawler for a single source: URL (web, local file, or raw HTML). - Migration Guide: - Old way (deprecated): - result = await crawler.arun( - url="https://example.com", - word_count_threshold=200, - screenshot=True, - ... - ) + Migration Guide: + Old way (deprecated): + result = await crawler.arun( + url="https://example.com", + word_count_threshold=200, + screenshot=True, + ... + ) - New way (recommended): - config = CrawlerRunConfig( - word_count_threshold=200, - screenshot=True, - ... - ) - result = await crawler.arun(url="https://example.com", crawler_config=config) + New way (recommended): + config = CrawlerRunConfig( + word_count_threshold=200, + screenshot=True, + ... + ) + result = await crawler.arun(url="https://example.com", crawler_config=config) - Args: - url: The URL to crawl (http://, https://, file://, or raw:) - crawler_config: Configuration object controlling crawl behavior - [other parameters maintained for backwards compatibility] + Args: + url: The URL to crawl (http://, https://, file://, or raw:) + crawler_config: Configuration object controlling crawl behavior + [other parameters maintained for backwards compatibility] - Returns: - CrawlResult: The result of crawling and processing - """ - crawler_config = config - if not isinstance(url, str) or not url: - raise ValueError("Invalid URL, make sure the URL is a non-empty string") - - async with self._lock or self.nullcontext(): - try: - # Handle configuration - if crawler_config is not None: - # if any(param is not None for param in [ - # word_count_threshold, extraction_strategy, chunking_strategy, - # content_filter, cache_mode, css_selector, screenshot, pdf - # ]): - # self.logger.warning( - # message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.", - # tag="WARNING" - # ) - config = crawler_config - else: - # Merge all parameters into a single kwargs dict for config creation - config_kwargs = { - "word_count_threshold": word_count_threshold, - "extraction_strategy": extraction_strategy, - "chunking_strategy": chunking_strategy, - "content_filter": content_filter, - "cache_mode": cache_mode, - "bypass_cache": bypass_cache, - "disable_cache": disable_cache, - "no_cache_read": no_cache_read, - "no_cache_write": no_cache_write, - "css_selector": css_selector, - "screenshot": screenshot, - "pdf": pdf, - "verbose": verbose, - **kwargs - } - config = CrawlerRunConfig.from_kwargs(config_kwargs) + Returns: + CrawlResult: The result of crawling and processing + """ + crawler_config = config + if not isinstance(url, str) or not url: + raise ValueError("Invalid URL, make sure the URL is a non-empty string") - # Handle deprecated cache parameters - if any([bypass_cache, disable_cache, no_cache_read, no_cache_write]): - if kwargs.get("warning", True): - warnings.warn( - "Cache control boolean flags are deprecated and will be removed in version 0.5.0. " - "Use 'cache_mode' parameter instead.", - DeprecationWarning, - stacklevel=2 - ) - - # Convert legacy parameters if cache_mode not provided - if config.cache_mode is None: - config.cache_mode = _legacy_to_cache_mode( - disable_cache=disable_cache, - bypass_cache=bypass_cache, - no_cache_read=no_cache_read, - no_cache_write=no_cache_write - ) - - # Default to ENABLED if no cache mode specified + async with self._lock or self.nullcontext(): + try: + # Handle configuration + if crawler_config is not None: + # if any(param is not None for param in [ + # word_count_threshold, extraction_strategy, chunking_strategy, + # content_filter, cache_mode, css_selector, screenshot, pdf + # ]): + # self.logger.warning( + # message="Both crawler_config and legacy parameters provided. crawler_config will take precedence.", + # tag="WARNING" + # ) + config = crawler_config + else: + # Merge all parameters into a single kwargs dict for config creation + config_kwargs = { + "word_count_threshold": word_count_threshold, + "extraction_strategy": extraction_strategy, + "chunking_strategy": chunking_strategy, + "content_filter": content_filter, + "cache_mode": cache_mode, + "bypass_cache": bypass_cache, + "disable_cache": disable_cache, + "no_cache_read": no_cache_read, + "no_cache_write": no_cache_write, + "css_selector": css_selector, + "screenshot": screenshot, + "pdf": pdf, + "verbose": verbose, + **kwargs, + } + config = CrawlerRunConfig.from_kwargs(config_kwargs) + + # Handle deprecated cache parameters + if any([bypass_cache, disable_cache, no_cache_read, no_cache_write]): + if kwargs.get("warning", True): + warnings.warn( + "Cache control boolean flags are deprecated and will be removed in version 0.5.0. " + "Use 'cache_mode' parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + + # Convert legacy parameters if cache_mode not provided if config.cache_mode is None: - config.cache_mode = CacheMode.ENABLED - - # Create cache context - cache_context = CacheContext(url, config.cache_mode, self.always_bypass_cache) - - # Initialize processing variables - async_response: AsyncCrawlResponse = None - cached_result: CrawlResult = None - screenshot_data = None - pdf_data = None - extracted_content = None - start_time = time.perf_counter() - - # Try to get cached result if appropriate - if cache_context.should_read(): - cached_result = await async_db_manager.aget_cached_url(url) - - if cached_result: - html = sanitize_input_encode(cached_result.html) - extracted_content = sanitize_input_encode(cached_result.extracted_content or "") - extracted_content = None if not extracted_content or extracted_content == "[]" else extracted_content - # If screenshot is requested but its not in cache, then set cache_result to None - screenshot_data = cached_result.screenshot - pdf_data = cached_result.pdf - if config.screenshot and not screenshot or config.pdf and not pdf: - cached_result = None - - self.logger.url_status( - url=cache_context.display_url, - success=bool(html), - timing=time.perf_counter() - start_time, - tag="FETCH" + config.cache_mode = _legacy_to_cache_mode( + disable_cache=disable_cache, + bypass_cache=bypass_cache, + no_cache_read=no_cache_read, + no_cache_write=no_cache_write, ) - # Fetch fresh content if needed - if not cached_result or not html: - t1 = time.perf_counter() - - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - - # Pass config to crawl method - async_response = await self.crawler_strategy.crawl( - url, - config=config # Pass the entire config object - ) - - html = sanitize_input_encode(async_response.html) - screenshot_data = async_response.screenshot - pdf_data = async_response.pdf_data - - t2 = time.perf_counter() - self.logger.url_status( - url=cache_context.display_url, - success=bool(html), - timing=t2 - t1, - tag="FETCH" - ) + # Default to ENABLED if no cache mode specified + if config.cache_mode is None: + config.cache_mode = CacheMode.ENABLED - # Process the HTML content - crawl_result = await self.aprocess_html( - url=url, - html=html, - extracted_content=extracted_content, - config=config, # Pass the config object instead of individual parameters - screenshot=screenshot_data, - pdf_data=pdf_data, - verbose=config.verbose, - is_raw_html = True if url.startswith("raw:") else False, - **kwargs - ) + # Create cache context + cache_context = CacheContext( + url, config.cache_mode, self.always_bypass_cache + ) - crawl_result.status_code = async_response.status_code - crawl_result.response_headers = async_response.response_headers - crawl_result.downloaded_files = async_response.downloaded_files - crawl_result.ssl_certificate = async_response.ssl_certificate # Add SSL certificate + # Initialize processing variables + async_response: AsyncCrawlResponse = None + cached_result: CrawlResult = None + screenshot_data = None + pdf_data = None + extracted_content = None + start_time = time.perf_counter() - # # Check and set values from async_response to crawl_result - # try: - # for key in vars(async_response): - # if hasattr(crawl_result, key): - # value = getattr(async_response, key, None) - # current_value = getattr(crawl_result, key, None) - # if value is not None and not current_value: - # try: - # setattr(crawl_result, key, value) - # except Exception as e: - # self.logger.warning( - # message=f"Failed to set attribute {key}: {str(e)}", - # tag="WARNING" - # ) - # except Exception as e: - # self.logger.warning( - # message=f"Error copying response attributes: {str(e)}", - # tag="WARNING" - # ) + # Try to get cached result if appropriate + if cache_context.should_read(): + cached_result = await async_db_manager.aget_cached_url(url) - crawl_result.success = bool(html) - crawl_result.session_id = getattr(config, 'session_id', None) - - self.logger.success( - message="{url:.50}... | Status: {status} | Total: {timing}", - tag="COMPLETE", - params={ - "url": cache_context.display_url, - "status": crawl_result.success, - "timing": f"{time.perf_counter() - start_time:.2f}s" - }, - colors={ - "status": Fore.GREEN if crawl_result.success else Fore.RED, - "timing": Fore.YELLOW - } - ) - - # Update cache if appropriate - if cache_context.should_write() and not bool(cached_result): - await async_db_manager.acache_url(crawl_result) - - return crawl_result - - else: - self.logger.success( - message="{url:.50}... | Status: {status} | Total: {timing}", - tag="COMPLETE", - params={ - "url": cache_context.display_url, - "status": True, - "timing": f"{time.perf_counter() - start_time:.2f}s" - }, - colors={ - "status": Fore.GREEN, - "timing": Fore.YELLOW - } - ) - - cached_result.success = bool(html) - cached_result.session_id = getattr(config, 'session_id', None) - return cached_result - - except Exception as e: - error_context = get_error_context(sys.exc_info()) - - error_message = ( - f"Unexpected error in _crawl_web at line {error_context['line_no']} " - f"in {error_context['function']} ({error_context['filename']}):\n" - f"Error: {str(e)}\n\n" - f"Code context:\n{error_context['code_context']}" + if cached_result: + html = sanitize_input_encode(cached_result.html) + extracted_content = sanitize_input_encode( + cached_result.extracted_content or "" ) - # if not hasattr(e, "msg"): - # e.msg = str(e) - - self.logger.error_status( + extracted_content = ( + None + if not extracted_content or extracted_content == "[]" + else extracted_content + ) + # If screenshot is requested but its not in cache, then set cache_result to None + screenshot_data = cached_result.screenshot + pdf_data = cached_result.pdf + if config.screenshot and not screenshot or config.pdf and not pdf: + cached_result = None + + self.logger.url_status( + url=cache_context.display_url, + success=bool(html), + timing=time.perf_counter() - start_time, + tag="FETCH", + ) + + # Fetch fresh content if needed + if not cached_result or not html: + t1 = time.perf_counter() + + if user_agent: + self.crawler_strategy.update_user_agent(user_agent) + + # Pass config to crawl method + async_response = await self.crawler_strategy.crawl( + url, + config=config, # Pass the entire config object + ) + + html = sanitize_input_encode(async_response.html) + screenshot_data = async_response.screenshot + pdf_data = async_response.pdf_data + + t2 = time.perf_counter() + self.logger.url_status( + url=cache_context.display_url, + success=bool(html), + timing=t2 - t1, + tag="FETCH", + ) + + # Process the HTML content + crawl_result = await self.aprocess_html( url=url, - error=create_box_message(error_message, type="error"), - tag="ERROR" + html=html, + extracted_content=extracted_content, + config=config, # Pass the config object instead of individual parameters + screenshot=screenshot_data, + pdf_data=pdf_data, + verbose=config.verbose, + is_raw_html=True if url.startswith("raw:") else False, + **kwargs, ) - - return CrawlResult( - url=url, - html="", - success=False, - error_message=error_message + + crawl_result.status_code = async_response.status_code + crawl_result.response_headers = async_response.response_headers + crawl_result.downloaded_files = async_response.downloaded_files + crawl_result.ssl_certificate = ( + async_response.ssl_certificate + ) # Add SSL certificate + + # # Check and set values from async_response to crawl_result + # try: + # for key in vars(async_response): + # if hasattr(crawl_result, key): + # value = getattr(async_response, key, None) + # current_value = getattr(crawl_result, key, None) + # if value is not None and not current_value: + # try: + # setattr(crawl_result, key, value) + # except Exception as e: + # self.logger.warning( + # message=f"Failed to set attribute {key}: {str(e)}", + # tag="WARNING" + # ) + # except Exception as e: + # self.logger.warning( + # message=f"Error copying response attributes: {str(e)}", + # tag="WARNING" + # ) + + crawl_result.success = bool(html) + crawl_result.session_id = getattr(config, "session_id", None) + + self.logger.success( + message="{url:.50}... | Status: {status} | Total: {timing}", + tag="COMPLETE", + params={ + "url": cache_context.display_url, + "status": crawl_result.success, + "timing": f"{time.perf_counter() - start_time:.2f}s", + }, + colors={ + "status": Fore.GREEN if crawl_result.success else Fore.RED, + "timing": Fore.YELLOW, + }, ) + # Update cache if appropriate + if cache_context.should_write() and not bool(cached_result): + await async_db_manager.acache_url(crawl_result) + + return crawl_result + + else: + self.logger.success( + message="{url:.50}... | Status: {status} | Total: {timing}", + tag="COMPLETE", + params={ + "url": cache_context.display_url, + "status": True, + "timing": f"{time.perf_counter() - start_time:.2f}s", + }, + colors={"status": Fore.GREEN, "timing": Fore.YELLOW}, + ) + + cached_result.success = bool(html) + cached_result.session_id = getattr(config, "session_id", None) + return cached_result + + except Exception as e: + error_context = get_error_context(sys.exc_info()) + + error_message = ( + f"Unexpected error in _crawl_web at line {error_context['line_no']} " + f"in {error_context['function']} ({error_context['filename']}):\n" + f"Error: {str(e)}\n\n" + f"Code context:\n{error_context['code_context']}" + ) + # if not hasattr(e, "msg"): + # e.msg = str(e) + + self.logger.error_status( + url=url, + error=create_box_message(error_message, type="error"), + tag="ERROR", + ) + + return CrawlResult( + url=url, html="", success=False, error_message=error_message + ) + async def aprocess_html( - self, - url: str, - html: str, - extracted_content: str, - config: CrawlerRunConfig, - screenshot: str, - pdf_data: str, - verbose: bool, - **kwargs, - ) -> CrawlResult: - """ - Process HTML content using the provided configuration. - - Args: - url: The URL being processed - html: Raw HTML content - extracted_content: Previously extracted content (if any) - config: Configuration object controlling processing behavior - screenshot: Screenshot data (if any) - pdf_data: PDF data (if any) - verbose: Whether to enable verbose logging - **kwargs: Additional parameters for backwards compatibility - - Returns: - CrawlResult: Processed result containing extracted and formatted content - """ - try: - _url = url if not kwargs.get("is_raw_html", False) else "Raw HTML" - t1 = time.perf_counter() + self, + url: str, + html: str, + extracted_content: str, + config: CrawlerRunConfig, + screenshot: str, + pdf_data: str, + verbose: bool, + **kwargs, + ) -> CrawlResult: + """ + Process HTML content using the provided configuration. - # Get scraping strategy and ensure it has a logger - scraping_strategy = config.scraping_strategy - if not scraping_strategy.logger: - scraping_strategy.logger = self.logger + Args: + url: The URL being processed + html: Raw HTML content + extracted_content: Previously extracted content (if any) + config: Configuration object controlling processing behavior + screenshot: Screenshot data (if any) + pdf_data: PDF data (if any) + verbose: Whether to enable verbose logging + **kwargs: Additional parameters for backwards compatibility - # Process HTML content - params = {k:v for k, v in config.to_dict().items() if k not in ["url"]} - # add keys from kwargs to params that doesn't exist in params - params.update({k:v for k, v in kwargs.items() if k not in params.keys()}) - - result = scraping_strategy.scrap( - url, - html, - **params + Returns: + CrawlResult: Processed result containing extracted and formatted content + """ + try: + _url = url if not kwargs.get("is_raw_html", False) else "Raw HTML" + t1 = time.perf_counter() + + # Get scraping strategy and ensure it has a logger + scraping_strategy = config.scraping_strategy + if not scraping_strategy.logger: + scraping_strategy.logger = self.logger + + # Process HTML content + params = {k: v for k, v in config.to_dict().items() if k not in ["url"]} + # add keys from kwargs to params that doesn't exist in params + params.update({k: v for k, v in kwargs.items() if k not in params.keys()}) + + result = scraping_strategy.scrap(url, html, **params) + + if result is None: + raise ValueError( + f"Process HTML, Failed to extract content from the website: {url}" ) - if result is None: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}") + except InvalidCSSSelectorError as e: + raise ValueError(str(e)) + except Exception as e: + raise ValueError( + f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}" + ) - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - except Exception as e: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") + # Extract results - handle both dict and ScrapingResult + if isinstance(result, dict): + cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) + media = result.get("media", {}) + links = result.get("links", {}) + metadata = result.get("metadata", {}) + else: + cleaned_html = sanitize_input_encode(result.cleaned_html) + media = result.media.model_dump() + links = result.links.model_dump() + metadata = result.metadata - + # Markdown Generation + markdown_generator: Optional[MarkdownGenerationStrategy] = ( + config.markdown_generator or DefaultMarkdownGenerator() + ) - # Extract results - handle both dict and ScrapingResult - if isinstance(result, dict): - cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) - media = result.get("media", {}) - links = result.get("links", {}) - metadata = result.get("metadata", {}) - else: - cleaned_html = sanitize_input_encode(result.cleaned_html) - media = result.media.model_dump() - links = result.links.model_dump() - metadata = result.metadata + # Uncomment if by default we want to use PruningContentFilter + # if not config.content_filter and not markdown_generator.content_filter: + # markdown_generator.content_filter = PruningContentFilter() - # Markdown Generation - markdown_generator: Optional[MarkdownGenerationStrategy] = config.markdown_generator or DefaultMarkdownGenerator() - - # Uncomment if by default we want to use PruningContentFilter - # if not config.content_filter and not markdown_generator.content_filter: - # markdown_generator.content_filter = PruningContentFilter() - - markdown_result: MarkdownGenerationResult = markdown_generator.generate_markdown( + markdown_result: MarkdownGenerationResult = ( + markdown_generator.generate_markdown( cleaned_html=cleaned_html, base_url=url, # html2text_options=kwargs.get('html2text', {}) ) - markdown_v2 = markdown_result - markdown = sanitize_input_encode(markdown_result.raw_markdown) + ) + markdown_v2 = markdown_result + markdown = sanitize_input_encode(markdown_result.raw_markdown) - # Log processing completion - self.logger.info( - message="Processed {url:.50}... | Time: {timing}ms", - tag="SCRAPE", - params={ - "url": _url, - "timing": int((time.perf_counter() - t1) * 1000) - } + # Log processing completion + self.logger.info( + message="Processed {url:.50}... | Time: {timing}ms", + tag="SCRAPE", + params={"url": _url, "timing": int((time.perf_counter() - t1) * 1000)}, + ) + + # Handle content extraction if needed + if ( + not bool(extracted_content) + and config.extraction_strategy + and not isinstance(config.extraction_strategy, NoExtractionStrategy) + ): + t1 = time.perf_counter() + + # Choose content based on input_format + content_format = config.extraction_strategy.input_format + if content_format == "fit_markdown" and not markdown_result.fit_markdown: + self.logger.warning( + message="Fit markdown requested but not available. Falling back to raw markdown.", + tag="EXTRACT", + params={"url": _url}, + ) + content_format = "markdown" + + content = { + "markdown": markdown, + "html": html, + "fit_markdown": markdown_result.raw_markdown, + }.get(content_format, markdown) + + # Use IdentityChunking for HTML input, otherwise use provided chunking strategy + chunking = ( + IdentityChunking() + if content_format == "html" + else config.chunking_strategy + ) + sections = chunking.chunk(content) + extracted_content = config.extraction_strategy.run(url, sections) + extracted_content = json.dumps( + extracted_content, indent=4, default=str, ensure_ascii=False ) - # Handle content extraction if needed - if (not bool(extracted_content) and config.extraction_strategy and not isinstance(config.extraction_strategy, NoExtractionStrategy)): - - t1 = time.perf_counter() - - # Choose content based on input_format - content_format = config.extraction_strategy.input_format - if content_format == "fit_markdown" and not markdown_result.fit_markdown: - self.logger.warning( - message="Fit markdown requested but not available. Falling back to raw markdown.", - tag="EXTRACT", - params={"url": _url} - ) - content_format = "markdown" + # Log extraction completion + self.logger.info( + message="Completed for {url:.50}... | Time: {timing}s", + tag="EXTRACT", + params={"url": _url, "timing": time.perf_counter() - t1}, + ) - content = { - "markdown": markdown, - "html": html, - "fit_markdown": markdown_result.raw_markdown - }.get(content_format, markdown) - - # Use IdentityChunking for HTML input, otherwise use provided chunking strategy - chunking = IdentityChunking() if content_format == "html" else config.chunking_strategy - sections = chunking.chunk(content) - extracted_content = config.extraction_strategy.run(url, sections) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) + # Handle screenshot and PDF data + screenshot_data = None if not screenshot else screenshot + pdf_data = None if not pdf_data else pdf_data - # Log extraction completion - self.logger.info( - message="Completed for {url:.50}... | Time: {timing}s", - tag="EXTRACT", - params={ - "url": _url, - "timing": time.perf_counter() - t1 - } - ) + # Apply HTML formatting if requested + if config.prettiify: + cleaned_html = fast_format_html(cleaned_html) - # Handle screenshot and PDF data - screenshot_data = None if not screenshot else screenshot - pdf_data = None if not pdf_data else pdf_data - - # Apply HTML formatting if requested - if config.prettiify: - cleaned_html = fast_format_html(cleaned_html) - - # Return complete crawl result - return CrawlResult( - url=url, - html=html, - cleaned_html=cleaned_html, - markdown_v2=markdown_v2, - markdown=markdown, - fit_markdown=markdown_result.fit_markdown, - fit_html=markdown_result.fit_html, - media=media, - links=links, - metadata=metadata, - screenshot=screenshot_data, - pdf=pdf_data, - extracted_content=extracted_content, - success=True, - error_message="", - ) + # Return complete crawl result + return CrawlResult( + url=url, + html=html, + cleaned_html=cleaned_html, + markdown_v2=markdown_v2, + markdown=markdown, + fit_markdown=markdown_result.fit_markdown, + fit_html=markdown_result.fit_html, + media=media, + links=links, + metadata=metadata, + screenshot=screenshot_data, + pdf=pdf_data, + extracted_content=extracted_content, + success=True, + error_message="", + ) async def arun_many( - self, - urls: List[str], - config: Optional[CrawlerRunConfig] = None, - dispatcher: Optional[BaseDispatcher] = None, - # Legacy parameters maintained for backwards compatibility - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - content_filter: RelevantContentFilter = None, - cache_mode: Optional[CacheMode] = None, - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - pdf: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> List[CrawlResult]: - """ - Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy. + self, + urls: List[str], + config: Optional[CrawlerRunConfig] = None, + dispatcher: Optional[BaseDispatcher] = None, + # Legacy parameters maintained for backwards compatibility + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = RegexChunking(), + content_filter: RelevantContentFilter = None, + cache_mode: Optional[CacheMode] = None, + bypass_cache: bool = False, + css_selector: str = None, + screenshot: bool = False, + pdf: bool = False, + user_agent: str = None, + verbose=True, + **kwargs, + ) -> List[CrawlResult]: + """ + Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy. - Migration Guide: - Old way (deprecated): - results = await crawler.arun_many( - urls, - word_count_threshold=200, - screenshot=True, - ... - ) - - New way (recommended): - config = CrawlerRunConfig( - word_count_threshold=200, - screenshot=True, - dispatcher_config=DispatcherConfig( - enable_rate_limiting=True, - rate_limit_config=RateLimitConfig(...), - ), - ... - ) - results = await crawler.arun_many( - urls, - config=config, - dispatcher_strategy=MemoryAdaptiveDispatcher # Optional, this is the default - ) - - Args: - urls: List of URLs to crawl - config: Configuration object controlling crawl behavior for all URLs - dispatcher_strategy: The dispatcher strategy class to use. Defaults to MemoryAdaptiveDispatcher. - [other parameters maintained for backwards compatibility] - - Returns: - List[CrawlResult]: Results for each URL - """ - # Create config if not provided - if config is None: - config = CrawlerRunConfig( - word_count_threshold=word_count_threshold, - extraction_strategy=extraction_strategy, - chunking_strategy=chunking_strategy, - content_filter=content_filter, - cache_mode=cache_mode, - bypass_cache=bypass_cache, - css_selector=css_selector, - screenshot=screenshot, - pdf=pdf, - verbose=verbose, - **kwargs - ) - - # # Initialize the dispatcher with the selected strategy - # dispatcher = dispatcher_strategy(self, config.dispatcher_config) - - # memory_monitor: CrawlerMonitor = None - # if config.dispatcher_config.enable_monitor: - # memory_monitor = CrawlerMonitor(max_visible_rows=config.dispatcher_config.max_display_rows, display_mode=config.dispatcher_config.display_mode) - - # Create default dispatcher if none provided - if dispatcher is None: - dispatcher = MemoryAdaptiveDispatcher( - self, - rate_limiter=RateLimiter( - base_delay=(1.0, 3.0), - max_delay=60.0, - max_retries=3 - ) - ) - - # Run the URLs through the dispatcher - _results: List[CrawlerTaskResult] = await dispatcher.run_urls( - crawler=self, - urls=urls, - config=config + Migration Guide: + Old way (deprecated): + results = await crawler.arun_many( + urls, + word_count_threshold=200, + screenshot=True, + ... ) - - results: CrawlResult = [] - for res in _results: - _res : CrawlResult = res.result - dispatch_result: DispatchResult = DispatchResult( - task_id=res.task_id, - memory_usage=res.memory_usage, - peak_memory=res.peak_memory, - start_time=res.start_time, - end_time=res.end_time, - error_message=res.error_message - ) - _res.dispatch_result = dispatch_result - results.append(_res) - - return results + + New way (recommended): + config = CrawlerRunConfig( + word_count_threshold=200, + screenshot=True, + dispatcher_config=DispatcherConfig( + enable_rate_limiting=True, + rate_limit_config=RateLimitConfig(...), + ), + ... + ) + results = await crawler.arun_many( + urls, + config=config, + dispatcher_strategy=MemoryAdaptiveDispatcher # Optional, this is the default + ) + + Args: + urls: List of URLs to crawl + config: Configuration object controlling crawl behavior for all URLs + dispatcher_strategy: The dispatcher strategy class to use. Defaults to MemoryAdaptiveDispatcher. + [other parameters maintained for backwards compatibility] + + Returns: + List[CrawlResult]: Results for each URL + """ + # Create config if not provided + if config is None: + config = CrawlerRunConfig( + word_count_threshold=word_count_threshold, + extraction_strategy=extraction_strategy, + chunking_strategy=chunking_strategy, + content_filter=content_filter, + cache_mode=cache_mode, + bypass_cache=bypass_cache, + css_selector=css_selector, + screenshot=screenshot, + pdf=pdf, + verbose=verbose, + **kwargs, + ) + + # # Initialize the dispatcher with the selected strategy + # dispatcher = dispatcher_strategy(self, config.dispatcher_config) + + # memory_monitor: CrawlerMonitor = None + # if config.dispatcher_config.enable_monitor: + # memory_monitor = CrawlerMonitor(max_visible_rows=config.dispatcher_config.max_display_rows, display_mode=config.dispatcher_config.display_mode) + + # Create default dispatcher if none provided + if dispatcher is None: + dispatcher = MemoryAdaptiveDispatcher( + self, + rate_limiter=RateLimiter( + base_delay=(1.0, 3.0), max_delay=60.0, max_retries=3 + ), + ) + + # Run the URLs through the dispatcher + _results: List[CrawlerTaskResult] = await dispatcher.run_urls( + crawler=self, urls=urls, config=config + ) + + results: CrawlResult = [] + for res in _results: + _res: CrawlResult = res.result + dispatch_result: DispatchResult = DispatchResult( + task_id=res.task_id, + memory_usage=res.memory_usage, + peak_memory=res.peak_memory, + start_time=res.start_time, + end_time=res.end_time, + error_message=res.error_message, + ) + _res.dispatch_result = dispatch_result + results.append(_res) + + return results async def aclear_cache(self): """Clear the cache database.""" diff --git a/crawl4ai/cache_context.py b/crawl4ai/cache_context.py index 588edd62..75914b5b 100644 --- a/crawl4ai/cache_context.py +++ b/crawl4ai/cache_context.py @@ -4,7 +4,7 @@ from enum import Enum class CacheMode(Enum): """ Defines the caching behavior for web crawling operations. - + Modes: - ENABLED: Normal caching behavior (read and write) - DISABLED: No caching at all @@ -12,6 +12,7 @@ class CacheMode(Enum): - WRITE_ONLY: Only write to cache, don't read - BYPASS: Bypass cache for this operation """ + ENABLED = "enabled" DISABLED = "disabled" READ_ONLY = "read_only" @@ -22,10 +23,10 @@ class CacheMode(Enum): class CacheContext: """ Encapsulates cache-related decisions and URL handling. - + This class centralizes all cache-related logic and URL type checking, making the caching behavior more predictable and maintainable. - + Attributes: url (str): The URL being processed. cache_mode (CacheMode): The cache mode for the current operation. @@ -36,10 +37,11 @@ class CacheContext: is_raw_html (bool): True if the URL is raw HTML, False otherwise. _url_display (str): The display name for the URL (web, local file, or raw HTML). """ + def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False): """ Initializes the CacheContext with the provided URL and cache mode. - + Args: url (str): The URL being processed. cache_mode (CacheMode): The cache mode for the current operation. @@ -48,42 +50,42 @@ class CacheContext: self.url = url self.cache_mode = cache_mode self.always_bypass = always_bypass - self.is_cacheable = url.startswith(('http://', 'https://', 'file://')) - self.is_web_url = url.startswith(('http://', 'https://')) + self.is_cacheable = url.startswith(("http://", "https://", "file://")) + self.is_web_url = url.startswith(("http://", "https://")) self.is_local_file = url.startswith("file://") self.is_raw_html = url.startswith("raw:") self._url_display = url if not self.is_raw_html else "Raw HTML" - + def should_read(self) -> bool: """ Determines if cache should be read based on context. - + How it works: 1. If always_bypass is True or is_cacheable is False, return False. 2. If cache_mode is ENABLED or READ_ONLY, return True. - + Returns: bool: True if cache should be read, False otherwise. """ if self.always_bypass or not self.is_cacheable: return False return self.cache_mode in [CacheMode.ENABLED, CacheMode.READ_ONLY] - + def should_write(self) -> bool: """ Determines if cache should be written based on context. - + How it works: 1. If always_bypass is True or is_cacheable is False, return False. 2. If cache_mode is ENABLED or WRITE_ONLY, return True. - + Returns: bool: True if cache should be written, False otherwise. """ if self.always_bypass or not self.is_cacheable: return False return self.cache_mode in [CacheMode.ENABLED, CacheMode.WRITE_ONLY] - + @property def display_url(self) -> str: """Returns the URL in display format.""" @@ -94,11 +96,11 @@ def _legacy_to_cache_mode( disable_cache: bool = False, bypass_cache: bool = False, no_cache_read: bool = False, - no_cache_write: bool = False + no_cache_write: bool = False, ) -> CacheMode: """ Converts legacy cache parameters to the new CacheMode enum. - + This is an internal function to help transition from the old boolean flags to the new CacheMode system. """ diff --git a/crawl4ai/chunking_strategy.py b/crawl4ai/chunking_strategy.py index 7b8c08ad..ca188d1d 100644 --- a/crawl4ai/chunking_strategy.py +++ b/crawl4ai/chunking_strategy.py @@ -3,49 +3,53 @@ import re from collections import Counter import string from .model_loader import load_nltk_punkt -from .utils import * + # Define the abstract base class for chunking strategies class ChunkingStrategy(ABC): """ Abstract base class for chunking strategies. """ - + @abstractmethod def chunk(self, text: str) -> list: """ Abstract method to chunk the given text. - + Args: text (str): The text to chunk. - + Returns: list: A list of chunks. """ pass + # Create an identity chunking strategy f(x) = [x] class IdentityChunking(ChunkingStrategy): """ Chunking strategy that returns the input text as a single chunk. """ + def chunk(self, text: str) -> list: return [text] + # Regex-based chunking class RegexChunking(ChunkingStrategy): """ Chunking strategy that splits text based on regular expression patterns. """ + def __init__(self, patterns=None, **kwargs): """ Initialize the RegexChunking object. - + Args: patterns (list): A list of regular expression patterns to split text. """ if patterns is None: - patterns = [r'\n\n'] # Default split pattern + patterns = [r"\n\n"] # Default split pattern self.patterns = patterns def chunk(self, text: str) -> list: @@ -56,18 +60,19 @@ class RegexChunking(ChunkingStrategy): new_paragraphs.extend(re.split(pattern, paragraph)) paragraphs = new_paragraphs return paragraphs - -# NLP-based sentence chunking + + +# NLP-based sentence chunking class NlpSentenceChunking(ChunkingStrategy): """ Chunking strategy that splits text into sentences using NLTK's sentence tokenizer. - """ + """ + def __init__(self, **kwargs): """ Initialize the NlpSentenceChunking object. """ load_nltk_punkt() - def chunk(self, text: str) -> list: # Improved regex for sentence splitting @@ -75,31 +80,34 @@ class NlpSentenceChunking(ChunkingStrategy): # r'(? list: # Tokenize and remove stopwords and punctuation import nltk as nl + tokens = nl.toknize.word_tokenize(text) - tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation] + tokens = [ + token.lower() + for token in tokens + if token not in nl.corpus.stopwords.words("english") + and token not in string.punctuation + ] # Calculate frequency distribution freq_dist = Counter(tokens) @@ -123,23 +137,27 @@ class TopicSegmentationChunking(ChunkingStrategy): # Segment the text into topics segments = self.chunk(text) # Extract keywords for each topic segment - segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments] + segments_with_topics = [ + (segment, self.extract_keywords(segment)) for segment in segments + ] return segments_with_topics - + + # Fixed-length word chunks class FixedLengthWordChunking(ChunkingStrategy): """ Chunking strategy that splits text into fixed-length word chunks. - + How it works: 1. Split the text into words 2. Create chunks of fixed length 3. Return the list of chunks """ + def __init__(self, chunk_size=100, **kwargs): """ Initialize the fixed-length word chunking strategy with the given chunk size. - + Args: chunk_size (int): The size of each chunk in words. """ @@ -147,23 +165,28 @@ class FixedLengthWordChunking(ChunkingStrategy): def chunk(self, text: str) -> list: words = text.split() - return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)] - + return [ + " ".join(words[i : i + self.chunk_size]) + for i in range(0, len(words), self.chunk_size) + ] + + # Sliding window chunking class SlidingWindowChunking(ChunkingStrategy): """ Chunking strategy that splits text into overlapping word chunks. - + How it works: 1. Split the text into words 2. Create chunks of fixed length 3. Return the list of chunks """ + def __init__(self, window_size=100, step=50, **kwargs): """ Initialize the sliding window chunking strategy with the given window size and step size. - + Args: window_size (int): The size of the sliding window in words. step (int): The step size for sliding the window in words. @@ -174,35 +197,37 @@ class SlidingWindowChunking(ChunkingStrategy): def chunk(self, text: str) -> list: words = text.split() chunks = [] - + if len(words) <= self.window_size: return [text] - + for i in range(0, len(words) - self.window_size + 1, self.step): - chunk = ' '.join(words[i:i + self.window_size]) + chunk = " ".join(words[i : i + self.window_size]) chunks.append(chunk) - + # Handle the last chunk if it doesn't align perfectly if i + self.window_size < len(words): - chunks.append(' '.join(words[-self.window_size:])) - + chunks.append(" ".join(words[-self.window_size :])) + return chunks - + + class OverlappingWindowChunking(ChunkingStrategy): """ Chunking strategy that splits text into overlapping word chunks. - + How it works: 1. Split the text into words using whitespace 2. Create chunks of fixed length equal to the window size 3. Slide the window by the overlap size 4. Return the list of chunks """ + def __init__(self, window_size=1000, overlap=100, **kwargs): """ Initialize the overlapping window chunking strategy with the given window size and overlap size. - + Args: window_size (int): The size of the window in words. overlap (int): The size of the overlap between consecutive chunks in words. @@ -213,19 +238,19 @@ class OverlappingWindowChunking(ChunkingStrategy): def chunk(self, text: str) -> list: words = text.split() chunks = [] - + if len(words) <= self.window_size: return [text] - + start = 0 while start < len(words): end = start + self.window_size - chunk = ' '.join(words[start:end]) + chunk = " ".join(words[start:end]) chunks.append(chunk) - + if end >= len(words): break - + start = end - self.overlap - - return chunks \ No newline at end of file + + return chunks diff --git a/crawl4ai/cli.py b/crawl4ai/cli.py index 4a01c1c2..b2d2199e 100644 --- a/crawl4ai/cli.py +++ b/crawl4ai/cli.py @@ -8,15 +8,22 @@ from .async_logger import AsyncLogger logger = AsyncLogger(verbose=True) docs_manager = DocsManager(logger) + def print_table(headers: List[str], rows: List[List[str]], padding: int = 2): """Print formatted table with headers and rows""" widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)] - border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+' - + border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+" + def format_row(row): - return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}" - for cell, w in zip(row, widths)) + '|' - + return ( + "|" + + "|".join( + f"{' ' * padding}{str(cell):<{w}}{' ' * padding}" + for cell, w in zip(row, widths) + ) + + "|" + ) + click.echo(border) click.echo(format_row(headers)) click.echo(border) @@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2): click.echo(format_row(row)) click.echo(border) + @click.group() def cli(): """Crawl4AI Command Line Interface""" pass + @cli.group() def docs(): """Documentation operations""" pass + @docs.command() -@click.argument('sections', nargs=-1) -@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended') +@click.argument("sections", nargs=-1) +@click.option( + "--mode", type=click.Choice(["extended", "condensed"]), default="extended" +) def combine(sections: tuple, mode: str): """Combine documentation sections""" try: @@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str): logger.error(str(e), tag="ERROR") sys.exit(1) + @docs.command() -@click.argument('query') -@click.option('--top-k', '-k', default=5) -@click.option('--build-index', is_flag=True, help='Build index if missing') +@click.argument("query") +@click.option("--top-k", "-k", default=5) +@click.option("--build-index", is_flag=True, help="Build index if missing") def search(query: str, top_k: int, build_index: bool): """Search documentation""" try: result = docs_manager.search(query, top_k) if result == "No search index available. Call build_search_index() first.": - if build_index or click.confirm('No search index found. Build it now?'): + if build_index or click.confirm("No search index found. Build it now?"): asyncio.run(docs_manager.llm_text.generate_index_files()) result = docs_manager.search(query, top_k) click.echo(result) @@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool): click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + @docs.command() def update(): """Update docs from GitHub""" @@ -73,22 +87,25 @@ def update(): click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + @docs.command() -@click.option('--force-facts', is_flag=True, help='Force regenerate fact files') -@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache') +@click.option("--force-facts", is_flag=True, help="Force regenerate fact files") +@click.option("--clear-cache", is_flag=True, help="Clear BM25 cache") def index(force_facts: bool, clear_cache: bool): """Build or rebuild search indexes""" try: asyncio.run(docs_manager.ensure_docs_exist()) - asyncio.run(docs_manager.llm_text.generate_index_files( - force_generate_facts=force_facts, - clear_bm25_cache=clear_cache - )) + asyncio.run( + docs_manager.llm_text.generate_index_files( + force_generate_facts=force_facts, clear_bm25_cache=clear_cache + ) + ) click.echo("Search indexes built successfully") except Exception as e: click.echo(f"Error: {str(e)}", err=True) sys.exit(1) + # Add docs list command @docs.command() def list(): @@ -96,10 +113,11 @@ def list(): try: sections = docs_manager.list() print_table(["Sections"], [[section] for section in sections]) - + except Exception as e: click.echo(f"Error: {str(e)}", err=True) sys.exit(1) -if __name__ == '__main__': - cli() \ No newline at end of file + +if __name__ == "__main__": + cli() diff --git a/crawl4ai/config.py b/crawl4ai/config.py index c2be7638..3e26514a 100644 --- a/crawl4ai/config.py +++ b/crawl4ai/config.py @@ -8,7 +8,7 @@ DEFAULT_PROVIDER = "openai/gpt-4o-mini" MODEL_REPO_BRANCH = "new-release-0.0.2" # Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy PROVIDER_MODELS = { - "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token + "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token "groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"), "groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"), "openai/gpt-4o-mini": os.getenv("OPENAI_API_KEY"), @@ -22,27 +22,49 @@ PROVIDER_MODELS = { } # Chunk token threshold -CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens +CHUNK_TOKEN_THRESHOLD = 2**11 # 2048 tokens OVERLAP_RATE = 0.1 WORD_TOKEN_RATE = 1.3 -# Threshold for the minimum number of word in a HTML tag to be considered +# Threshold for the minimum number of word in a HTML tag to be considered MIN_WORD_THRESHOLD = 1 IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1 -IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height'] -ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark'] +IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"] +ONLY_TEXT_ELIGIBLE_TAGS = [ + "b", + "i", + "u", + "span", + "del", + "ins", + "sub", + "sup", + "strong", + "em", + "code", + "kbd", + "var", + "s", + "q", + "abbr", + "cite", + "dfn", + "time", + "small", + "mark", +] SOCIAL_MEDIA_DOMAINS = [ - 'facebook.com', - 'twitter.com', - 'x.com', - 'linkedin.com', - 'instagram.com', - 'pinterest.com', - 'tiktok.com', - 'snapchat.com', - 'reddit.com', - ] + "facebook.com", + "twitter.com", + "x.com", + "linkedin.com", + "instagram.com", + "pinterest.com", + "tiktok.com", + "snapchat.com", + "reddit.com", +] # Threshold for the Image extraction - Range is 1 to 6 # Images are scored based on point based system, to filter based on usefulness. Points are assigned @@ -60,5 +82,5 @@ NEED_MIGRATION = True URL_LOG_SHORTEN_LENGTH = 30 SHOW_DEPRECATION_WARNINGS = True SCREENSHOT_HEIGHT_TRESHOLD = 10000 -PAGE_TIMEOUT=60000 -DOWNLOAD_PAGE_TIMEOUT=60000 \ No newline at end of file +PAGE_TIMEOUT = 60000 +DOWNLOAD_PAGE_TIMEOUT = 60000 diff --git a/crawl4ai/content_filter_strategy.py b/crawl4ai/content_filter_strategy.py index ce433118..11a33ac2 100644 --- a/crawl4ai/content_filter_strategy.py +++ b/crawl4ai/content_filter_strategy.py @@ -1,59 +1,100 @@ import re from bs4 import BeautifulSoup, Tag -from typing import List, Tuple, Dict +from typing import List, Tuple from rank_bm25 import BM25Okapi -from time import perf_counter from collections import deque -from bs4 import BeautifulSoup, NavigableString, Tag, Comment +from bs4 import NavigableString, Comment from .utils import clean_tokens from abc import ABC, abstractmethod import math from snowballstemmer import stemmer + + class RelevantContentFilter(ABC): """Abstract base class for content filtering strategies""" + def __init__(self, user_query: str = None): self.user_query = user_query self.included_tags = { # Primary structure - 'article', 'main', 'section', 'div', + "article", + "main", + "section", + "div", # List structures - 'ul', 'ol', 'li', 'dl', 'dt', 'dd', + "ul", + "ol", + "li", + "dl", + "dt", + "dd", # Text content - 'p', 'span', 'blockquote', 'pre', 'code', + "p", + "span", + "blockquote", + "pre", + "code", # Headers - 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", # Tables - 'table', 'thead', 'tbody', 'tr', 'td', 'th', + "table", + "thead", + "tbody", + "tr", + "td", + "th", # Other semantic elements - 'figure', 'figcaption', 'details', 'summary', + "figure", + "figcaption", + "details", + "summary", # Text formatting - 'em', 'strong', 'b', 'i', 'mark', 'small', + "em", + "strong", + "b", + "i", + "mark", + "small", # Rich content - 'time', 'address', 'cite', 'q' + "time", + "address", + "cite", + "q", } self.excluded_tags = { - 'nav', 'footer', 'header', 'aside', 'script', - 'style', 'form', 'iframe', 'noscript' + "nav", + "footer", + "header", + "aside", + "script", + "style", + "form", + "iframe", + "noscript", } - self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'} + self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"} self.negative_patterns = re.compile( - r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share', - re.I + r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I ) self.min_word_count = 2 - + @abstractmethod def filter_content(self, html: str) -> List[str]: """Abstract method to be implemented by specific filtering strategies""" pass - + def extract_page_query(self, soup: BeautifulSoup, body: Tag) -> str: """Common method to extract page metadata with fallbacks""" if self.user_query: return self.user_query query_parts = [] - + # Title try: title = soup.title.string @@ -62,109 +103,145 @@ class RelevantContentFilter(ABC): except Exception: pass - if soup.find('h1'): - query_parts.append(soup.find('h1').get_text()) - + if soup.find("h1"): + query_parts.append(soup.find("h1").get_text()) + # Meta tags temp = "" - for meta_name in ['keywords', 'description']: - meta = soup.find('meta', attrs={'name': meta_name}) - if meta and meta.get('content'): - query_parts.append(meta['content']) - temp += meta['content'] - + for meta_name in ["keywords", "description"]: + meta = soup.find("meta", attrs={"name": meta_name}) + if meta and meta.get("content"): + query_parts.append(meta["content"]) + temp += meta["content"] + # If still empty, grab first significant paragraph if not temp: # Find the first tag P thatits text contains more than 50 characters - for p in body.find_all('p'): + for p in body.find_all("p"): if len(p.get_text()) > 150: query_parts.append(p.get_text()[:150]) - break - - return ' '.join(filter(None, query_parts)) + break - def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]: + return " ".join(filter(None, query_parts)) + + def extract_text_chunks( + self, body: Tag, min_word_threshold: int = None + ) -> List[Tuple[str, str]]: """ Extracts text chunks from a BeautifulSoup body element while preserving order. Returns list of tuples (text, tag_name) for classification. - + Args: body: BeautifulSoup Tag object representing the body element - + Returns: List of (text, tag_name) tuples """ # Tags to ignore - inline elements that shouldn't break text flow INLINE_TAGS = { - 'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code', - 'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q', - 'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup', - 'textarea', 'time', 'tt', 'var' + "a", + "abbr", + "acronym", + "b", + "bdo", + "big", + "br", + "button", + "cite", + "code", + "dfn", + "em", + "i", + "img", + "input", + "kbd", + "label", + "map", + "object", + "q", + "samp", + "script", + "select", + "small", + "span", + "strong", + "sub", + "sup", + "textarea", + "time", + "tt", + "var", } - + # Tags that typically contain meaningful headers - HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'} - + HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"} + chunks = [] current_text = [] chunk_index = 0 - + def should_break_chunk(tag: Tag) -> bool: """Determine if a tag should cause a break in the current text chunk""" - return ( - tag.name not in INLINE_TAGS - and not (tag.name == 'p' and len(current_text) == 0) + return tag.name not in INLINE_TAGS and not ( + tag.name == "p" and len(current_text) == 0 ) - + # Use deque for efficient push/pop operations stack = deque([(body, False)]) - + while stack: element, visited = stack.pop() - + if visited: # End of block element - flush accumulated text if current_text and should_break_chunk(element): - text = ' '.join(''.join(current_text).split()) + text = " ".join("".join(current_text).split()) if text: - tag_type = 'header' if element.name in HEADER_TAGS else 'content' + tag_type = ( + "header" if element.name in HEADER_TAGS else "content" + ) chunks.append((chunk_index, text, tag_type, element)) chunk_index += 1 current_text = [] continue - + if isinstance(element, NavigableString): if str(element).strip(): current_text.append(str(element).strip()) continue - + # Pre-allocate children to avoid multiple list operations children = list(element.children) if not children: continue - + # Mark block for revisit after processing children stack.append((element, True)) - + # Add children in reverse order for correct processing for child in reversed(children): if isinstance(child, (Tag, NavigableString)): stack.append((child, False)) - + # Handle any remaining text if current_text: - text = ' '.join(''.join(current_text).split()) + text = " ".join("".join(current_text).split()) if text: - chunks.append((chunk_index, text, 'content', body)) - - if min_word_threshold: - chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold] - - return chunks + chunks.append((chunk_index, text, "content", body)) - def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]: + if min_word_threshold: + chunks = [ + chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold + ] + + return chunks + + def _deprecated_extract_text_chunks( + self, soup: BeautifulSoup + ) -> List[Tuple[int, str, Tag]]: """Common method for extracting text chunks""" _text_cache = {} + def fast_text(element: Tag) -> str: elem_id = id(element) if elem_id in _text_cache: @@ -175,13 +252,13 @@ class RelevantContentFilter(ABC): text = content.strip() if text: texts.append(text) - result = ' '.join(texts) + result = " ".join(texts) _text_cache[elem_id] = result return result - + candidates = [] index = 0 - + def dfs(element): nonlocal index if isinstance(element, Tag): @@ -189,7 +266,7 @@ class RelevantContentFilter(ABC): if not self.is_excluded(element): text = fast_text(element) word_count = len(text.split()) - + # Headers pass through with adjusted minimum if element.name in self.header_tags: if word_count >= 3: # Minimal sanity check for headers @@ -199,7 +276,7 @@ class RelevantContentFilter(ABC): elif word_count >= self.min_word_count: candidates.append((index, text, element)) index += 1 - + for child in element.children: dfs(child) @@ -210,59 +287,67 @@ class RelevantContentFilter(ABC): """Common method for exclusion logic""" if tag.name in self.excluded_tags: return True - class_id = ' '.join(filter(None, [ - ' '.join(tag.get('class', [])), - tag.get('id', '') - ])) + class_id = " ".join( + filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")]) + ) return bool(self.negative_patterns.search(class_id)) def clean_element(self, tag: Tag) -> str: """Common method for cleaning HTML elements with minimal overhead""" if not tag or not isinstance(tag, Tag): return "" - - unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'} - unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'} - + + unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"} + unwanted_attrs = { + "style", + "onclick", + "onmouseover", + "align", + "bgcolor", + "class", + "id", + } + # Use string builder pattern for better performance builder = [] - + def render_tag(elem): if not isinstance(elem, Tag): if isinstance(elem, str): builder.append(elem.strip()) return - + if elem.name in unwanted_tags: return - + # Start tag - builder.append(f'<{elem.name}') - + builder.append(f"<{elem.name}") + # Add cleaned attributes attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs} for key, value in attrs.items(): builder.append(f' {key}="{value}"') - - builder.append('>') - + + builder.append(">") + # Process children for child in elem.children: render_tag(child) - + # Close tag - builder.append(f'') - + builder.append(f"") + try: render_tag(tag) - return ''.join(builder) + return "".join(builder) except Exception: return str(tag) # Fallback to original if anything fails + class BM25ContentFilter(RelevantContentFilter): """ Content filtering using BM25 algorithm with priority tag handling. - + How it works: 1. Extracts page metadata with fallbacks. 2. Extracts text chunks from the body element. @@ -271,22 +356,28 @@ class BM25ContentFilter(RelevantContentFilter): 5. Filters out chunks below the threshold. 6. Sorts chunks by score in descending order. 7. Returns the top N chunks. - + Attributes: user_query (str): User query for filtering (optional). bm25_threshold (float): BM25 threshold for filtering (default: 1.0). language (str): Language for stemming (default: 'english'). - + Methods: filter_content(self, html: str, min_word_threshold: int = None) """ - def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'): + + def __init__( + self, + user_query: str = None, + bm25_threshold: float = 1.0, + language: str = "english", + ): """ Initializes the BM25ContentFilter class, if not provided, falls back to page metadata. - + Note: If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph. - + Args: user_query (str): User query for filtering (optional). bm25_threshold (float): BM25 threshold for filtering (default: 1.0). @@ -295,52 +386,52 @@ class BM25ContentFilter(RelevantContentFilter): super().__init__(user_query=user_query) self.bm25_threshold = bm25_threshold self.priority_tags = { - 'h1': 5.0, - 'h2': 4.0, - 'h3': 3.0, - 'title': 4.0, - 'strong': 2.0, - 'b': 1.5, - 'em': 1.5, - 'blockquote': 2.0, - 'code': 2.0, - 'pre': 1.5, - 'th': 1.5, # Table headers + "h1": 5.0, + "h2": 4.0, + "h3": 3.0, + "title": 4.0, + "strong": 2.0, + "b": 1.5, + "em": 1.5, + "blockquote": 2.0, + "code": 2.0, + "pre": 1.5, + "th": 1.5, # Table headers } self.stemmer = stemmer(language) def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: """ Implements content filtering using BM25 algorithm with priority tag handling. - + Note: This method implements the filtering logic for the BM25ContentFilter class. It takes HTML content as input and returns a list of filtered text chunks. - + Args: html (str): HTML content to be filtered. min_word_threshold (int): Minimum word threshold for filtering (optional). - + Returns: List[str]: List of filtered text chunks. """ if not html or not isinstance(html, str): return [] - soup = BeautifulSoup(html, 'lxml') - + soup = BeautifulSoup(html, "lxml") + # Check if body is present if not soup.body: # Wrap in body tag if missing - soup = BeautifulSoup(f'{html}', 'lxml') - body = soup.find('body') - + soup = BeautifulSoup(f"{html}", "lxml") + body = soup.find("body") + query = self.extract_page_query(soup, body) - + if not query: return [] # return [self.clean_element(soup)] - + candidates = self.extract_text_chunks(body, min_word_threshold) if not candidates: @@ -349,16 +440,20 @@ class BM25ContentFilter(RelevantContentFilter): # Tokenize corpus # tokenized_corpus = [chunk.lower().split() for _, chunk, _, _ in candidates] # tokenized_query = query.lower().split() - - # tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()] - # for _, chunk, _, _ in candidates] - # tokenized_query = [ps.stem(word) for word in query.lower().split()] - - tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()] - for _, chunk, _, _ in candidates] - tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()] - # tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] + # tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()] + # for _, chunk, _, _ in candidates] + # tokenized_query = [ps.stem(word) for word in query.lower().split()] + + tokenized_corpus = [ + [self.stemmer.stemWord(word) for word in chunk.lower().split()] + for _, chunk, _, _ in candidates + ] + tokenized_query = [ + self.stemmer.stemWord(word) for word in query.lower().split() + ] + + # tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] # for _, chunk, _, _ in candidates] # tokenized_query = [self.stemmer.stemWord(word) for word in tokenize_text(query.lower())] @@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter): # Filter candidates by threshold selected_candidates = [ - (index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates + (index, chunk, tag) + for adjusted_score, index, chunk, tag in adjusted_candidates if adjusted_score >= self.bm25_threshold ] @@ -390,10 +486,11 @@ class BM25ContentFilter(RelevantContentFilter): return [self.clean_element(tag) for _, _, tag in selected_candidates] + class PruningContentFilter(RelevantContentFilter): """ Content filtering using pruning algorithm with dynamic threshold. - + How it works: 1. Extracts page metadata with fallbacks. 2. Extracts text chunks from the body element. @@ -407,18 +504,24 @@ class PruningContentFilter(RelevantContentFilter): min_word_threshold (int): Minimum word threshold for filtering (optional). threshold_type (str): Threshold type for dynamic threshold (default: 'fixed'). threshold (float): Fixed threshold value (default: 0.48). - + Methods: filter_content(self, html: str, min_word_threshold: int = None): """ - def __init__(self, user_query: str = None, min_word_threshold: int = None, - threshold_type: str = 'fixed', threshold: float = 0.48): + + def __init__( + self, + user_query: str = None, + min_word_threshold: int = None, + threshold_type: str = "fixed", + threshold: float = 0.48, + ): """ Initializes the PruningContentFilter class, if not provided, falls back to page metadata. - + Note: If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph. - + Args: user_query (str): User query for filtering (optional). min_word_threshold (int): Minimum word threshold for filtering (optional). @@ -429,92 +532,92 @@ class PruningContentFilter(RelevantContentFilter): self.min_word_threshold = min_word_threshold self.threshold_type = threshold_type self.threshold = threshold - + # Add tag importance for dynamic threshold self.tag_importance = { - 'article': 1.5, - 'main': 1.4, - 'section': 1.3, - 'p': 1.2, - 'h1': 1.4, - 'h2': 1.3, - 'h3': 1.2, - 'div': 0.7, - 'span': 0.6 + "article": 1.5, + "main": 1.4, + "section": 1.3, + "p": 1.2, + "h1": 1.4, + "h2": 1.3, + "h3": 1.2, + "div": 0.7, + "span": 0.6, } - + # Metric configuration self.metric_config = { - 'text_density': True, - 'link_density': True, - 'tag_weight': True, - 'class_id_weight': True, - 'text_length': True, + "text_density": True, + "link_density": True, + "tag_weight": True, + "class_id_weight": True, + "text_length": True, } - + self.metric_weights = { - 'text_density': 0.4, - 'link_density': 0.2, - 'tag_weight': 0.2, - 'class_id_weight': 0.1, - 'text_length': 0.1, + "text_density": 0.4, + "link_density": 0.2, + "tag_weight": 0.2, + "class_id_weight": 0.1, + "text_length": 0.1, } - + self.tag_weights = { - 'div': 0.5, - 'p': 1.0, - 'article': 1.5, - 'section': 1.0, - 'span': 0.3, - 'li': 0.5, - 'ul': 0.5, - 'ol': 0.5, - 'h1': 1.2, - 'h2': 1.1, - 'h3': 1.0, - 'h4': 0.9, - 'h5': 0.8, - 'h6': 0.7, + "div": 0.5, + "p": 1.0, + "article": 1.5, + "section": 1.0, + "span": 0.3, + "li": 0.5, + "ul": 0.5, + "ol": 0.5, + "h1": 1.2, + "h2": 1.1, + "h3": 1.0, + "h4": 0.9, + "h5": 0.8, + "h6": 0.7, } def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: """ Implements content filtering using pruning algorithm with dynamic threshold. - + Note: This method implements the filtering logic for the PruningContentFilter class. It takes HTML content as input and returns a list of filtered text chunks. - + Args: html (str): HTML content to be filtered. min_word_threshold (int): Minimum word threshold for filtering (optional). - + Returns: List[str]: List of filtered text chunks. """ if not html or not isinstance(html, str): return [] - - soup = BeautifulSoup(html, 'lxml') + + soup = BeautifulSoup(html, "lxml") if not soup.body: - soup = BeautifulSoup(f'{html}', 'lxml') - + soup = BeautifulSoup(f"{html}", "lxml") + # Remove comments and unwanted tags self._remove_comments(soup) self._remove_unwanted_tags(soup) - + # Prune tree starting from body - body = soup.find('body') + body = soup.find("body") self._prune_tree(body) - + # Extract remaining content as list of HTML strings content_blocks = [] for element in body.children: - if isinstance(element, str) or not hasattr(element, 'name'): + if isinstance(element, str) or not hasattr(element, "name"): continue if len(element.get_text(strip=True)) > 0: content_blocks.append(str(element)) - + return content_blocks def _remove_comments(self, soup): @@ -531,34 +634,38 @@ class PruningContentFilter(RelevantContentFilter): def _prune_tree(self, node): """ Prunes the tree starting from the given node. - + Args: node (Tag): The node from which the pruning starts. """ - if not node or not hasattr(node, 'name') or node.name is None: + if not node or not hasattr(node, "name") or node.name is None: return text_len = len(node.get_text(strip=True)) - tag_len = len(node.encode_contents().decode('utf-8')) - link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s) + tag_len = len(node.encode_contents().decode("utf-8")) + link_text_len = sum( + len(s.strip()) + for s in (a.string for a in node.find_all("a", recursive=False)) + if s + ) metrics = { - 'node': node, - 'tag_name': node.name, - 'text_len': text_len, - 'tag_len': tag_len, - 'link_text_len': link_text_len + "node": node, + "tag_name": node.name, + "text_len": text_len, + "tag_len": tag_len, + "link_text_len": link_text_len, } score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len) - if self.threshold_type == 'fixed': + if self.threshold_type == "fixed": should_remove = score < self.threshold else: # dynamic tag_importance = self.tag_importance.get(node.name, 0.7) text_ratio = text_len / tag_len if tag_len > 0 else 0 link_ratio = link_text_len / text_len if text_len > 0 else 1 - + threshold = self.threshold # base threshold if tag_importance > 1: threshold *= 0.8 @@ -566,13 +673,13 @@ class PruningContentFilter(RelevantContentFilter): threshold *= 0.9 if link_ratio > 0.6: threshold *= 1.2 - + should_remove = score < threshold if should_remove: node.decompose() else: - children = [child for child in node.children if hasattr(child, 'name')] + children = [child for child in node.children if hasattr(child, "name")] for child in children: self._prune_tree(child) @@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter): """Computes the composite score""" if self.min_word_threshold: # Get raw text from metrics node - avoid extra processing - text = metrics['node'].get_text(strip=True) - word_count = text.count(' ') + 1 + text = metrics["node"].get_text(strip=True) + word_count = text.count(" ") + 1 if word_count < self.min_word_threshold: return -1.0 # Guaranteed removal score = 0.0 total_weight = 0.0 - if self.metric_config['text_density']: + if self.metric_config["text_density"]: density = text_len / tag_len if tag_len > 0 else 0 - score += self.metric_weights['text_density'] * density - total_weight += self.metric_weights['text_density'] + score += self.metric_weights["text_density"] * density + total_weight += self.metric_weights["text_density"] - if self.metric_config['link_density']: + if self.metric_config["link_density"]: density = 1 - (link_text_len / text_len if text_len > 0 else 0) - score += self.metric_weights['link_density'] * density - total_weight += self.metric_weights['link_density'] + score += self.metric_weights["link_density"] * density + total_weight += self.metric_weights["link_density"] - if self.metric_config['tag_weight']: - tag_score = self.tag_weights.get(metrics['tag_name'], 0.5) - score += self.metric_weights['tag_weight'] * tag_score - total_weight += self.metric_weights['tag_weight'] + if self.metric_config["tag_weight"]: + tag_score = self.tag_weights.get(metrics["tag_name"], 0.5) + score += self.metric_weights["tag_weight"] * tag_score + total_weight += self.metric_weights["tag_weight"] - if self.metric_config['class_id_weight']: - class_score = self._compute_class_id_weight(metrics['node']) - score += self.metric_weights['class_id_weight'] * max(0, class_score) - total_weight += self.metric_weights['class_id_weight'] + if self.metric_config["class_id_weight"]: + class_score = self._compute_class_id_weight(metrics["node"]) + score += self.metric_weights["class_id_weight"] * max(0, class_score) + total_weight += self.metric_weights["class_id_weight"] - if self.metric_config['text_length']: - score += self.metric_weights['text_length'] * math.log(text_len + 1) - total_weight += self.metric_weights['text_length'] + if self.metric_config["text_length"]: + score += self.metric_weights["text_length"] * math.log(text_len + 1) + total_weight += self.metric_weights["text_length"] return score / total_weight if total_weight > 0 else 0 def _compute_class_id_weight(self, node): """Computes the class ID weight""" class_id_score = 0 - if 'class' in node.attrs: - classes = ' '.join(node['class']) + if "class" in node.attrs: + classes = " ".join(node["class"]) if self.negative_patterns.match(classes): class_id_score -= 0.5 - if 'id' in node.attrs: - element_id = node['id'] + if "id" in node.attrs: + element_id = node["id"] if self.negative_patterns.match(element_id): class_id_score -= 0.5 - return class_id_score \ No newline at end of file + return class_id_score diff --git a/crawl4ai/content_scraping_strategy.py b/crawl4ai/content_scraping_strategy.py index ae09037d..6cb169db 100644 --- a/crawl4ai/content_scraping_strategy.py +++ b/crawl4ai/content_scraping_strategy.py @@ -1,12 +1,18 @@ -import re +import re from itertools import chain -import time from abc import ABC, abstractmethod from typing import Dict, Any, Optional from bs4 import BeautifulSoup -from concurrent.futures import ThreadPoolExecutor -import asyncio, requests, re, os -from .config import * +import asyncio +import requests +from .config import ( + MIN_WORD_THRESHOLD, + IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, + IMAGE_SCORE_THRESHOLD, + ONLY_TEXT_ELIGIBLE_TAGS, + IMPORTANT_ATTRS, + SOCIAL_MEDIA_DOMAINS, +) from bs4 import NavigableString, Comment from bs4 import PageElement, Tag from urllib.parse import urljoin @@ -14,18 +20,18 @@ from requests.exceptions import InvalidSchema from .utils import ( extract_metadata, normalize_url, - is_external_url, - get_base_domain, - extract_metadata_using_lxml + is_external_url, + get_base_domain, + extract_metadata_using_lxml, ) from lxml import etree from lxml import html as lhtml -from typing import Dict, Any, List, Tuple +from typing import List from .models import ScrapingResult, MediaItem, Link, Media, Links # Pre-compile regular expressions for Open Graph and Twitter metadata -OG_REGEX = re.compile(r'^og:') -TWITTER_REGEX = re.compile(r'^twitter:') +OG_REGEX = re.compile(r"^og:") +TWITTER_REGEX = re.compile(r"^twitter:") DIMENSION_REGEX = re.compile(r"(\d+)(\D*)") @@ -34,17 +40,22 @@ def parse_srcset(s: str) -> List[Dict]: if not s: return [] variants = [] - for part in s.split(','): + for part in s.split(","): part = part.strip() if not part: continue parts = part.split() if len(parts) >= 1: url = parts[0] - width = parts[1].rstrip('w') if len(parts) > 1 and parts[1].endswith('w') else None - variants.append({'url': url, 'width': width}) + width = ( + parts[1].rstrip("w") + if len(parts) > 1 and parts[1].endswith("w") + else None + ) + variants.append({"url": url, "width": width}) return variants + # Function to parse image height/width value and units def parse_dimension(dimension): if dimension: @@ -52,26 +63,28 @@ def parse_dimension(dimension): match = DIMENSION_REGEX.match(dimension) if match: number = int(match.group(1)) - unit = match.group(2) or 'px' # Default unit is 'px' if not specified + unit = match.group(2) or "px" # Default unit is 'px' if not specified return number, unit return None, None + # Fetch image file metadata to extract size and extension def fetch_image_file_size(img, base_url): - #If src is relative path construct full URL, if not it may be CDN URL - img_url = urljoin(base_url,img.get('src')) + # If src is relative path construct full URL, if not it may be CDN URL + img_url = urljoin(base_url, img.get("src")) try: response = requests.head(img_url) if response.status_code == 200: - return response.headers.get('Content-Length',None) + return response.headers.get("Content-Length", None) else: print(f"Failed to retrieve file size for {img_url}") return None - except InvalidSchema as e: + except InvalidSchema: return None finally: return + class ContentScrapingStrategy(ABC): @abstractmethod def scrap(self, url: str, html: str, **kwargs) -> ScrapingResult: @@ -81,10 +94,11 @@ class ContentScrapingStrategy(ABC): async def ascrap(self, url: str, html: str, **kwargs) -> ScrapingResult: pass + class WebScrapingStrategy(ContentScrapingStrategy): """ - Class for web content scraping. Perhaps the most important class. - + Class for web content scraping. Perhaps the most important class. + How it works: 1. Extract content from HTML using BeautifulSoup. 2. Clean the extracted content using a content cleaning strategy. @@ -92,7 +106,7 @@ class WebScrapingStrategy(ContentScrapingStrategy): 4. Generate markdown content from the filtered content. 5. Return the markdown content. """ - + def __init__(self, logger=None): self.logger = logger @@ -101,10 +115,10 @@ class WebScrapingStrategy(ContentScrapingStrategy): if self.logger: log_method = getattr(self.logger, level) log_method(message=message, tag=tag, **kwargs) - + def scrap(self, url: str, html: str, **kwargs) -> ScrapingResult: """ - Main entry point for content scraping. + Main entry point for content scraping. Args: url (str): The URL of the page to scrape. @@ -121,20 +135,40 @@ class WebScrapingStrategy(ContentScrapingStrategy): success=False, media=Media(), links=Links(), - metadata={} + metadata={}, ) # Convert media items media = Media( - images=[MediaItem(**img) for img in raw_result.get("media", {}).get("images", []) if img], - videos=[MediaItem(**vid) for vid in raw_result.get("media", {}).get("videos", []) if vid], - audios=[MediaItem(**aud) for aud in raw_result.get("media", {}).get("audios", []) if aud] + images=[ + MediaItem(**img) + for img in raw_result.get("media", {}).get("images", []) + if img + ], + videos=[ + MediaItem(**vid) + for vid in raw_result.get("media", {}).get("videos", []) + if vid + ], + audios=[ + MediaItem(**aud) + for aud in raw_result.get("media", {}).get("audios", []) + if aud + ], ) # Convert links links = Links( - internal=[Link(**link) for link in raw_result.get("links", {}).get("internal", []) if link], - external=[Link(**link) for link in raw_result.get("links", {}).get("external", []) if link] + internal=[ + Link(**link) + for link in raw_result.get("links", {}).get("internal", []) + if link + ], + external=[ + Link(**link) + for link in raw_result.get("links", {}).get("external", []) + if link + ], ) return ScrapingResult( @@ -142,7 +176,7 @@ class WebScrapingStrategy(ContentScrapingStrategy): success=raw_result.get("success", False), media=media, links=links, - metadata=raw_result.get("metadata", {}) + metadata=raw_result.get("metadata", {}), ) async def ascrap(self, url: str, html: str, **kwargs) -> ScrapingResult: @@ -171,7 +205,11 @@ class WebScrapingStrategy(ContentScrapingStrategy): """ if isinstance(node, NavigableString): return node - if len(node.contents) == 1 and isinstance(node.contents[0], Tag) and node.contents[0].name == node.name: + if ( + len(node.contents) == 1 + and isinstance(node.contents[0], Tag) + and node.contents[0].name == node.name + ): return self.flatten_nested_elements(node.contents[0]) node.contents = [self.flatten_nested_elements(child) for child in node.contents] return node @@ -187,23 +225,27 @@ class WebScrapingStrategy(ContentScrapingStrategy): Returns: Tag: The closest parent with useful text, or None if not found. """ - image_description_min_word_threshold = kwargs.get('image_description_min_word_threshold', IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD) + image_description_min_word_threshold = kwargs.get( + "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD + ) current_tag = tag while current_tag: current_tag = current_tag.parent # Get the text content of the parent tag if current_tag: - text_content = current_tag.get_text(separator=' ',strip=True) + text_content = current_tag.get_text(separator=" ", strip=True) # Check if the text content has at least word_count_threshold if len(text_content.split()) >= image_description_min_word_threshold: return text_content return None - def remove_unwanted_attributes(self, element, important_attrs, keep_data_attributes=False): + def remove_unwanted_attributes( + self, element, important_attrs, keep_data_attributes=False + ): """ Remove unwanted attributes from an HTML element. - Args: + Args: element (Tag): The HTML element to remove attributes from. important_attrs (list): List of important attributes to keep. keep_data_attributes (bool): Whether to keep data attributes. @@ -215,18 +257,18 @@ class WebScrapingStrategy(ContentScrapingStrategy): for attr in element.attrs: if attr not in important_attrs: if keep_data_attributes: - if not attr.startswith('data-'): + if not attr.startswith("data-"): attrs_to_remove.append(attr) else: attrs_to_remove.append(attr) - + for attr in attrs_to_remove: del element[attr] def process_image(self, img, url, index, total_images, **kwargs): """ Process an image element. - + How it works: 1. Check if the image has valid display and inside undesired html elements. 2. Score an image for it's usefulness. @@ -244,33 +286,35 @@ class WebScrapingStrategy(ContentScrapingStrategy): Returns: dict: A dictionary containing the processed image information. """ - # parse_srcset = lambda s: [{'url': u.strip().split()[0], 'width': u.strip().split()[-1].rstrip('w') - # if ' ' in u else None} + # parse_srcset = lambda s: [{'url': u.strip().split()[0], 'width': u.strip().split()[-1].rstrip('w') + # if ' ' in u else None} # for u in [f"http{p}" for p in s.split("http") if p]] - + # Constants for checks - classes_to_check = frozenset(['button', 'icon', 'logo']) - tags_to_check = frozenset(['button', 'input']) - image_formats = frozenset(['jpg', 'jpeg', 'png', 'webp', 'avif', 'gif']) - + classes_to_check = frozenset(["button", "icon", "logo"]) + tags_to_check = frozenset(["button", "input"]) + image_formats = frozenset(["jpg", "jpeg", "png", "webp", "avif", "gif"]) + # Pre-fetch commonly used attributes - style = img.get('style', '') - alt = img.get('alt', '') - src = img.get('src', '') - data_src = img.get('data-src', '') - srcset = img.get('srcset', '') - data_srcset = img.get('data-srcset', '') - width = img.get('width') - height = img.get('height') + style = img.get("style", "") + alt = img.get("alt", "") + src = img.get("src", "") + data_src = img.get("data-src", "") + srcset = img.get("srcset", "") + data_srcset = img.get("data-srcset", "") + width = img.get("width") + height = img.get("height") parent = img.parent - parent_classes = parent.get('class', []) + parent_classes = parent.get("class", []) # Quick validation checks - if ('display:none' in style or - parent.name in tags_to_check or - any(c in cls for c in parent_classes for cls in classes_to_check) or - any(c in src for c in classes_to_check) or - any(c in alt for c in classes_to_check)): + if ( + "display:none" in style + or parent.name in tags_to_check + or any(c in cls for c in parent_classes for cls in classes_to_check) + or any(c in src for c in classes_to_check) + or any(c in alt for c in classes_to_check) + ): return None # Quick score calculation @@ -283,30 +327,29 @@ class WebScrapingStrategy(ContentScrapingStrategy): score += 1 if height_val > 150 else 0 if alt: score += 1 - score += index/total_images < 0.5 - + score += index / total_images < 0.5 + # image_format = '' # if "data:image/" in src: # image_format = src.split(',')[0].split(';')[0].split('/')[1].split(';')[0] # else: # image_format = os.path.splitext(src)[1].lower().strip('.').split('?')[0] - + # if image_format in ('jpg', 'png', 'webp', 'avif'): # score += 1 - - + # Check for image format in all possible sources def has_image_format(url): return any(fmt in url.lower() for fmt in image_formats) - + # Score for having proper image sources if any(has_image_format(url) for url in [src, data_src, srcset, data_srcset]): score += 1 if srcset or data_srcset: score += 1 - if img.find_parent('picture'): + if img.find_parent("picture"): score += 1 - + # Detect format from any available source detected_format = None for url in [src, data_src, srcset, data_srcset]: @@ -314,62 +357,66 @@ class WebScrapingStrategy(ContentScrapingStrategy): format_matches = [fmt for fmt in image_formats if fmt in url.lower()] if format_matches: detected_format = format_matches[0] - break + break - if score <= kwargs.get('image_score_threshold', IMAGE_SCORE_THRESHOLD): + if score <= kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD): return None # Use set for deduplication unique_urls = set() image_variants = [] - + # Generate a unique group ID for this set of variants - group_id = index - + group_id = index + # Base image info template base_info = { - 'alt': alt, - 'desc': self.find_closest_parent_with_useful_text(img, **kwargs), - 'score': score, - 'type': 'image', - 'group_id': group_id, # Group ID for this set of variants - 'format': detected_format, + "alt": alt, + "desc": self.find_closest_parent_with_useful_text(img, **kwargs), + "score": score, + "type": "image", + "group_id": group_id, # Group ID for this set of variants + "format": detected_format, } # Inline function for adding variants def add_variant(src, width=None): - if src and not src.startswith('data:') and src not in unique_urls: + if src and not src.startswith("data:") and src not in unique_urls: unique_urls.add(src) - image_variants.append({**base_info, 'src': src, 'width': width}) + image_variants.append({**base_info, "src": src, "width": width}) # Process all sources add_variant(src) add_variant(data_src) - + # Handle srcset and data-srcset in one pass - for attr in ('srcset', 'data-srcset'): + for attr in ("srcset", "data-srcset"): if value := img.get(attr): for source in parse_srcset(value): - add_variant(source['url'], source['width']) + add_variant(source["url"], source["width"]) # Quick picture element check - if picture := img.find_parent('picture'): - for source in picture.find_all('source'): - if srcset := source.get('srcset'): + if picture := img.find_parent("picture"): + for source in picture.find_all("source"): + if srcset := source.get("srcset"): for src in parse_srcset(srcset): - add_variant(src['url'], src['width']) + add_variant(src["url"], src["width"]) # Framework-specific attributes in one pass for attr, value in img.attrs.items(): - if attr.startswith('data-') and ('src' in attr or 'srcset' in attr) and 'http' in value: + if ( + attr.startswith("data-") + and ("src" in attr or "srcset" in attr) + and "http" in value + ): add_variant(value) return image_variants if image_variants else None - def process_element(self, url, element: PageElement, **kwargs) -> Dict[str, Any]: + def process_element(self, url, element: PageElement, **kwargs) -> Dict[str, Any]: """ Process an HTML element. - + How it works: 1. Check if the element is an image, video, or audio. 2. Extract the element's attributes and content. @@ -384,89 +431,92 @@ class WebScrapingStrategy(ContentScrapingStrategy): Returns: dict: A dictionary containing the processed element information. """ - media = {'images': [], 'videos': [], 'audios': []} + media = {"images": [], "videos": [], "audios": []} internal_links_dict = {} external_links_dict = {} self._process_element( - url, - element, - media, - internal_links_dict, - external_links_dict, - **kwargs + url, element, media, internal_links_dict, external_links_dict, **kwargs ) return { - 'media': media, - 'internal_links_dict': internal_links_dict, - 'external_links_dict': external_links_dict + "media": media, + "internal_links_dict": internal_links_dict, + "external_links_dict": external_links_dict, } - - def _process_element(self, url, element: PageElement, media: Dict[str, Any], internal_links_dict: Dict[str, Any], external_links_dict: Dict[str, Any], **kwargs) -> bool: + + def _process_element( + self, + url, + element: PageElement, + media: Dict[str, Any], + internal_links_dict: Dict[str, Any], + external_links_dict: Dict[str, Any], + **kwargs, + ) -> bool: """ - Process an HTML element. + Process an HTML element. """ try: if isinstance(element, NavigableString): if isinstance(element, Comment): element.extract() return False - + # if element.name == 'img': # process_image(element, url, 0, 1) # return True base_domain = kwargs.get("base_domain", get_base_domain(url)) - if element.name in ['script', 'style', 'link', 'meta', 'noscript']: + if element.name in ["script", "style", "link", "meta", "noscript"]: element.decompose() return False keep_element = False - - exclude_domains = kwargs.get('exclude_domains', []) + + exclude_domains = kwargs.get("exclude_domains", []) # exclude_social_media_domains = kwargs.get('exclude_social_media_domains', set(SOCIAL_MEDIA_DOMAINS)) # exclude_social_media_domains = SOCIAL_MEDIA_DOMAINS + kwargs.get('exclude_social_media_domains', []) # exclude_social_media_domains = list(set(exclude_social_media_domains)) - + try: - if element.name == 'a' and element.get('href'): - href = element.get('href', '').strip() + if element.name == "a" and element.get("href"): + href = element.get("href", "").strip() if not href: # Skip empty hrefs return False - - url_base = url.split('/')[2] - + + # url_base = url.split("/")[2] + # Normalize the URL try: normalized_href = normalize_url(href, url) - except ValueError as e: + except ValueError: # logging.warning(f"Invalid URL format: {href}, Error: {str(e)}") return False - + link_data = { - 'href': normalized_href, - 'text': element.get_text().strip(), - 'title': element.get('title', '').strip(), - 'base_domain': base_domain + "href": normalized_href, + "text": element.get_text().strip(), + "title": element.get("title", "").strip(), + "base_domain": base_domain, } - + is_external = is_external_url(normalized_href, base_domain) - + keep_element = True - + # Handle external link exclusions if is_external: link_base_domain = get_base_domain(normalized_href) - link_data['base_domain'] = link_base_domain - if kwargs.get('exclude_external_links', False): + link_data["base_domain"] = link_base_domain + if kwargs.get("exclude_external_links", False): element.decompose() return False # elif kwargs.get('exclude_social_media_links', False): # if link_base_domain in exclude_social_media_domains: # element.decompose() # return False - # if any(domain in normalized_href.lower() for domain in exclude_social_media_domains): - # element.decompose() - # return False + # if any(domain in normalized_href.lower() for domain in exclude_social_media_domains): + # element.decompose() + # return False elif exclude_domains: if link_base_domain in exclude_domains: element.decompose() @@ -482,32 +532,36 @@ class WebScrapingStrategy(ContentScrapingStrategy): if normalized_href not in internal_links_dict: internal_links_dict[normalized_href] = link_data - except Exception as e: raise Exception(f"Error processing links: {str(e)}") try: - if element.name == 'img': - potential_sources = ['src', 'data-src', 'srcset' 'data-lazy-src', 'data-original'] - src = element.get('src', '') + if element.name == "img": + potential_sources = [ + "src", + "data-src", + "srcset" "data-lazy-src", + "data-original", + ] + src = element.get("src", "") while not src and potential_sources: - src = element.get(potential_sources.pop(0), '') + src = element.get(potential_sources.pop(0), "") if not src: element.decompose() return False - + # If it is srcset pick up the first image - if 'srcset' in element.attrs: - src = element.attrs['srcset'].split(',')[0].split(' ')[0] - + if "srcset" in element.attrs: + src = element.attrs["srcset"].split(",")[0].split(" ")[0] + # If image src is internal, then skip if not is_external_url(src, base_domain): return True - + image_src_base_domain = get_base_domain(src) - + # Check flag if we should remove external images - if kwargs.get('exclude_external_images', False): + if kwargs.get("exclude_external_images", False): element.decompose() return False # src_url_base = src.split('/')[2] @@ -515,78 +569,98 @@ class WebScrapingStrategy(ContentScrapingStrategy): # if url_base not in src_url_base: # element.decompose() # return False - + # if kwargs.get('exclude_social_media_links', False): # if image_src_base_domain in exclude_social_media_domains: # element.decompose() # return False - # src_url_base = src.split('/')[2] - # url_base = url.split('/')[2] - # if any(domain in src for domain in exclude_social_media_domains): - # element.decompose() - # return False - + # src_url_base = src.split('/')[2] + # url_base = url.split('/')[2] + # if any(domain in src for domain in exclude_social_media_domains): + # element.decompose() + # return False + # Handle exclude domains - if exclude_domains: + if exclude_domains: if image_src_base_domain in exclude_domains: element.decompose() return False # if any(domain in src for domain in kwargs.get('exclude_domains', [])): # element.decompose() # return False - + return True # Always keep image elements - except Exception as e: + except Exception: raise "Error processing images" - - + # Check if flag to remove all forms is set - if kwargs.get('remove_forms', False) and element.name == 'form': + if kwargs.get("remove_forms", False) and element.name == "form": element.decompose() return False - - if element.name in ['video', 'audio']: - media[f"{element.name}s"].append({ - 'src': element.get('src'), - 'alt': element.get('alt'), - 'type': element.name, - 'description': self.find_closest_parent_with_useful_text(element, **kwargs) - }) - source_tags = element.find_all('source') + + if element.name in ["video", "audio"]: + media[f"{element.name}s"].append( + { + "src": element.get("src"), + "alt": element.get("alt"), + "type": element.name, + "description": self.find_closest_parent_with_useful_text( + element, **kwargs + ), + } + ) + source_tags = element.find_all("source") for source_tag in source_tags: - media[f"{element.name}s"].append({ - 'src': source_tag.get('src'), - 'alt': element.get('alt'), - 'type': element.name, - 'description': self.find_closest_parent_with_useful_text(element, **kwargs) - }) + media[f"{element.name}s"].append( + { + "src": source_tag.get("src"), + "alt": element.get("alt"), + "type": element.name, + "description": self.find_closest_parent_with_useful_text( + element, **kwargs + ), + } + ) return True # Always keep video and audio elements if element.name in ONLY_TEXT_ELIGIBLE_TAGS: - if kwargs.get('only_text', False): + if kwargs.get("only_text", False): element.replace_with(element.get_text()) try: - self.remove_unwanted_attributes(element, IMPORTANT_ATTRS, kwargs.get('keep_data_attributes', False)) + self.remove_unwanted_attributes( + element, IMPORTANT_ATTRS, kwargs.get("keep_data_attributes", False) + ) except Exception as e: # print('Error removing unwanted attributes:', str(e)) - self._log('error', + self._log( + "error", message="Error removing unwanted attributes: {error}", tag="SCRAPE", - params={"error": str(e)} + params={"error": str(e)}, ) # Process children for child in list(element.children): - if isinstance(child, NavigableString) and not isinstance(child, Comment): + if isinstance(child, NavigableString) and not isinstance( + child, Comment + ): if len(child.strip()) > 0: keep_element = True else: - if self._process_element(url, child, media, internal_links_dict, external_links_dict, **kwargs): + if self._process_element( + url, + child, + media, + internal_links_dict, + external_links_dict, + **kwargs, + ): keep_element = True - # Check word count - word_count_threshold = kwargs.get('word_count_threshold', MIN_WORD_THRESHOLD) + word_count_threshold = kwargs.get( + "word_count_threshold", MIN_WORD_THRESHOLD + ) if not keep_element: word_count = len(element.get_text(strip=True).split()) keep_element = word_count >= word_count_threshold @@ -597,14 +671,22 @@ class WebScrapingStrategy(ContentScrapingStrategy): return keep_element except Exception as e: # print('Error processing element:', str(e)) - self._log('error', + self._log( + "error", message="Error processing element: {error}", tag="SCRAPE", - params={"error": str(e)} - ) + params={"error": str(e)}, + ) return False - def _scrap(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, css_selector: str = None, **kwargs) -> Dict[str, Any]: + def _scrap( + self, + url: str, + html: str, + word_count_threshold: int = MIN_WORD_THRESHOLD, + css_selector: str = None, + **kwargs, + ) -> Dict[str, Any]: """ Extract content from HTML using BeautifulSoup. @@ -622,83 +704,93 @@ class WebScrapingStrategy(ContentScrapingStrategy): if not html: return None - parser_type = kwargs.get('parser', 'lxml') + parser_type = kwargs.get("parser", "lxml") soup = BeautifulSoup(html, parser_type) body = soup.body base_domain = get_base_domain(url) - + try: meta = extract_metadata("", soup) except Exception as e: - self._log('error', + self._log( + "error", message="Error extracting metadata: {error}", tag="SCRAPE", - params={"error": str(e)} - ) + params={"error": str(e)}, + ) meta = {} - + # Handle tag-based removal first - faster than CSS selection - excluded_tags = set(kwargs.get('excluded_tags', []) or []) + excluded_tags = set(kwargs.get("excluded_tags", []) or []) if excluded_tags: for element in body.find_all(lambda tag: tag.name in excluded_tags): element.extract() - + # Handle CSS selector-based removal - excluded_selector = kwargs.get('excluded_selector', '') + excluded_selector = kwargs.get("excluded_selector", "") if excluded_selector: - is_single_selector = ',' not in excluded_selector and ' ' not in excluded_selector + is_single_selector = ( + "," not in excluded_selector and " " not in excluded_selector + ) if is_single_selector: while element := body.select_one(excluded_selector): element.extract() else: for element in body.select(excluded_selector): - element.extract() - + element.extract() + if css_selector: selected_elements = body.select(css_selector) if not selected_elements: return { - 'markdown': '', - 'cleaned_html': '', - 'success': True, - 'media': {'images': [], 'videos': [], 'audios': []}, - 'links': {'internal': [], 'external': []}, - 'metadata': {}, - 'message': f"No elements found for CSS selector: {css_selector}" + "markdown": "", + "cleaned_html": "", + "success": True, + "media": {"images": [], "videos": [], "audios": []}, + "links": {"internal": [], "external": []}, + "metadata": {}, + "message": f"No elements found for CSS selector: {css_selector}", } # raise InvalidCSSSelectorError(f"Invalid CSS selector, No elements found for CSS selector: {css_selector}") - body = soup.new_tag('div') + body = soup.new_tag("div") for el in selected_elements: body.append(el) - kwargs['exclude_social_media_domains'] = set(kwargs.get('exclude_social_media_domains', []) + SOCIAL_MEDIA_DOMAINS) - kwargs['exclude_domains'] = set(kwargs.get('exclude_domains', [])) - if kwargs.get('exclude_social_media_links', False): - kwargs['exclude_domains'] = kwargs['exclude_domains'].union(kwargs['exclude_social_media_domains']) - - result_obj = self.process_element( - url, - body, - word_count_threshold = word_count_threshold, - base_domain=base_domain, - **kwargs + kwargs["exclude_social_media_domains"] = set( + kwargs.get("exclude_social_media_domains", []) + SOCIAL_MEDIA_DOMAINS ) - - links = {'internal': [], 'external': []} - media = result_obj['media'] - internal_links_dict = result_obj['internal_links_dict'] - external_links_dict = result_obj['external_links_dict'] - + kwargs["exclude_domains"] = set(kwargs.get("exclude_domains", [])) + if kwargs.get("exclude_social_media_links", False): + kwargs["exclude_domains"] = kwargs["exclude_domains"].union( + kwargs["exclude_social_media_domains"] + ) + + result_obj = self.process_element( + url, + body, + word_count_threshold=word_count_threshold, + base_domain=base_domain, + **kwargs, + ) + + links = {"internal": [], "external": []} + media = result_obj["media"] + internal_links_dict = result_obj["internal_links_dict"] + external_links_dict = result_obj["external_links_dict"] + # Update the links dictionary with unique links - links['internal'] = list(internal_links_dict.values()) - links['external'] = list(external_links_dict.values()) + links["internal"] = list(internal_links_dict.values()) + links["external"] = list(external_links_dict.values()) # # Process images using ThreadPoolExecutor - imgs = body.find_all('img') - - media['images'] = [ - img for result in (self.process_image(img, url, i, len(imgs), **kwargs) - for i, img in enumerate(imgs)) + imgs = body.find_all("img") + + media["images"] = [ + img + for result in ( + self.process_image(img, url, i, len(imgs), **kwargs) + for i, img in enumerate(imgs) + ) if result is not None for img in result ] @@ -706,22 +798,22 @@ class WebScrapingStrategy(ContentScrapingStrategy): body = self.flatten_nested_elements(body) base64_pattern = re.compile(r'data:image/[^;]+;base64,([^"]+)') for img in imgs: - src = img.get('src', '') + src = img.get("src", "") if base64_pattern.match(src): # Replace base64 data with empty string - img['src'] = base64_pattern.sub('', src) - + img["src"] = base64_pattern.sub("", src) + str_body = "" try: - str_body = body.encode_contents().decode('utf-8') - except Exception as e: + str_body = body.encode_contents().decode("utf-8") + except Exception: # Reset body to the original HTML success = False - body = BeautifulSoup(html, 'html.parser') - + body = BeautifulSoup(html, "html.parser") + # Create a new div with a special ID - error_div = body.new_tag('div', id='crawl4ai_error_message') - error_div.string = ''' + error_div = body.new_tag("div", id="crawl4ai_error_message") + error_div.string = """ Crawl4AI Error: This page is not fully supported. Possible reasons: @@ -734,128 +826,146 @@ class WebScrapingStrategy(ContentScrapingStrategy): - Set headless=False to visualize what's happening on the page. If the issue persists, please check the page's structure and any potential anti-crawling measures. - ''' - + """ + # Append the error div to the body body.append(error_div) - str_body = body.encode_contents().decode('utf-8') - - print(f"[LOG] 😧 Error: After processing the crawled HTML and removing irrelevant tags, nothing was left in the page. Check the markdown for further details.") - self._log('error', + str_body = body.encode_contents().decode("utf-8") + + print( + "[LOG] 😧 Error: After processing the crawled HTML and removing irrelevant tags, nothing was left in the page. Check the markdown for further details." + ) + self._log( + "error", message="After processing the crawled HTML and removing irrelevant tags, nothing was left in the page. Check the markdown for further details.", - tag="SCRAPE" + tag="SCRAPE", ) - cleaned_html = str_body.replace('\n\n', '\n').replace(' ', ' ') + cleaned_html = str_body.replace("\n\n", "\n").replace(" ", " ") - return { # **markdown_content, - 'cleaned_html': cleaned_html, - 'success': success, - 'media': media, - 'links': links, - 'metadata': meta + "cleaned_html": cleaned_html, + "success": success, + "media": media, + "links": links, + "metadata": meta, } + class LXMLWebScrapingStrategy(WebScrapingStrategy): def __init__(self, logger=None): super().__init__(logger) self.DIMENSION_REGEX = re.compile(r"(\d+)(\D*)") self.BASE64_PATTERN = re.compile(r'data:image/[^;]+;base64,([^"]+)') - def _process_element(self, url: str, element: lhtml.HtmlElement, media: Dict[str, List], - internal_links_dict: Dict[str, Any], external_links_dict: Dict[str, Any], **kwargs) -> bool: + def _process_element( + self, + url: str, + element: lhtml.HtmlElement, + media: Dict[str, List], + internal_links_dict: Dict[str, Any], + external_links_dict: Dict[str, Any], + **kwargs, + ) -> bool: base_domain = kwargs.get("base_domain", get_base_domain(url)) - exclude_domains = set(kwargs.get('exclude_domains', [])) - + exclude_domains = set(kwargs.get("exclude_domains", [])) + # Process links - for link in element.xpath('.//a[@href]'): - href = link.get('href', '').strip() + for link in element.xpath(".//a[@href]"): + href = link.get("href", "").strip() if not href: continue - + try: normalized_href = normalize_url(href, url) link_data = { - 'href': normalized_href, - 'text': link.text_content().strip(), - 'title': link.get('title', '').strip(), - 'base_domain': base_domain + "href": normalized_href, + "text": link.text_content().strip(), + "title": link.get("title", "").strip(), + "base_domain": base_domain, } - + is_external = is_external_url(normalized_href, base_domain) if is_external: link_base_domain = get_base_domain(normalized_href) - link_data['base_domain'] = link_base_domain - if kwargs.get('exclude_external_links', False) or link_base_domain in exclude_domains: + link_data["base_domain"] = link_base_domain + if ( + kwargs.get("exclude_external_links", False) + or link_base_domain in exclude_domains + ): link.getparent().remove(link) continue - + if normalized_href not in external_links_dict: external_links_dict[normalized_href] = link_data else: if normalized_href not in internal_links_dict: internal_links_dict[normalized_href] = link_data - + except Exception as e: - self._log('error', f"Error processing link: {str(e)}", "SCRAPE") + self._log("error", f"Error processing link: {str(e)}", "SCRAPE") continue # Process images - images = element.xpath('.//img') + images = element.xpath(".//img") total_images = len(images) - + for idx, img in enumerate(images): - src = img.get('src') or '' + src = img.get("src") or "" img_domain = get_base_domain(src) # Decide if we need to exclude this image # 1) If its domain is in exclude_domains, remove. # 2) Or if exclude_external_images=True and it's an external domain, remove. if (img_domain in exclude_domains) or ( - kwargs.get('exclude_external_images', False) and is_external_url(src, base_domain) + kwargs.get("exclude_external_images", False) + and is_external_url(src, base_domain) ): parent = img.getparent() if parent is not None: parent.remove(img) continue - + # Otherwise, process the image as usual. try: - processed_images = self.process_image(img, url, idx, total_images, **kwargs) + processed_images = self.process_image( + img, url, idx, total_images, **kwargs + ) if processed_images: - media['images'].extend(processed_images) + media["images"].extend(processed_images) except Exception as e: - self._log('error', f"Error processing image: {str(e)}", "SCRAPE") + self._log("error", f"Error processing image: {str(e)}", "SCRAPE") # Process videos and audios - for media_type in ['video', 'audio']: - for elem in element.xpath(f'.//{media_type}'): + for media_type in ["video", "audio"]: + for elem in element.xpath(f".//{media_type}"): media_info = { - 'src': elem.get('src'), - 'alt': elem.get('alt'), - 'type': media_type, - 'description': self.find_closest_parent_with_useful_text(elem, **kwargs) + "src": elem.get("src"), + "alt": elem.get("alt"), + "type": media_type, + "description": self.find_closest_parent_with_useful_text( + elem, **kwargs + ), } media[f"{media_type}s"].append(media_info) - + # Process source tags within media elements - for source in elem.xpath('.//source'): - if src := source.get('src'): - media[f"{media_type}s"].append({**media_info, 'src': src}) + for source in elem.xpath(".//source"): + if src := source.get("src"): + media[f"{media_type}s"].append({**media_info, "src": src}) # Clean up unwanted elements - if kwargs.get('remove_forms', False): - for form in element.xpath('.//form'): + if kwargs.get("remove_forms", False): + for form in element.xpath(".//form"): form.getparent().remove(form) - if excluded_tags := kwargs.get('excluded_tags', []): + if excluded_tags := kwargs.get("excluded_tags", []): for tag in excluded_tags: - for elem in element.xpath(f'.//{tag}'): + for elem in element.xpath(f".//{tag}"): elem.getparent().remove(elem) - if excluded_selector := kwargs.get('excluded_selector', ''): + if excluded_selector := kwargs.get("excluded_selector", ""): try: for elem in element.cssselect(excluded_selector): elem.getparent().remove(elem) @@ -864,12 +974,19 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): return True - def find_closest_parent_with_useful_text(self, element: lhtml.HtmlElement, **kwargs) -> Optional[str]: - image_description_min_word_threshold = kwargs.get('image_description_min_word_threshold', - IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD) + def find_closest_parent_with_useful_text( + self, element: lhtml.HtmlElement, **kwargs + ) -> Optional[str]: + image_description_min_word_threshold = kwargs.get( + "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD + ) current = element while current is not None: - if current.text and len(current.text_content().split()) >= image_description_min_word_threshold: + if ( + current.text + and len(current.text_content().split()) + >= image_description_min_word_threshold + ): return current.text_content().strip() current = current.getparent() return None @@ -878,52 +995,57 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): """Flatten nested elements of the same type in LXML tree""" if len(element) == 1 and element.tag == element[0].tag: return self.flatten_nested_elements(element[0]) - + for child in element: child_idx = element.index(child) flattened_child = self.flatten_nested_elements(child) if flattened_child is not child: # Only replace if actually flattened element[child_idx] = flattened_child - + return element - def process_image(self, img: lhtml.HtmlElement, url: str, index: int, total_images: int, **kwargs) -> Optional[List[Dict]]: + def process_image( + self, img: lhtml.HtmlElement, url: str, index: int, total_images: int, **kwargs + ) -> Optional[List[Dict]]: # Quick validation checks - style = img.get('style', '') - alt = img.get('alt', '') - src = img.get('src', '') - data_src = img.get('data-src', '') - srcset = img.get('srcset', '') - data_srcset = img.get('data-srcset', '') - - if 'display:none' in style: + style = img.get("style", "") + alt = img.get("alt", "") + src = img.get("src", "") + data_src = img.get("data-src", "") + srcset = img.get("srcset", "") + data_srcset = img.get("data-srcset", "") + + if "display:none" in style: return None parent = img.getparent() - if parent.tag in ['button', 'input']: + if parent.tag in ["button", "input"]: return None - parent_classes = parent.get('class', '').split() - if any('button' in cls or 'icon' in cls or 'logo' in cls for cls in parent_classes): + parent_classes = parent.get("class", "").split() + if any( + "button" in cls or "icon" in cls or "logo" in cls for cls in parent_classes + ): return None - + # If src is in class or alt, likely an icon - if (src and any(c in src for c in ['button', 'icon', 'logo'])) or \ - (alt and any(c in alt for c in ['button', 'icon', 'logo'])): + if (src and any(c in src for c in ["button", "icon", "logo"])) or ( + alt and any(c in alt for c in ["button", "icon", "logo"]) + ): return None # Score calculation score = 0 - if (width := img.get('width')) and width.isdigit(): + if (width := img.get("width")) and width.isdigit(): score += 1 if int(width) > 150 else 0 - if (height := img.get('height')) and height.isdigit(): + if (height := img.get("height")) and height.isdigit(): score += 1 if int(height) > 150 else 0 if alt: score += 1 - score += index/total_images < 0.5 + score += index / total_images < 0.5 # Check formats in all possible sources - image_formats = {'jpg', 'jpeg', 'png', 'webp', 'avif', 'gif'} + image_formats = {"jpg", "jpeg", "png", "webp", "avif", "gif"} detected_format = None for url in [src, data_src, srcset, data_srcset]: if url: @@ -936,51 +1058,55 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): if srcset or data_srcset: score += 1 - if picture := img.xpath('./ancestor::picture[1]'): + if picture := img.xpath("./ancestor::picture[1]"): score += 1 - if score <= kwargs.get('image_score_threshold', IMAGE_SCORE_THRESHOLD): + if score <= kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD): return None # Process image variants unique_urls = set() image_variants = [] base_info = { - 'alt': alt, - 'desc': self.find_closest_parent_with_useful_text(img, **kwargs), - 'score': score, - 'type': 'image', - 'group_id': index, - 'format': detected_format, + "alt": alt, + "desc": self.find_closest_parent_with_useful_text(img, **kwargs), + "score": score, + "type": "image", + "group_id": index, + "format": detected_format, } def add_variant(src: str, width: Optional[str] = None): - if src and not src.startswith('data:') and src not in unique_urls: + if src and not src.startswith("data:") and src not in unique_urls: unique_urls.add(src) - variant = {**base_info, 'src': src} + variant = {**base_info, "src": src} if width: - variant['width'] = width + variant["width"] = width image_variants.append(variant) # Add variants from different sources add_variant(src) add_variant(data_src) - + for srcset_attr in [srcset, data_srcset]: if srcset_attr: for source in parse_srcset(srcset_attr): - add_variant(source['url'], source['width']) + add_variant(source["url"], source["width"]) # Handle picture element if picture: - for source in picture[0].xpath('.//source[@srcset]'): - if source_srcset := source.get('srcset'): + for source in picture[0].xpath(".//source[@srcset]"): + if source_srcset := source.get("srcset"): for src_data in parse_srcset(source_srcset): - add_variant(src_data['url'], src_data['width']) + add_variant(src_data["url"], src_data["width"]) # Check framework-specific attributes for attr, value in img.attrib.items(): - if attr.startswith('data-') and ('src' in attr or 'srcset' in attr) and 'http' in value: + if ( + attr.startswith("data-") + and ("src" in attr or "srcset" in attr) + and "http" in value + ): add_variant(value) return image_variants if image_variants else None @@ -990,33 +1116,44 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): Remove elements that fall below the desired word threshold in a single pass from the bottom up. Skips non-element nodes like HtmlComment and bypasses certain tags that are allowed to have no content. """ - bypass_tags = {'a', 'img', 'br', 'hr', 'input', 'meta', 'link', 'source', 'track', 'wbr'} - + bypass_tags = { + "a", + "img", + "br", + "hr", + "input", + "meta", + "link", + "source", + "track", + "wbr", + } + for el in reversed(list(root.iterdescendants())): if not isinstance(el, lhtml.HtmlElement): continue - + if el.tag in bypass_tags: continue - + text_content = (el.text_content() or "").strip() - if len(text_content.split()) < word_count_threshold and not el.getchildren(): + if ( + len(text_content.split()) < word_count_threshold + and not el.getchildren() + ): parent = el.getparent() if parent is not None: parent.remove(el) - + return root - + def remove_unwanted_attributes_fast( - self, - root: lhtml.HtmlElement, - important_attrs=None, - keep_data_attributes=False + self, root: lhtml.HtmlElement, important_attrs=None, keep_data_attributes=False ) -> lhtml.HtmlElement: """ Removes all attributes from each element (including root) except those in `important_attrs`. If `keep_data_attributes=True`, also retain any attribute starting with 'data-'. - + Returns the same root element, mutated in-place, for fluent usage. """ if important_attrs is None: @@ -1029,26 +1166,32 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): # We only remove attributes on HtmlElement nodes, skip comments or text nodes if not isinstance(el, lhtml.HtmlElement): continue - + old_attribs = dict(el.attrib) new_attribs = {} - + for attr_name, attr_val in old_attribs.items(): # If it's an important attribute, keep it if attr_name in important_attrs: new_attribs[attr_name] = attr_val # Or if keep_data_attributes is True and it's a 'data-*' attribute - elif keep_data_attributes and attr_name.startswith('data-'): + elif keep_data_attributes and attr_name.startswith("data-"): new_attribs[attr_name] = attr_val # Clear old attributes and set the filtered set el.attrib.clear() el.attrib.update(new_attribs) - + return root - - def _scrap(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, - css_selector: str = None, **kwargs) -> Dict[str, Any]: + + def _scrap( + self, + url: str, + html: str, + word_count_threshold: int = MIN_WORD_THRESHOLD, + css_selector: str = None, + **kwargs, + ) -> Dict[str, Any]: if not html: return None @@ -1058,38 +1201,42 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): # Match BeautifulSoup's behavior of using body or full doc # body = doc.xpath('//body')[0] if doc.xpath('//body') else doc body = doc - + base_domain = get_base_domain(url) - - # Add comment removal - if kwargs.get('remove_comments', False): - comments = body.xpath('//comment()') + + # Add comment removal + if kwargs.get("remove_comments", False): + comments = body.xpath("//comment()") for comment in comments: comment.getparent().remove(comment) - + # Handle tag-based removal first - excluded_tags = set(kwargs.get('excluded_tags', []) or []) + excluded_tags = set(kwargs.get("excluded_tags", []) or []) if excluded_tags: for tag in excluded_tags: - for element in body.xpath(f'.//{tag}'): + for element in body.xpath(f".//{tag}"): if element.getparent() is not None: element.getparent().remove(element) - + # Handle CSS selector-based exclusion - excluded_selector = kwargs.get('excluded_selector', '') + excluded_selector = kwargs.get("excluded_selector", "") if excluded_selector: try: for element in body.cssselect(excluded_selector): if element.getparent() is not None: element.getparent().remove(element) except Exception as e: - self._log('error', f"Error with excluded CSS selector: {str(e)}", "SCRAPE") + self._log( + "error", f"Error with excluded CSS selector: {str(e)}", "SCRAPE" + ) # Extract metadata before any content filtering try: - meta = extract_metadata_using_lxml("", doc) # Using same function as BeautifulSoup version + meta = extract_metadata_using_lxml( + "", doc + ) # Using same function as BeautifulSoup version except Exception as e: - self._log('error', f"Error extracting metadata: {str(e)}", "SCRAPE") + self._log("error", f"Error extracting metadata: {str(e)}", "SCRAPE") meta = {} # Handle CSS selector targeting @@ -1098,101 +1245,106 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): selected_elements = body.cssselect(css_selector) if not selected_elements: return { - 'markdown': '', - 'cleaned_html': '', - 'success': True, - 'media': {'images': [], 'videos': [], 'audios': []}, - 'links': {'internal': [], 'external': []}, - 'metadata': meta, - 'message': f"No elements found for CSS selector: {css_selector}" + "markdown": "", + "cleaned_html": "", + "success": True, + "media": {"images": [], "videos": [], "audios": []}, + "links": {"internal": [], "external": []}, + "metadata": meta, + "message": f"No elements found for CSS selector: {css_selector}", } - body = lhtml.Element('div') + body = lhtml.Element("div") body.extend(selected_elements) except Exception as e: - self._log('error', f"Error with CSS selector: {str(e)}", "SCRAPE") + self._log("error", f"Error with CSS selector: {str(e)}", "SCRAPE") return None # Remove script and style tags - for tag in ['script', 'style', 'link', 'meta', 'noscript']: - for element in body.xpath(f'.//{tag}'): + for tag in ["script", "style", "link", "meta", "noscript"]: + for element in body.xpath(f".//{tag}"): if element.getparent() is not None: element.getparent().remove(element) # Handle social media and domain exclusions - kwargs['exclude_domains'] = set(kwargs.get('exclude_domains', [])) - if kwargs.get('exclude_social_media_links', False): - kwargs['exclude_social_media_domains'] = set(kwargs.get('exclude_social_media_domains', []) + SOCIAL_MEDIA_DOMAINS) - kwargs['exclude_domains'].update(kwargs['exclude_social_media_domains']) + kwargs["exclude_domains"] = set(kwargs.get("exclude_domains", [])) + if kwargs.get("exclude_social_media_links", False): + kwargs["exclude_social_media_domains"] = set( + kwargs.get("exclude_social_media_domains", []) + + SOCIAL_MEDIA_DOMAINS + ) + kwargs["exclude_domains"].update(kwargs["exclude_social_media_domains"]) # Process forms if needed - if kwargs.get('remove_forms', False): - for form in body.xpath('.//form'): + if kwargs.get("remove_forms", False): + for form in body.xpath(".//form"): if form.getparent() is not None: form.getparent().remove(form) - # Process content - media = {'images': [], 'videos': [], 'audios': []} + media = {"images": [], "videos": [], "audios": []} internal_links_dict = {} external_links_dict = {} - + self._process_element( - url, - body, - media, + url, + body, + media, internal_links_dict, external_links_dict, base_domain=base_domain, - **kwargs + **kwargs, ) # Handle only_text option - if kwargs.get('only_text', False): + if kwargs.get("only_text", False): for tag in ONLY_TEXT_ELIGIBLE_TAGS: - for element in body.xpath(f'.//{tag}'): + for element in body.xpath(f".//{tag}"): if element.text: - new_text = lhtml.Element('span') + new_text = lhtml.Element("span") new_text.text = element.text_content() if element.getparent() is not None: element.getparent().replace(element, new_text) # Clean base64 images - for img in body.xpath('.//img[@src]'): - src = img.get('src', '') + for img in body.xpath(".//img[@src]"): + src = img.get("src", "") if self.BASE64_PATTERN.match(src): - img.set('src', self.BASE64_PATTERN.sub('', src)) - + img.set("src", self.BASE64_PATTERN.sub("", src)) # Remove empty elements self.remove_empty_elements_fast(body, 1) - - # Remvoe unneeded attributes - self.remove_unwanted_attributes_fast(body, keep_data_attributes=kwargs.get('keep_data_attributes', False)) + # Remvoe unneeded attributes + self.remove_unwanted_attributes_fast( + body, keep_data_attributes=kwargs.get("keep_data_attributes", False) + ) # Generate output HTML - cleaned_html = lhtml.tostring(body, encoding='unicode', - pretty_print=True, - method='html', - with_tail=False).strip() + cleaned_html = lhtml.tostring( + body, + encoding="unicode", + pretty_print=True, + method="html", + with_tail=False, + ).strip() return { - 'cleaned_html': cleaned_html, - 'success': success, - 'media': media, - 'links': { - 'internal': list(internal_links_dict.values()), - 'external': list(external_links_dict.values()) + "cleaned_html": cleaned_html, + "success": success, + "media": media, + "links": { + "internal": list(internal_links_dict.values()), + "external": list(external_links_dict.values()), }, - 'metadata': meta + "metadata": meta, } - + except Exception as e: - self._log('error', f"Error processing HTML: {str(e)}", "SCRAPE") + self._log("error", f"Error processing HTML: {str(e)}", "SCRAPE") # Create error message in case of failure - error_body = lhtml.Element('div') + error_body = lhtml.Element("div") # Use etree.SubElement rather than lhtml.SubElement - error_div = etree.SubElement(error_body, 'div', id='crawl4ai_error_message') - error_div.text = f''' + error_div = etree.SubElement(error_body, "div", id="crawl4ai_error_message") + error_div.text = f""" Crawl4AI Error: This page is not fully supported. Error Message: {str(e)} @@ -1207,12 +1359,14 @@ class LXMLWebScrapingStrategy(WebScrapingStrategy): - Set headless=False to visualize what's happening on the page. If the issue persists, please check the page's structure and any potential anti-crawling measures. - ''' - cleaned_html = lhtml.tostring(error_body, encoding='unicode', pretty_print=True) + """ + cleaned_html = lhtml.tostring( + error_body, encoding="unicode", pretty_print=True + ) return { - 'cleaned_html': cleaned_html, - 'success': False, - 'media': {'images': [], 'videos': [], 'audios': []}, - 'links': {'internal': [], 'external': []}, - 'metadata': {} - } \ No newline at end of file + "cleaned_html": cleaned_html, + "success": False, + "media": {"images": [], "videos": [], "audios": []}, + "links": {"internal": [], "external": []}, + "metadata": {}, + } diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index 898dcfa8..34e20ecd 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -15,54 +15,53 @@ import logging, time import base64 from PIL import Image, ImageDraw, ImageFont from io import BytesIO -from typing import List, Callable +from typing import Callable import requests import os from pathlib import Path from .utils import * -logger = logging.getLogger('selenium.webdriver.remote.remote_connection') +logger = logging.getLogger("selenium.webdriver.remote.remote_connection") logger.setLevel(logging.WARNING) -logger_driver = logging.getLogger('selenium.webdriver.common.service') +logger_driver = logging.getLogger("selenium.webdriver.common.service") logger_driver.setLevel(logging.WARNING) -urllib3_logger = logging.getLogger('urllib3.connectionpool') +urllib3_logger = logging.getLogger("urllib3.connectionpool") urllib3_logger.setLevel(logging.WARNING) # Disable http.client logging -http_client_logger = logging.getLogger('http.client') +http_client_logger = logging.getLogger("http.client") http_client_logger.setLevel(logging.WARNING) # Disable driver_finder and service logging -driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder') +driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder") driver_finder_logger.setLevel(logging.WARNING) - - class CrawlerStrategy(ABC): @abstractmethod def crawl(self, url: str, **kwargs) -> str: pass - + @abstractmethod def take_screenshot(self, save_path: str): pass - + @abstractmethod def update_user_agent(self, user_agent: str): pass - + @abstractmethod def set_hook(self, hook_type: str, hook: Callable): pass + class CloudCrawlerStrategy(CrawlerStrategy): - def __init__(self, use_cached_html = False): + def __init__(self, use_cached_html=False): super().__init__() self.use_cached_html = use_cached_html - + def crawl(self, url: str) -> str: data = { "urls": [url], @@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy): html = response["results"][0]["html"] return sanitize_input_encode(html) + class LocalSeleniumCrawlerStrategy(CrawlerStrategy): def __init__(self, use_cached_html=False, js_code=None, **kwargs): super().__init__() @@ -87,20 +87,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): if kwargs.get("user_agent"): self.options.add_argument("--user-agent=" + kwargs.get("user_agent")) else: - user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") + user_agent = kwargs.get( + "user_agent", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + ) self.options.add_argument(f"--user-agent={user_agent}") - self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") - + self.options.add_argument( + "user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + ) + self.options.headless = kwargs.get("headless", True) if self.options.headless: self.options.add_argument("--headless") - - self.options.add_argument("--disable-gpu") + + self.options.add_argument("--disable-gpu") self.options.add_argument("--window-size=1920,1080") self.options.add_argument("--no-sandbox") self.options.add_argument("--disable-dev-shm-usage") - self.options.add_argument("--disable-blink-features=AutomationControlled") - + self.options.add_argument("--disable-blink-features=AutomationControlled") + # self.options.add_argument("--disable-dev-shm-usage") self.options.add_argument("--disable-gpu") # self.options.add_argument("--disable-extensions") @@ -120,14 +125,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.use_cached_html = use_cached_html self.js_code = js_code self.verbose = kwargs.get("verbose", False) - + # Hooks self.hooks = { - 'on_driver_created': None, - 'on_user_agent_updated': None, - 'before_get_url': None, - 'after_get_url': None, - 'before_return_html': None + "on_driver_created": None, + "on_user_agent_updated": None, + "before_get_url": None, + "after_get_url": None, + "before_return_html": None, } # chromedriver_autoinstaller.install() @@ -137,31 +142,28 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): # chromedriver_path = chromedriver_autoinstaller.install() # chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver() # self.service = Service(chromedriver_autoinstaller.install()) - - + # chromedriver_path = ChromeDriverManager().install() # self.service = Service(chromedriver_path) # self.service.log_path = "NUL" # self.driver = webdriver.Chrome(service=self.service, options=self.options) - + # Use selenium-manager (built into Selenium 4.10.0+) self.service = Service() self.driver = webdriver.Chrome(options=self.options) - - self.driver = self.execute_hook('on_driver_created', self.driver) - + + self.driver = self.execute_hook("on_driver_created", self.driver) + if kwargs.get("cookies"): for cookie in kwargs.get("cookies"): self.driver.add_cookie(cookie) - - def set_hook(self, hook_type: str, hook: Callable): if hook_type in self.hooks: self.hooks[hook_type] = hook else: raise ValueError(f"Invalid hook type: {hook_type}") - + def execute_hook(self, hook_type: str, *args): hook = self.hooks.get(hook_type) if hook: @@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): if isinstance(result, webdriver.Chrome): return result else: - raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.") + raise TypeError( + f"Hook {hook_type} must return an instance of webdriver.Chrome or None." + ) # If the hook returns None or there is no hook, return self.driver return self.driver @@ -178,60 +182,77 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.options.add_argument(f"user-agent={user_agent}") self.driver.quit() self.driver = webdriver.Chrome(service=self.service, options=self.options) - self.driver = self.execute_hook('on_user_agent_updated', self.driver) + self.driver = self.execute_hook("on_user_agent_updated", self.driver) def set_custom_headers(self, headers: dict): # Enable Network domain for sending headers - self.driver.execute_cdp_cmd('Network.enable', {}) + self.driver.execute_cdp_cmd("Network.enable", {}) # Set extra HTTP headers - self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers}) + self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers}) - def _ensure_page_load(self, max_checks=6, check_interval=0.01): + def _ensure_page_load(self, max_checks=6, check_interval=0.01): initial_length = len(self.driver.page_source) - + for ix in range(max_checks): # print(f"Checking page load: {ix}") time.sleep(check_interval) current_length = len(self.driver.page_source) - + if current_length != initial_length: break return self.driver.page_source - + def crawl(self, url: str, **kwargs) -> str: # Create md5 hash of the URL import hashlib + url_hash = hashlib.md5(url.encode()).hexdigest() - + if self.use_cached_html: - cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) + cache_file_path = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), + ".crawl4ai", + "cache", + url_hash, + ) if os.path.exists(cache_file_path): with open(cache_file_path, "r") as f: return sanitize_input_encode(f.read()) try: - self.driver = self.execute_hook('before_get_url', self.driver) + self.driver = self.execute_hook("before_get_url", self.driver) if self.verbose: print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...") - self.driver.get(url) # - + self.driver.get(url) # + WebDriverWait(self.driver, 20).until( - lambda d: d.execute_script('return document.readyState') == 'complete' + lambda d: d.execute_script("return document.readyState") == "complete" ) WebDriverWait(self.driver, 10).until( EC.presence_of_all_elements_located((By.TAG_NAME, "body")) ) - - self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);") - - self.driver = self.execute_hook('after_get_url', self.driver) - html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source - can_not_be_done_headless = False # Look at my creativity for naming variables - + + self.driver.execute_script( + "window.scrollTo(0, document.body.scrollHeight);" + ) + + self.driver = self.execute_hook("after_get_url", self.driver) + html = sanitize_input_encode( + self._ensure_page_load() + ) # self.driver.page_source + can_not_be_done_headless = ( + False # Look at my creativity for naming variables + ) + # TODO: Very ugly approach, but promise to change it! - if kwargs.get('bypass_headless', False) or html == "": - print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...") + if ( + kwargs.get("bypass_headless", False) + or html == "" + ): + print( + "[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..." + ) can_not_be_done_headless = True options = Options() options.headless = False @@ -239,27 +260,31 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): options.add_argument("--window-size=5,5") driver = webdriver.Chrome(service=self.service, options=options) driver.get(url) - self.driver = self.execute_hook('after_get_url', driver) + self.driver = self.execute_hook("after_get_url", driver) html = sanitize_input_encode(driver.page_source) driver.quit() - + # Execute JS code if provided self.js_code = kwargs.get("js_code", self.js_code) if self.js_code and type(self.js_code) == str: self.driver.execute_script(self.js_code) # Optionally, wait for some condition after executing the JS code WebDriverWait(self.driver, 10).until( - lambda driver: driver.execute_script("return document.readyState") == "complete" + lambda driver: driver.execute_script("return document.readyState") + == "complete" ) elif self.js_code and type(self.js_code) == list: for js in self.js_code: self.driver.execute_script(js) WebDriverWait(self.driver, 10).until( - lambda driver: driver.execute_script("return document.readyState") == "complete" + lambda driver: driver.execute_script( + "return document.readyState" + ) + == "complete" ) - + # Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky) - wait_for = kwargs.get('wait_for', False) + wait_for = kwargs.get("wait_for", False) if wait_for: if callable(wait_for): print("[LOG] 🔄 Waiting for condition...") @@ -268,32 +293,37 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): print("[LOG] 🔄 Waiting for condition...") WebDriverWait(self.driver, 20).until( EC.presence_of_element_located((By.CSS_SELECTOR, wait_for)) - ) - + ) + if not can_not_be_done_headless: html = sanitize_input_encode(self.driver.page_source) - self.driver = self.execute_hook('before_return_html', self.driver, html) - + self.driver = self.execute_hook("before_return_html", self.driver, html) + # Store in cache - cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) + cache_file_path = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), + ".crawl4ai", + "cache", + url_hash, + ) with open(cache_file_path, "w", encoding="utf-8") as f: f.write(html) - + if self.verbose: print(f"[LOG] ✅ Crawled {url} successfully!") - + return html except InvalidArgumentException as e: - if not hasattr(e, 'msg'): + if not hasattr(e, "msg"): e.msg = sanitize_input_encode(str(e)) raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}") except WebDriverException as e: # If e does nlt have msg attribute create it and set it to str(e) - if not hasattr(e, 'msg'): + if not hasattr(e, "msg"): e.msg = sanitize_input_encode(str(e)) - raise WebDriverException(f"Failed to crawl {url}: {e.msg}") + raise WebDriverException(f"Failed to crawl {url}: {e.msg}") except Exception as e: - if not hasattr(e, 'msg'): + if not hasattr(e, "msg"): e.msg = sanitize_input_encode(str(e)) raise Exception(f"Failed to crawl {url}: {e.msg}") @@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): try: # Get the dimensions of the page total_width = self.driver.execute_script("return document.body.scrollWidth") - total_height = self.driver.execute_script("return document.body.scrollHeight") + total_height = self.driver.execute_script( + "return document.body.scrollHeight" + ) # Set the window size to the dimensions of the page self.driver.set_window_size(total_width, total_height) @@ -313,25 +345,27 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): image = Image.open(BytesIO(screenshot)) # Convert image to RGB mode (this will handle both RGB and RGBA images) - rgb_image = image.convert('RGB') + rgb_image = image.convert("RGB") # Convert to JPEG and compress buffered = BytesIO() rgb_image.save(buffered, format="JPEG", quality=85) - img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") if self.verbose: - print(f"[LOG] 📸 Screenshot taken and converted to base64") + print("[LOG] 📸 Screenshot taken and converted to base64") return img_base64 except Exception as e: - error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}") + error_message = sanitize_input_encode( + f"Failed to take screenshot: {str(e)}" + ) print(error_message) # Generate an image with black background - img = Image.new('RGB', (800, 600), color='black') + img = Image.new("RGB", (800, 600), color="black") draw = ImageDraw.Draw(img) - + # Load a font try: font = ImageFont.truetype("arial.ttf", 40) @@ -345,16 +379,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): # Calculate text position text_position = (10, 10) - + # Draw the text on the image draw.text(text_position, wrapped_text, fill=text_color, font=font) - + # Convert to base64 buffered = BytesIO() img.save(buffered, format="JPEG") - img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_base64 - + def quit(self): self.driver.quit() diff --git a/crawl4ai/database.py b/crawl4ai/database.py index 42ad7017..815b6b05 100644 --- a/crawl4ai/database.py +++ b/crawl4ai/database.py @@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra os.makedirs(DB_PATH, exist_ok=True) DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") + def init_db(): global DB_PATH conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS crawled_data ( url TEXT PRIMARY KEY, html TEXT, @@ -24,31 +26,42 @@ def init_db(): metadata TEXT DEFAULT "{}", screenshot TEXT DEFAULT "" ) - ''') + """ + ) conn.commit() conn.close() + def alter_db_add_screenshot(new_column: str = "media"): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') + cursor.execute( + f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""' + ) conn.commit() conn.close() except Exception as e: print(f"Error altering database to add screenshot column: {e}") + def check_db_path(): if not DB_PATH: raise ValueError("Database path is not set or is empty.") -def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]: + +def get_cached_url( + url: str, +) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]: check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,)) + cursor.execute( + "SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?", + (url,), + ) result = cursor.fetchone() conn.close() return result @@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str print(f"Error retrieving cached URL: {e}") return None -def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""): + +def cache_url( + url: str, + html: str, + cleaned_html: str, + markdown: str, + extracted_content: str, + success: bool, + media: str = "{}", + links: str = "{}", + metadata: str = "{}", + screenshot: str = "", +): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(url) DO UPDATE SET @@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c links = excluded.links, metadata = excluded.metadata, screenshot = excluded.screenshot - ''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)) + """, + ( + url, + html, + cleaned_html, + markdown, + extracted_content, + success, + media, + links, + metadata, + screenshot, + ), + ) conn.commit() conn.close() except Exception as e: print(f"Error caching URL: {e}") + def get_total_count() -> int: check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('SELECT COUNT(*) FROM crawled_data') + cursor.execute("SELECT COUNT(*) FROM crawled_data") result = cursor.fetchone() conn.close() return result[0] @@ -93,43 +133,48 @@ def get_total_count() -> int: print(f"Error getting total count: {e}") return 0 + def clear_db(): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('DELETE FROM crawled_data') + cursor.execute("DELETE FROM crawled_data") conn.commit() conn.close() except Exception as e: print(f"Error clearing database: {e}") - + + def flush_db(): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute('DROP TABLE crawled_data') + cursor.execute("DROP TABLE crawled_data") conn.commit() conn.close() except Exception as e: print(f"Error flushing database: {e}") + def update_existing_records(new_column: str = "media", default_value: str = "{}"): check_db_path() try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() - cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL') + cursor.execute( + f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL' + ) conn.commit() conn.close() except Exception as e: print(f"Error updating existing records: {e}") + if __name__ == "__main__": # Delete the existing database file if os.path.exists(DB_PATH): os.remove(DB_PATH) - init_db() + init_db() # alter_db_add_screenshot("COL_NAME") - diff --git a/crawl4ai/docs_manager.py b/crawl4ai/docs_manager.py index aacc5812..9a6096a5 100644 --- a/crawl4ai/docs_manager.py +++ b/crawl4ai/docs_manager.py @@ -4,6 +4,7 @@ from pathlib import Path from crawl4ai.async_logger import AsyncLogger from crawl4ai.llmtxt import AsyncLLMTextManager + class DocsManager: def __init__(self, logger=None): self.docs_dir = Path.home() / ".crawl4ai" / "docs" @@ -21,11 +22,14 @@ class DocsManager: """Copy from local docs or download from GitHub""" try: # Try local first - if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))): + if self.local_docs.exists() and ( + any(self.local_docs.glob("*.md")) + or any(self.local_docs.glob("*.tokens")) + ): # Empty the local docs directory for file_path in self.docs_dir.glob("*.md"): file_path.unlink() - # for file_path in self.docs_dir.glob("*.tokens"): + # for file_path in self.docs_dir.glob("*.tokens"): # file_path.unlink() for file_path in self.local_docs.glob("*.md"): shutil.copy2(file_path, self.docs_dir / file_path.name) @@ -36,14 +40,14 @@ class DocsManager: # Fallback to GitHub response = requests.get( "https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt", - headers={'Accept': 'application/vnd.github.v3+json'} + headers={"Accept": "application/vnd.github.v3+json"}, ) response.raise_for_status() - + for item in response.json(): - if item['type'] == 'file' and item['name'].endswith('.md'): - content = requests.get(item['download_url']).text - with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f: + if item["type"] == "file" and item["name"].endswith(".md"): + content = requests.get(item["download_url"]).text + with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f: f.write(content) return True @@ -57,11 +61,15 @@ class DocsManager: # Remove [0-9]+_ prefix names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names] # Exclude those end with .xs.md and .q.md - names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")] + names = [ + name + for name in names + if not name.endswith(".xs") and not name.endswith(".q") + ] return names - + def generate(self, sections, mode="extended"): return self.llm_text.generate(sections, mode) - + def search(self, query: str, top_k: int = 5): - return self.llm_text.search(query, top_k) \ No newline at end of file + return self.llm_text.search(query, top_k) diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index f6c62cc9..65cc005d 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -1,26 +1,54 @@ from abc import ABC, abstractmethod -from typing import Any, List, Dict, Optional, Union +from typing import Any, List, Dict, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -import json, time -# from optimum.intel import IPEXModel -from .prompts import * -from .config import * -from .utils import * -from .models import * +import json +import time +import os + +from .prompts import PROMPT_EXTRACT_BLOCKS +from .config import ( + DEFAULT_PROVIDER, PROVIDER_MODELS, + CHUNK_TOKEN_THRESHOLD, + OVERLAP_RATE, + WORD_TOKEN_RATE, + PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION, + PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION +) +from .utils import * # noqa: F403 + +from .utils import ( + sanitize_html, + calculate_batch_size, + escape_json_string, + perform_completion_with_backoff, + extract_xml_data, + split_and_parse_json_objects, + sanitize_input_encode, +) +from .models import * # noqa: F403 + +from .models import TokenUsage + +from .model_loader import * # noqa: F403 +from .model_loader import ( + get_device, + load_HF_embedding_model, + load_text_multilabel_classifier, +) + from functools import partial -from .model_loader import * import math import numpy as np import re from bs4 import BeautifulSoup from lxml import html, etree -from dataclasses import dataclass + class ExtractionStrategy(ABC): """ Abstract base class for all extraction strategies. """ - + def __init__(self, input_format: str = "markdown", **kwargs): """ Initialize the extraction strategy. @@ -45,7 +73,7 @@ class ExtractionStrategy(ABC): :return: A list of extracted blocks or chunks. """ pass - + def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: """ Process sections of text in parallel by default. @@ -56,40 +84,49 @@ class ExtractionStrategy(ABC): """ extracted_content = [] with ThreadPoolExecutor() as executor: - futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections] + futures = [ + executor.submit(self.extract, url, section, **kwargs) + for section in sections + ] for future in as_completed(futures): extracted_content.extend(future.result()) - return extracted_content - + return extracted_content + + class NoExtractionStrategy(ExtractionStrategy): """ A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block. """ + def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: """ Extract meaningful blocks or chunks from the given HTML. """ return [{"index": 0, "content": html}] - + def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: - return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)] + return [ + {"index": i, "tags": [], "content": section} + for i, section in enumerate(sections) + ] ####################################################### # Strategies using clustering for text data extraction # ####################################################### + class CosineStrategy(ExtractionStrategy): """ Extract meaningful blocks or chunks from the given HTML using cosine similarity. - + How it works: 1. Pre-filter documents using embeddings and semantic_filter. 2. Perform clustering using cosine similarity. 3. Organize texts by their cluster labels, retaining order. 4. Filter clusters by word count. 5. Extract meaningful blocks or chunks from the filtered clusters. - + Attributes: semantic_filter (str): A keyword filter for document filtering. word_count_threshold (int): Minimum number of words per cluster. @@ -98,8 +135,19 @@ class CosineStrategy(ExtractionStrategy): top_k (int): Number of top categories to extract. model_name (str): The name of the sentence-transformers model. sim_threshold (float): The similarity threshold for clustering. - """ - def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'sentence-transformers/all-MiniLM-L6-v2', sim_threshold = 0.3, **kwargs): + """ + + def __init__( + self, + semantic_filter=None, + word_count_threshold=10, + max_dist=0.2, + linkage_method="ward", + top_k=3, + model_name="sentence-transformers/all-MiniLM-L6-v2", + sim_threshold=0.3, + **kwargs, + ): """ Initialize the strategy with clustering parameters. @@ -111,9 +159,9 @@ class CosineStrategy(ExtractionStrategy): top_k (int): Number of top categories to extract. """ super().__init__(**kwargs) - + import numpy as np - + self.semantic_filter = semantic_filter self.word_count_threshold = word_count_threshold self.max_dist = max_dist @@ -122,14 +170,14 @@ class CosineStrategy(ExtractionStrategy): self.sim_threshold = sim_threshold self.timer = time.time() self.verbose = kwargs.get("verbose", False) - + self.buffer_embeddings = np.array([]) self.get_embedding_method = "direct" - + self.device = get_device() # import torch # self.device = torch.device('cpu') - + self.default_batch_size = calculate_batch_size(self.device) if self.verbose: @@ -143,10 +191,10 @@ class CosineStrategy(ExtractionStrategy): self.tokenizer, self.model = load_HF_embedding_model(model_name) self.model.to(self.device) - self.model.eval() - + self.model.eval() + self.get_embedding_method = "batch" - + self.buffer_embeddings = np.array([]) # if model_name == "bert-base-uncased": @@ -161,18 +209,23 @@ class CosineStrategy(ExtractionStrategy): # self.model = load_onnx_all_MiniLM_l6_v2() # self.tokenizer = self.model.tokenizer # self.get_embedding_method = "direct" - - + if self.verbose: print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.") - + self.nlp, _ = load_text_multilabel_classifier() # self.default_batch_size = 16 if self.device.type == 'cpu' else 64 - - if self.verbose: - print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") - def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, at_least_k: int = 20) -> List[str]: + if self.verbose: + print( + f"[LOG] Model loaded {model_name}, models/reuters, took " + + str(time.time() - self.timer) + + " seconds" + ) + + def filter_documents_embeddings( + self, documents: List[str], semantic_filter: str, at_least_k: int = 20 + ) -> List[str]: """ Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding. @@ -184,39 +237,51 @@ class CosineStrategy(ExtractionStrategy): Returns: List[str]: A list of filtered and sorted document texts. """ - + if not semantic_filter: return documents - + if len(documents) < at_least_k: at_least_k = len(documents) // 2 - + from sklearn.metrics.pairwise import cosine_similarity - + # Compute embedding for the keyword filter query_embedding = self.get_embeddings([semantic_filter])[0] - + # Compute embeddings for the documents document_embeddings = self.get_embeddings(documents) - + # Calculate cosine similarity between the query embedding and document embeddings - similarities = cosine_similarity([query_embedding], document_embeddings).flatten() - + similarities = cosine_similarity( + [query_embedding], document_embeddings + ).flatten() + # Filter documents based on the similarity threshold - filtered_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim >= self.sim_threshold] - + filtered_docs = [ + (doc, sim) + for doc, sim in zip(documents, similarities) + if sim >= self.sim_threshold + ] + # If the number of filtered documents is less than at_least_k, sort remaining documents by similarity if len(filtered_docs) < at_least_k: - remaining_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim < self.sim_threshold] + remaining_docs = [ + (doc, sim) + for doc, sim in zip(documents, similarities) + if sim < self.sim_threshold + ] remaining_docs.sort(key=lambda x: x[1], reverse=True) - filtered_docs.extend(remaining_docs[:at_least_k - len(filtered_docs)]) - + filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)]) + # Extract the document texts from the tuples filtered_docs = [doc for doc, _ in filtered_docs] - + return filtered_docs[:at_least_k] - - def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=False): + + def get_embeddings( + self, sentences: List[str], batch_size=None, bypass_buffer=False + ): """ Get BERT embeddings for a list of sentences. @@ -228,43 +293,48 @@ class CosineStrategy(ExtractionStrategy): """ # if self.buffer_embeddings.any() and not bypass_buffer: # return self.buffer_embeddings - - if self.device.type in [ "cpu", "gpu", "cuda", "mps"]: - import torch + + if self.device.type in ["cpu", "gpu", "cuda", "mps"]: + import torch + # Tokenize sentences and convert to tensor if batch_size is None: batch_size = self.default_batch_size - + all_embeddings = [] for i in range(0, len(sentences), batch_size): - batch_sentences = sentences[i:i + batch_size] - encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt') - encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()} - + batch_sentences = sentences[i : i + batch_size] + encoded_input = self.tokenizer( + batch_sentences, padding=True, truncation=True, return_tensors="pt" + ) + encoded_input = { + key: tensor.to(self.device) for key, tensor in encoded_input.items() + } + # Ensure no gradients are calculated with torch.no_grad(): model_output = self.model(**encoded_input) - + # Get embeddings from the last hidden state (mean pooling) embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy() all_embeddings.append(embeddings) - + self.buffer_embeddings = np.vstack(all_embeddings) - elif self.device.type == "cpu": + elif self.device.type == "cpu": # self.buffer_embeddings = self.model(sentences) if batch_size is None: batch_size = self.default_batch_size - + all_embeddings = [] for i in range(0, len(sentences), batch_size): - batch_sentences = sentences[i:i + batch_size] + batch_sentences = sentences[i : i + batch_size] embeddings = self.model(batch_sentences) all_embeddings.append(embeddings) - + self.buffer_embeddings = np.vstack(all_embeddings) return self.buffer_embeddings - def hierarchical_clustering(self, sentences: List[str], embeddings = None): + def hierarchical_clustering(self, sentences: List[str], embeddings=None): """ Perform hierarchical clustering on sentences and return cluster labels. @@ -277,18 +347,21 @@ class CosineStrategy(ExtractionStrategy): # Get embeddings from scipy.cluster.hierarchy import linkage, fcluster from scipy.spatial.distance import pdist + self.timer = time.time() embeddings = self.get_embeddings(sentences, bypass_buffer=True) # print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds") # Compute pairwise cosine distances - distance_matrix = pdist(embeddings, 'cosine') + distance_matrix = pdist(embeddings, "cosine") # Perform agglomerative clustering respecting order linked = linkage(distance_matrix, method=self.linkage_method) # Form flat clusters - labels = fcluster(linked, self.max_dist, criterion='distance') + labels = fcluster(linked, self.max_dist, criterion="distance") return labels - def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]) -> Dict[int, List[str]]: + def filter_clusters_by_word_count( + self, clusters: Dict[int, List[str]] + ) -> Dict[int, List[str]]: """ Filter clusters to remove those with a word count below the threshold. @@ -304,7 +377,7 @@ class CosineStrategy(ExtractionStrategy): full_text = " ".join(texts) # Count words word_count = len(full_text.split()) - + # Keep clusters with word count above the threshold if word_count >= self.word_count_threshold: filtered_clusters[cluster_id] = texts @@ -325,9 +398,11 @@ class CosineStrategy(ExtractionStrategy): # Assume `html` is a list of text chunks for this strategy t = time.time() text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed - + # Pre-filter documents using embeddings and semantic_filter - text_chunks = self.filter_documents_embeddings(text_chunks, self.semantic_filter) + text_chunks = self.filter_documents_embeddings( + text_chunks, self.semantic_filter + ) if not text_chunks: return [] @@ -346,16 +421,19 @@ class CosineStrategy(ExtractionStrategy): filtered_clusters = self.filter_clusters_by_word_count(clusters) # Convert filtered clusters to a sorted list of dictionaries - cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)] - + cluster_list = [ + {"index": int(idx), "tags": [], "content": " ".join(filtered_clusters[idx])} + for idx in sorted(filtered_clusters) + ] + if self.verbose: print(f"[LOG] 🚀 Assign tags using {self.device}") - + if self.device.type in ["gpu", "cuda", "mps", "cpu"]: - labels = self.nlp([cluster['content'] for cluster in cluster_list]) - + labels = self.nlp([cluster["content"] for cluster in cluster_list]) + for cluster, label in zip(cluster_list, labels): - cluster['tags'] = label + cluster["tags"] = label # elif self.device.type == "cpu": # # Process the text with the loaded model # texts = [cluster['content'] for cluster in cluster_list] @@ -366,16 +444,16 @@ class CosineStrategy(ExtractionStrategy): # tok_k = self.top_k # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] # cluster['tags'] = [cat for cat, _ in top_categories] - - # for cluster in cluster_list: - # doc = self.nlp(cluster['content']) - # tok_k = self.top_k - # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] - # cluster['tags'] = [cat for cat, _ in top_categories] - + + # for cluster in cluster_list: + # doc = self.nlp(cluster['content']) + # tok_k = self.top_k + # top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] + # cluster['tags'] = [cat for cat, _ in top_categories] + if self.verbose: print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") - + return cluster_list def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: @@ -389,9 +467,8 @@ class CosineStrategy(ExtractionStrategy): Returns: """ # This strategy processes all sections together - - return self.extract(url, self.DEL.join(sections), **kwargs) + return self.extract(url, self.DEL.join(sections), **kwargs) ####################################################### @@ -400,11 +477,11 @@ class CosineStrategy(ExtractionStrategy): class LLMExtractionStrategy(ExtractionStrategy): """ A strategy that uses an LLM to extract meaningful content from the HTML. - + Attributes: provider: The provider to use for extraction. It follows the format /, e.g., "ollama/llama3.3". api_token: The API token for the provider. - instruction: The instruction to use for the LLM model. + instruction: The instruction to use for the LLM model. schema: Pydantic model schema for structured data. extraction_type: "block" or "schema". chunk_token_threshold: Maximum tokens per chunk. @@ -419,16 +496,22 @@ class LLMExtractionStrategy(ExtractionStrategy): total_usage: Accumulated token usage. """ - def __init__(self, - provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, - instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs): + def __init__( + self, + provider: str = DEFAULT_PROVIDER, + api_token: Optional[str] = None, + instruction: str = None, + schema: Dict = None, + extraction_type="block", + **kwargs, + ): """ Initialize the strategy with clustering parameters. - + Args: provider: The provider to use for extraction. It follows the format /, e.g., "ollama/llama3.3". api_token: The API token for the provider. - instruction: The instruction to use for the LLM model. + instruction: The instruction to use for the LLM model. schema: Pydantic model schema for structured data. extraction_type: "block" or "schema". chunk_token_threshold: Maximum tokens per chunk. @@ -440,19 +523,25 @@ class LLMExtractionStrategy(ExtractionStrategy): extra_args: Additional arguments for the API request, such as temprature, max_tokens, etc. verbose: Whether to print verbose output. usages: List of individual token usages. - total_usage: Accumulated token usage. + total_usage: Accumulated token usage. """ super().__init__(**kwargs) self.provider = provider - self.api_token = api_token or PROVIDER_MODELS.get(provider, "no-token") or os.getenv("OPENAI_API_KEY") + self.api_token = ( + api_token + or PROVIDER_MODELS.get(provider, "no-token") + or os.getenv("OPENAI_API_KEY") + ) self.instruction = instruction self.extract_type = extraction_type self.schema = schema if schema: self.extract_type = "schema" - - self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD) + + self.chunk_token_threshold = kwargs.get( + "chunk_token_threshold", CHUNK_TOKEN_THRESHOLD + ) self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE) self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE) self.apply_chunking = kwargs.get("apply_chunking", True) @@ -461,29 +550,30 @@ class LLMExtractionStrategy(ExtractionStrategy): self.extra_args = kwargs.get("extra_args", {}) if not self.apply_chunking: self.chunk_token_threshold = 1e9 - + self.verbose = kwargs.get("verbose", False) self.usages = [] # Store individual usages - self.total_usage = TokenUsage() # Accumulated usage - + self.total_usage = TokenUsage() # Accumulated usage + if not self.api_token: - raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.") - - - def extract(self, url: str, ix:int, html: str) -> List[Dict[str, Any]]: + raise ValueError( + "API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable." + ) + + def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]: """ Extract meaningful blocks or chunks from the given HTML using an LLM. - + How it works: 1. Construct a prompt with variables. 2. Make a request to the LLM using the prompt. 3. Parse the response and extract blocks or chunks. - + Args: url: The URL of the webpage. ix: Index of the block. html: The HTML content of the webpage. - + Returns: A list of extracted blocks or chunks. """ @@ -495,12 +585,12 @@ class LLMExtractionStrategy(ExtractionStrategy): "URL": url, "HTML": escape_json_string(sanitize_html(html)), } - + prompt_with_variables = PROMPT_EXTRACT_BLOCKS if self.instruction: variable_values["REQUEST"] = self.instruction prompt_with_variables = PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION - + if self.extract_type == "schema" and self.schema: variable_values["SCHEMA"] = json.dumps(self.schema, indent=2) prompt_with_variables = PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION @@ -509,60 +599,72 @@ class LLMExtractionStrategy(ExtractionStrategy): prompt_with_variables = prompt_with_variables.replace( "{" + variable + "}", variable_values[variable] ) - + response = perform_completion_with_backoff( - self.provider, - prompt_with_variables, - self.api_token, + self.provider, + prompt_with_variables, + self.api_token, base_url=self.api_base or self.base_url, - extra_args = self.extra_args - ) # , json_response=self.extract_type == "schema") + extra_args=self.extra_args, + ) # , json_response=self.extract_type == "schema") # Track usage usage = TokenUsage( completion_tokens=response.usage.completion_tokens, prompt_tokens=response.usage.prompt_tokens, total_tokens=response.usage.total_tokens, - completion_tokens_details=response.usage.completion_tokens_details.__dict__ if response.usage.completion_tokens_details else {}, - prompt_tokens_details=response.usage.prompt_tokens_details.__dict__ if response.usage.prompt_tokens_details else {} + completion_tokens_details=response.usage.completion_tokens_details.__dict__ + if response.usage.completion_tokens_details + else {}, + prompt_tokens_details=response.usage.prompt_tokens_details.__dict__ + if response.usage.prompt_tokens_details + else {}, ) self.usages.append(usage) - + # Update totals self.total_usage.completion_tokens += usage.completion_tokens - self.total_usage.prompt_tokens += usage.prompt_tokens + self.total_usage.prompt_tokens += usage.prompt_tokens self.total_usage.total_tokens += usage.total_tokens - + try: - blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] + blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[ + "blocks" + ] blocks = json.loads(blocks) for block in blocks: - block['error'] = False - except Exception as e: - parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content) + block["error"] = False + except Exception: + parsed, unparsed = split_and_parse_json_objects( + response.choices[0].message.content + ) blocks = parsed if unparsed: - blocks.append({ - "index": 0, - "error": True, - "tags": ["error"], - "content": unparsed - }) - + blocks.append( + {"index": 0, "error": True, "tags": ["error"], "content": unparsed} + ) + if self.verbose: - print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix) + print( + "[LOG] Extracted", + len(blocks), + "blocks from URL:", + url, + "block index:", + ix, + ) return blocks - + def _merge(self, documents, chunk_token_threshold, overlap): """ Merge documents into sections based on chunk_token_threshold and overlap. """ - chunks = [] + # chunks = [] sections = [] total_tokens = 0 # Calculate the total tokens across all documents for document in documents: - total_tokens += len(document.split(' ')) * self.word_token_rate + total_tokens += len(document.split(" ")) * self.word_token_rate # Calculate the number of sections needed num_sections = math.floor(total_tokens / chunk_token_threshold) @@ -574,9 +676,9 @@ class LLMExtractionStrategy(ExtractionStrategy): current_chunk = [] for document in documents: - tokens = document.split(' ') + tokens = document.split(" ") token_count = len(tokens) * self.word_token_rate - + if total_token_so_far + token_count <= adjusted_chunk_threshold: current_chunk.extend(tokens) total_token_so_far += token_count @@ -585,56 +687,61 @@ class LLMExtractionStrategy(ExtractionStrategy): if len(sections) == num_sections - 1: current_chunk.extend(tokens) continue - + # Add overlap if specified if overlap > 0 and current_chunk: overlap_tokens = current_chunk[-overlap:] current_chunk.extend(overlap_tokens) - - sections.append(' '.join(current_chunk)) + + sections.append(" ".join(current_chunk)) current_chunk = tokens total_token_so_far = token_count # Add the last chunk if current_chunk: - sections.append(' '.join(current_chunk)) + sections.append(" ".join(current_chunk)) return sections - def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]: """ Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy. - + Args: url: The URL of the webpage. sections: List of sections (strings) to process. - + Returns: A list of extracted blocks or chunks. """ - + merged_sections = self._merge( - sections, self.chunk_token_threshold, - overlap= int(self.chunk_token_threshold * self.overlap_rate) + sections, + self.chunk_token_threshold, + overlap=int(self.chunk_token_threshold * self.overlap_rate), ) extracted_content = [] if self.provider.startswith("groq/"): # Sequential processing with a delay for ix, section in enumerate(merged_sections): extract_func = partial(self.extract, url) - extracted_content.extend(extract_func(ix, sanitize_input_encode(section))) + extracted_content.extend( + extract_func(ix, sanitize_input_encode(section)) + ) time.sleep(0.5) # 500 ms delay between each processing else: # Parallel processing using ThreadPoolExecutor # extract_func = partial(self.extract, url) # for ix, section in enumerate(merged_sections): - # extracted_content.append(extract_func(ix, section)) - + # extracted_content.append(extract_func(ix, section)) + with ThreadPoolExecutor(max_workers=4) as executor: extract_func = partial(self.extract, url) - futures = [executor.submit(extract_func, ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)] - + futures = [ + executor.submit(extract_func, ix, sanitize_input_encode(section)) + for ix, section in enumerate(merged_sections) + ] + for future in as_completed(futures): try: extracted_content.extend(future.result()) @@ -642,17 +749,17 @@ class LLMExtractionStrategy(ExtractionStrategy): if self.verbose: print(f"Error in thread execution: {e}") # Add error information to extracted_content - extracted_content.append({ - "index": 0, - "error": True, - "tags": ["error"], - "content": str(e) - }) + extracted_content.append( + { + "index": 0, + "error": True, + "tags": ["error"], + "content": str(e), + } + ) + + return extracted_content - - return extracted_content - - def show_usage(self) -> None: """Print a detailed token usage report showing total and per-request usage.""" print("\n=== Token Usage Summary ===") @@ -666,13 +773,15 @@ class LLMExtractionStrategy(ExtractionStrategy): print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}") print("-" * 48) for i, usage in enumerate(self.usages, 1): - print(f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}") + print( + f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}" + ) - ####################################################### # New extraction strategies for JSON-based extraction # -####################################################### +####################################################### + class JsonElementExtractionStrategy(ExtractionStrategy): """ @@ -706,8 +815,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): _get_element_attribute(element, attribute): Extracts an attribute's value from an element. """ - - DEL = '\n' + DEL = "\n" def __init__(self, schema: Dict[str, Any], **kwargs): """ @@ -718,9 +826,11 @@ class JsonElementExtractionStrategy(ExtractionStrategy): """ super().__init__(**kwargs) self.schema = schema - self.verbose = kwargs.get('verbose', False) + self.verbose = kwargs.get("verbose", False) - def extract(self, url: str, html_content: str, *q, **kwargs) -> List[Dict[str, Any]]: + def extract( + self, url: str, html_content: str, *q, **kwargs + ) -> List[Dict[str, Any]]: """ Extract structured data from HTML content. @@ -738,27 +848,29 @@ class JsonElementExtractionStrategy(ExtractionStrategy): Returns: List[Dict[str, Any]]: A list of extracted items, each represented as a dictionary. """ - + parsed_html = self._parse_html(html_content) - base_elements = self._get_base_elements(parsed_html, self.schema['baseSelector']) - + base_elements = self._get_base_elements( + parsed_html, self.schema["baseSelector"] + ) + results = [] for element in base_elements: # Extract base element attributes item = {} - if 'baseFields' in self.schema: - for field in self.schema['baseFields']: + if "baseFields" in self.schema: + for field in self.schema["baseFields"]: value = self._extract_single_field(element, field) if value is not None: - item[field['name']] = value - + item[field["name"]] = value + # Extract child fields - field_data = self._extract_item(element, self.schema['fields']) + field_data = self._extract_item(element, self.schema["fields"]) item.update(field_data) - + if item: results.append(item) - + return results @abstractmethod @@ -778,24 +890,28 @@ class JsonElementExtractionStrategy(ExtractionStrategy): def _extract_field(self, element, field): try: - if field['type'] == 'nested': - nested_elements = self._get_elements(element, field['selector']) + if field["type"] == "nested": + nested_elements = self._get_elements(element, field["selector"]) nested_element = nested_elements[0] if nested_elements else None - return self._extract_item(nested_element, field['fields']) if nested_element else {} - - if field['type'] == 'list': - elements = self._get_elements(element, field['selector']) - return [self._extract_list_item(el, field['fields']) for el in elements] - - if field['type'] == 'nested_list': - elements = self._get_elements(element, field['selector']) - return [self._extract_item(el, field['fields']) for el in elements] - + return ( + self._extract_item(nested_element, field["fields"]) + if nested_element + else {} + ) + + if field["type"] == "list": + elements = self._get_elements(element, field["selector"]) + return [self._extract_list_item(el, field["fields"]) for el in elements] + + if field["type"] == "nested_list": + elements = self._get_elements(element, field["selector"]) + return [self._extract_item(el, field["fields"]) for el in elements] + return self._extract_single_field(element, field) except Exception as e: if self.verbose: print(f"Error extracting field {field['name']}: {str(e)}") - return field.get('default') + return field.get("default") def _extract_single_field(self, element, field): """ @@ -813,38 +929,38 @@ class JsonElementExtractionStrategy(ExtractionStrategy): Returns: Any: The extracted field value. """ - - if 'selector' in field: - selected = self._get_elements(element, field['selector']) + + if "selector" in field: + selected = self._get_elements(element, field["selector"]) if not selected: - return field.get('default') + return field.get("default") selected = selected[0] else: selected = element value = None - if field['type'] == 'text': + if field["type"] == "text": value = self._get_element_text(selected) - elif field['type'] == 'attribute': - value = self._get_element_attribute(selected, field['attribute']) - elif field['type'] == 'html': + elif field["type"] == "attribute": + value = self._get_element_attribute(selected, field["attribute"]) + elif field["type"] == "html": value = self._get_element_html(selected) - elif field['type'] == 'regex': + elif field["type"] == "regex": text = self._get_element_text(selected) - match = re.search(field['pattern'], text) + match = re.search(field["pattern"], text) value = match.group(1) if match else None - if 'transform' in field: - value = self._apply_transform(value, field['transform']) + if "transform" in field: + value = self._apply_transform(value, field["transform"]) - return value if value is not None else field.get('default') + return value if value is not None else field.get("default") def _extract_list_item(self, element, fields): item = {} for field in fields: value = self._extract_single_field(element, field) if value is not None: - item[field['name']] = value + item[field["name"]] = value return item def _extract_item(self, element, fields): @@ -863,15 +979,15 @@ class JsonElementExtractionStrategy(ExtractionStrategy): Returns: Dict[str, Any]: A dictionary representing the extracted item. """ - + item = {} for field in fields: - if field['type'] == 'computed': + if field["type"] == "computed": value = self._compute_field(item, field) else: value = self._extract_field(element, field) if value is not None: - item[field['name']] = value + item[field["name"]] = value return item def _apply_transform(self, value, transform): @@ -890,25 +1006,25 @@ class JsonElementExtractionStrategy(ExtractionStrategy): Returns: str: The transformed value. """ - - if transform == 'lowercase': + + if transform == "lowercase": return value.lower() - elif transform == 'uppercase': + elif transform == "uppercase": return value.upper() - elif transform == 'strip': + elif transform == "strip": return value.strip() return value def _compute_field(self, item, field): try: - if 'expression' in field: - return eval(field['expression'], {}, item) - elif 'function' in field: - return field['function'](item) + if "expression" in field: + return eval(field["expression"], {}, item) + elif "function" in field: + return field["function"](item) except Exception as e: if self.verbose: print(f"Error computing field {field['name']}: {str(e)}") - return field.get('default') + return field.get("default") def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: """ @@ -927,7 +1043,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): Returns: List[Dict[str, Any]]: A list of extracted items. """ - + combined_html = self.DEL.join(sections) return self.extract(url, combined_html, **kwargs) @@ -946,6 +1062,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy): """Get attribute value from element""" pass + class JsonCssExtractionStrategy(JsonElementExtractionStrategy): """ Concrete implementation of `JsonElementExtractionStrategy` using CSS selectors. @@ -967,13 +1084,13 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy): _get_element_html(element): Extracts the raw HTML content of a BeautifulSoup element. _get_element_attribute(element, attribute): Retrieves an attribute value from a BeautifulSoup element. """ - + def __init__(self, schema: Dict[str, Any], **kwargs): - kwargs['input_format'] = 'html' # Force HTML input + kwargs["input_format"] = "html" # Force HTML input super().__init__(schema, **kwargs) def _parse_html(self, html_content: str): - return BeautifulSoup(html_content, 'html.parser') + return BeautifulSoup(html_content, "html.parser") def _get_base_elements(self, parsed_html, selector: str): return parsed_html.select(selector) @@ -992,6 +1109,7 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy): def _get_element_attribute(self, element, attribute: str): return element.get(attribute) + class JsonXPathExtractionStrategy(JsonElementExtractionStrategy): """ Concrete implementation of `JsonElementExtractionStrategy` using XPath selectors. @@ -1014,9 +1132,9 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy): _get_element_html(element): Extracts the raw HTML content of an lxml element. _get_element_attribute(element, attribute): Retrieves an attribute value from an lxml element. """ - + def __init__(self, schema: Dict[str, Any], **kwargs): - kwargs['input_format'] = 'html' # Force HTML input + kwargs["input_format"] = "html" # Force HTML input super().__init__(schema, **kwargs) def _parse_html(self, html_content: str): @@ -1027,31 +1145,31 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy): def _css_to_xpath(self, css_selector: str) -> str: """Convert CSS selector to XPath if needed""" - if '/' in css_selector: # Already an XPath + if "/" in css_selector: # Already an XPath return css_selector return self._basic_css_to_xpath(css_selector) def _basic_css_to_xpath(self, css_selector: str) -> str: """Basic CSS to XPath conversion for common cases""" - if ' > ' in css_selector: - parts = css_selector.split(' > ') - return '//' + '/'.join(parts) - if ' ' in css_selector: - parts = css_selector.split(' ') - return '//' + '//'.join(parts) - return '//' + css_selector + if " > " in css_selector: + parts = css_selector.split(" > ") + return "//" + "/".join(parts) + if " " in css_selector: + parts = css_selector.split(" ") + return "//" + "//".join(parts) + return "//" + css_selector def _get_elements(self, element, selector: str): xpath = self._css_to_xpath(selector) - if not xpath.startswith('.'): - xpath = '.' + xpath + if not xpath.startswith("."): + xpath = "." + xpath return element.xpath(xpath) def _get_element_text(self, element) -> str: - return ''.join(element.xpath('.//text()')).strip() + return "".join(element.xpath(".//text()")).strip() def _get_element_html(self, element) -> str: - return etree.tostring(element, encoding='unicode') + return etree.tostring(element, encoding="unicode") def _get_element_attribute(self, element, attribute: str): return element.get(attribute) diff --git a/crawl4ai/html2text/__init__.py b/crawl4ai/html2text/__init__.py index c41258e0..a3349e70 100644 --- a/crawl4ai/html2text/__init__.py +++ b/crawl4ai/html2text/__init__.py @@ -54,13 +54,13 @@ class HTML2Text(html.parser.HTMLParser): self.td_count = 0 self.table_start = False self.unicode_snob = config.UNICODE_SNOB # covered in cli - + self.escape_snob = config.ESCAPE_SNOB # covered in cli self.escape_backslash = config.ESCAPE_BACKSLASH # covered in cli self.escape_dot = config.ESCAPE_DOT # covered in cli self.escape_plus = config.ESCAPE_PLUS # covered in cli self.escape_dash = config.ESCAPE_DASH # covered in cli - + self.links_each_paragraph = config.LINKS_EACH_PARAGRAPH self.body_width = bodywidth # covered in cli self.skip_internal_links = config.SKIP_INTERNAL_LINKS # covered in cli @@ -144,8 +144,8 @@ class HTML2Text(html.parser.HTMLParser): def update_params(self, **kwargs): for key, value in kwargs.items(): - setattr(self, key, value) - + setattr(self, key, value) + def feed(self, data: str) -> None: data = data.replace("", "") super().feed(data) @@ -903,7 +903,13 @@ class HTML2Text(html.parser.HTMLParser): self.empty_link = False if not self.code and not self.pre and not entity_char: - data = escape_md_section(data, snob=self.escape_snob, escape_dot=self.escape_dot, escape_plus=self.escape_plus, escape_dash=self.escape_dash) + data = escape_md_section( + data, + snob=self.escape_snob, + escape_dot=self.escape_dot, + escape_plus=self.escape_plus, + escape_dash=self.escape_dash, + ) self.preceding_data = data self.o(data, puredata=True) @@ -1006,6 +1012,7 @@ class HTML2Text(html.parser.HTMLParser): newlines += 1 return result + def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str: if bodywidth is None: bodywidth = config.BODY_WIDTH @@ -1013,6 +1020,7 @@ def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> return h.handle(html) + class CustomHTML2Text(HTML2Text): def __init__(self, *args, handle_code_in_pre=False, **kwargs): super().__init__(*args, **kwargs) @@ -1022,8 +1030,8 @@ class CustomHTML2Text(HTML2Text): self.current_preserved_tag = None self.preserved_content = [] self.preserve_depth = 0 - self.handle_code_in_pre = handle_code_in_pre - + self.handle_code_in_pre = handle_code_in_pre + # Configuration options self.skip_internal_links = False self.single_line_break = False @@ -1041,9 +1049,9 @@ class CustomHTML2Text(HTML2Text): def update_params(self, **kwargs): """Update parameters and set preserved tags.""" for key, value in kwargs.items(): - if key == 'preserve_tags': + if key == "preserve_tags": self.preserve_tags = set(value) - elif key == 'handle_code_in_pre': + elif key == "handle_code_in_pre": self.handle_code_in_pre = value else: setattr(self, key, value) @@ -1056,17 +1064,19 @@ class CustomHTML2Text(HTML2Text): self.current_preserved_tag = tag self.preserved_content = [] # Format opening tag with attributes - attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) - self.preserved_content.append(f'<{tag}{attr_str}>') + attr_str = "".join( + f' {k}="{v}"' for k, v in attrs.items() if v is not None + ) + self.preserved_content.append(f"<{tag}{attr_str}>") self.preserve_depth += 1 return else: self.preserve_depth -= 1 if self.preserve_depth == 0: - self.preserved_content.append(f'') + self.preserved_content.append(f"") # Output the preserved HTML block with proper spacing - preserved_html = ''.join(self.preserved_content) - self.o('\n' + preserved_html + '\n') + preserved_html = "".join(self.preserved_content) + self.o("\n" + preserved_html + "\n") self.current_preserved_tag = None return @@ -1074,29 +1084,31 @@ class CustomHTML2Text(HTML2Text): if self.preserve_depth > 0: if start: # Format nested tags with attributes - attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) - self.preserved_content.append(f'<{tag}{attr_str}>') + attr_str = "".join( + f' {k}="{v}"' for k, v in attrs.items() if v is not None + ) + self.preserved_content.append(f"<{tag}{attr_str}>") else: - self.preserved_content.append(f'') + self.preserved_content.append(f"") return # Handle pre tags - if tag == 'pre': + if tag == "pre": if start: - self.o('```\n') # Markdown code block start + self.o("```\n") # Markdown code block start self.inside_pre = True else: - self.o('\n```\n') # Markdown code block end + self.o("\n```\n") # Markdown code block end self.inside_pre = False - elif tag == 'code': + elif tag == "code": if self.inside_pre and not self.handle_code_in_pre: # Ignore code tags inside pre blocks if handle_code_in_pre is False return if start: - self.o('`') # Markdown inline code start + self.o("`") # Markdown inline code start self.inside_code = True else: - self.o('`') # Markdown inline code end + self.o("`") # Markdown inline code end self.inside_code = False else: super().handle_tag(tag, attrs, start) @@ -1113,13 +1125,12 @@ class CustomHTML2Text(HTML2Text): return if self.inside_code: # Inline code: no newlines allowed - self.o(data.replace('\n', ' ')) + self.o(data.replace("\n", " ")) return # Default behavior for other tags super().handle_data(data, entity_char) - # # Handle pre tags # if tag == 'pre': # if start: diff --git a/crawl4ai/html2text/_typing.py b/crawl4ai/html2text/_typing.py index eed83251..6e17fed2 100644 --- a/crawl4ai/html2text/_typing.py +++ b/crawl4ai/html2text/_typing.py @@ -1,2 +1,3 @@ class OutCallback: - def __call__(self, s: str) -> None: ... + def __call__(self, s: str) -> None: + ... diff --git a/crawl4ai/html2text/utils.py b/crawl4ai/html2text/utils.py index 1909d2cf..21bf98fb 100644 --- a/crawl4ai/html2text/utils.py +++ b/crawl4ai/html2text/utils.py @@ -210,7 +210,7 @@ def escape_md_section( snob: bool = False, escape_dot: bool = True, escape_plus: bool = True, - escape_dash: bool = True + escape_dash: bool = True, ) -> str: """ Escapes markdown-sensitive characters across whole document sections. @@ -233,6 +233,7 @@ def escape_md_section( return text + def reformat_table(lines: List[str], right_margin: int) -> List[str]: """ Given the lines of a table diff --git a/crawl4ai/install.py b/crawl4ai/install.py index 7efb6800..139be591 100644 --- a/crawl4ai/install.py +++ b/crawl4ai/install.py @@ -6,25 +6,44 @@ from .async_logger import AsyncLogger, LogLevel # Initialize logger logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True) + def post_install(): """Run all post-installation tasks""" logger.info("Running post-installation setup...", tag="INIT") install_playwright() run_migration() logger.success("Post-installation setup completed!", tag="COMPLETE") - + + def install_playwright(): logger.info("Installing Playwright browsers...", tag="INIT") try: # subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"]) - subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chromium"]) - logger.success("Playwright installation completed successfully.", tag="COMPLETE") - except subprocess.CalledProcessError as e: + subprocess.check_call( + [ + sys.executable, + "-m", + "playwright", + "install", + "--with-deps", + "--force", + "chromium", + ] + ) + logger.success( + "Playwright installation completed successfully.", tag="COMPLETE" + ) + except subprocess.CalledProcessError: # logger.error(f"Error during Playwright installation: {e}", tag="ERROR") - logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.") - except Exception as e: + logger.warning( + f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation." + ) + except Exception: # logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR") - logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.") + logger.warning( + f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation." + ) + def run_migration(): """Initialize database during installation""" @@ -33,18 +52,26 @@ def run_migration(): from crawl4ai.async_database import async_db_manager asyncio.run(async_db_manager.initialize()) - logger.success("Database initialization completed successfully.", tag="COMPLETE") + logger.success( + "Database initialization completed successfully.", tag="COMPLETE" + ) except ImportError: logger.warning("Database module not found. Will initialize on first use.") except Exception as e: logger.warning(f"Database initialization failed: {e}") logger.warning("Database will be initialized on first use") + async def run_doctor(): """Test if Crawl4AI is working properly""" logger.info("Running Crawl4AI health check...", tag="INIT") try: - from .async_webcrawler import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode + from .async_webcrawler import ( + AsyncWebCrawler, + BrowserConfig, + CrawlerRunConfig, + CacheMode, + ) browser_config = BrowserConfig( headless=True, @@ -52,7 +79,7 @@ async def run_doctor(): ignore_https_errors=True, light_mode=True, viewport_width=1280, - viewport_height=720 + viewport_height=720, ) run_config = CrawlerRunConfig( @@ -62,10 +89,7 @@ async def run_doctor(): async with AsyncWebCrawler(config=browser_config) as crawler: logger.info("Testing crawling capabilities...", tag="TEST") - result = await crawler.arun( - url="https://crawl4ai.com", - config=run_config - ) + result = await crawler.arun(url="https://crawl4ai.com", config=run_config) if result and result.markdown: logger.success("✅ Crawling test passed!", tag="COMPLETE") @@ -77,7 +101,9 @@ async def run_doctor(): logger.error(f"❌ Test failed: {e}", tag="ERROR") return False + def doctor(): """Entry point for the doctor command""" import asyncio + return asyncio.run(run_doctor()) diff --git a/crawl4ai/js_snippet/__init__.py b/crawl4ai/js_snippet/__init__.py index 73b0c2dd..e51f79d8 100644 --- a/crawl4ai/js_snippet/__init__.py +++ b/crawl4ai/js_snippet/__init__.py @@ -1,15 +1,18 @@ -import os, sys +import os + # Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free def load_js_script(script_name): # Get the path of the current script current_script_path = os.path.dirname(os.path.realpath(__file__)) # Get the path of the script to load - script_path = os.path.join(current_script_path, script_name + '.js') + script_path = os.path.join(current_script_path, script_name + ".js") # Check if the script exists if not os.path.exists(script_path): - raise ValueError(f"Script {script_name} not found in the folder {current_script_path}") + raise ValueError( + f"Script {script_name} not found in the folder {current_script_path}" + ) # Load the content of the script - with open(script_path, 'r') as f: + with open(script_path, "r") as f: script_content = f.read() return script_content diff --git a/crawl4ai/llmtxt.py b/crawl4ai/llmtxt.py index 94efe076..30256416 100644 --- a/crawl4ai/llmtxt.py +++ b/crawl4ai/llmtxt.py @@ -11,16 +11,16 @@ from rank_bm25 import BM25Okapi from nltk.tokenize import word_tokenize from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer -from litellm import completion, batch_completion +from litellm import batch_completion from .async_logger import AsyncLogger import litellm import pickle import hashlib # <--- ADDED for file-hash -from fnmatch import fnmatch import glob litellm.set_verbose = False + def _compute_file_hash(file_path: Path) -> str: """Compute MD5 hash for the file's entire content.""" hash_md5 = hashlib.md5() @@ -29,13 +29,14 @@ def _compute_file_hash(file_path: Path) -> str: hash_md5.update(chunk) return hash_md5.hexdigest() + class AsyncLLMTextManager: def __init__( self, docs_dir: Path, logger: Optional[AsyncLogger] = None, max_concurrent_calls: int = 5, - batch_size: int = 3 + batch_size: int = 3, ) -> None: self.docs_dir = docs_dir self.logger = logger @@ -51,7 +52,7 @@ class AsyncLLMTextManager: contents = [] for file_path in doc_batch: try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: contents.append(f.read()) except Exception as e: self.logger.error(f"Error reading {file_path}: {str(e)}") @@ -77,43 +78,53 @@ Wrap your response in ... tags. # Prepare messages for batch processing messages_list = [ [ - {"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"} + { + "role": "user", + "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}", + } ] - for content in contents if content + for content in contents + if content ] try: responses = batch_completion( model="anthropic/claude-3-5-sonnet-latest", messages=messages_list, - logger_fn=None + logger_fn=None, ) # Process responses and save index files for response, file_path in zip(responses, doc_batch): try: index_content_match = re.search( - r'(.*?)', + r"(.*?)", response.choices[0].message.content, - re.DOTALL + re.DOTALL, ) if not index_content_match: - self.logger.warning(f"No ... content found for {file_path}") + self.logger.warning( + f"No ... content found for {file_path}" + ) continue index_content = re.sub( r"\n\s*\n", "\n", index_content_match.group(1) ).strip() if index_content: - index_file = file_path.with_suffix('.q.md') - with open(index_file, 'w', encoding='utf-8') as f: + index_file = file_path.with_suffix(".q.md") + with open(index_file, "w", encoding="utf-8") as f: f.write(index_content) self.logger.info(f"Created index file: {index_file}") else: - self.logger.warning(f"No index content found in response for {file_path}") + self.logger.warning( + f"No index content found in response for {file_path}" + ) except Exception as e: - self.logger.error(f"Error processing response for {file_path}: {str(e)}") + self.logger.error( + f"Error processing response for {file_path}: {str(e)}" + ) except Exception as e: self.logger.error(f"Error in batch completion: {str(e)}") @@ -171,7 +182,12 @@ Wrap your response in ... tags. lemmatizer = WordNetLemmatizer() stop_words = set(stopwords.words("english")) - { - "how", "what", "when", "where", "why", "which", + "how", + "what", + "when", + "where", + "why", + "which", } tokens = [] @@ -222,7 +238,9 @@ Wrap your response in ... tags. self.logger.info("Checking which .q.md files need (re)indexing...") # Gather all .q.md files - q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")] + q_files = [ + self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md") + ] # We'll store known (unchanged) facts in these lists existing_facts: List[str] = [] @@ -243,7 +261,9 @@ Wrap your response in ... tags. # Otherwise, load the existing cache and compare hash cache = self._load_or_create_token_cache(qf) # If the .q.tokens was out of date (i.e. changed hash), we reindex - if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf): + if len(cache["facts"]) == 0 or cache.get( + "content_hash" + ) != _compute_file_hash(qf): needSet.append(qf) else: # File is unchanged → retrieve cached token data @@ -255,20 +275,29 @@ Wrap your response in ... tags. if not needSet and not clear_cache: # If no file needs reindexing, try loading existing index if self.maybe_load_bm25_index(clear_cache=False): - self.logger.info("No new/changed .q.md files found. Using existing BM25 index.") + self.logger.info( + "No new/changed .q.md files found. Using existing BM25 index." + ) return else: # If there's no existing index, we must build a fresh index from the old caches - self.logger.info("No existing BM25 index found. Building from cached facts.") + self.logger.info( + "No existing BM25 index found. Building from cached facts." + ) if existing_facts: - self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.") + self.logger.info( + f"Building BM25 index with {len(existing_facts)} cached facts." + ) self.bm25_index = BM25Okapi(existing_tokens) self.tokenized_facts = existing_facts with open(self.bm25_index_file, "wb") as f: - pickle.dump({ - "bm25_index": self.bm25_index, - "tokenized_facts": self.tokenized_facts - }, f) + pickle.dump( + { + "bm25_index": self.bm25_index, + "tokenized_facts": self.tokenized_facts, + }, + f, + ) else: self.logger.warning("No facts found at all. Index remains empty.") return @@ -311,7 +340,9 @@ Wrap your response in ... tags. self._save_token_cache(file, fresh_cache) mem_usage = process.memory_info().rss / 1024 / 1024 - self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB") + self.logger.debug( + f"Memory usage after {file.name}: {mem_usage:.2f}MB" + ) except Exception as e: self.logger.error(f"Error processing {file}: {str(e)}") @@ -328,40 +359,49 @@ Wrap your response in ... tags. all_tokens = existing_tokens + new_tokens # 3) Build BM25 index from combined facts - self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).") + self.logger.info( + f"Building BM25 index with {len(all_facts)} total facts (old + new)." + ) self.bm25_index = BM25Okapi(all_tokens) self.tokenized_facts = all_facts # 4) Save the updated BM25 index to disk with open(self.bm25_index_file, "wb") as f: - pickle.dump({ - "bm25_index": self.bm25_index, - "tokenized_facts": self.tokenized_facts - }, f) + pickle.dump( + { + "bm25_index": self.bm25_index, + "tokenized_facts": self.tokenized_facts, + }, + f, + ) final_mem = process.memory_info().rss / 1024 / 1024 self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB") - async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None: + async def generate_index_files( + self, force_generate_facts: bool = False, clear_bm25_cache: bool = False + ) -> None: """ Generate index files for all documents in parallel batches - + Args: force_generate_facts (bool): If True, regenerate indexes even if they exist clear_bm25_cache (bool): If True, clear existing BM25 index cache """ self.logger.info("Starting index generation for documentation files.") - + md_files = [ - self.docs_dir / f for f in os.listdir(self.docs_dir) - if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md']) + self.docs_dir / f + for f in os.listdir(self.docs_dir) + if f.endswith(".md") and not any(f.endswith(x) for x in [".q.md", ".xs.md"]) ] # Filter out files that already have .q files unless force=True if not force_generate_facts: md_files = [ - f for f in md_files - if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists() + f + for f in md_files + if not (self.docs_dir / f.name.replace(".md", ".q.md")).exists() ] if not md_files: @@ -369,8 +409,10 @@ Wrap your response in ... tags. else: # Process documents in batches for i in range(0, len(md_files), self.batch_size): - batch = md_files[i:i + self.batch_size] - self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}") + batch = md_files[i : i + self.batch_size] + self.logger.info( + f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}" + ) await self._process_document_batch(batch) self.logger.info("Index generation complete, building/updating search index.") @@ -378,21 +420,31 @@ Wrap your response in ... tags. def generate(self, sections: List[str], mode: str = "extended") -> str: # Get all markdown files - all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \ - glob.glob(str(self.docs_dir / "[0-9]*.xs.md")) - + all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + glob.glob( + str(self.docs_dir / "[0-9]*.xs.md") + ) + # Extract base names without extensions - base_docs = {Path(f).name.split('.')[0] for f in all_files - if not Path(f).name.endswith('.q.md')} - + base_docs = { + Path(f).name.split(".")[0] + for f in all_files + if not Path(f).name.endswith(".q.md") + } + # Filter by sections if provided if sections: - base_docs = {doc for doc in base_docs - if any(section.lower() in doc.lower() for section in sections)} - + base_docs = { + doc + for doc in base_docs + if any(section.lower() in doc.lower() for section in sections) + } + # Get file paths based on mode files = [] - for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999): + for doc in sorted( + base_docs, + key=lambda x: int(x.split("_")[0]) if x.split("_")[0].isdigit() else 999999, + ): if mode == "condensed": xs_file = self.docs_dir / f"{doc}.xs.md" regular_file = self.docs_dir / f"{doc}.md" @@ -404,7 +456,7 @@ Wrap your response in ... tags. content = [] for file in files: try: - with open(file, 'r', encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: fname = Path(file).name content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}") except Exception as e: @@ -443,15 +495,9 @@ Wrap your response in ... tags. for file, _ in ranked_files: main_doc = str(file).replace(".q.md", ".md") if os.path.exists(self.docs_dir / main_doc): - with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f: + with open(self.docs_dir / main_doc, "r", encoding="utf-8") as f: only_file_name = main_doc.split("/")[-1] - content = [ - "#" * 20, - f"# {only_file_name}", - "#" * 20, - "", - f.read() - ] + content = ["#" * 20, f"# {only_file_name}", "#" * 20, "", f.read()] results.append("\n".join(content)) return "\n\n---\n\n".join(results) @@ -482,7 +528,9 @@ Wrap your response in ... tags. if len(components) == 3: code_ref = components[2].strip() code_tokens = self.preprocess_text(code_ref) - code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens) + code_match_score = len(set(query_tokens) & set(code_tokens)) / len( + query_tokens + ) file_data[file_path]["total_score"] += score file_data[file_path]["match_count"] += 1 diff --git a/crawl4ai/markdown_generation_strategy.py b/crawl4ai/markdown_generation_strategy.py index 89e5e34e..1e3f0554 100644 --- a/crawl4ai/markdown_generation_strategy.py +++ b/crawl4ai/markdown_generation_strategy.py @@ -2,77 +2,94 @@ from abc import ABC, abstractmethod from typing import Optional, Dict, Any, Tuple from .models import MarkdownGenerationResult from .html2text import CustomHTML2Text -from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter +from .content_filter_strategy import RelevantContentFilter import re from urllib.parse import urljoin # Pre-compile the regex pattern LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)') + def fast_urljoin(base: str, url: str) -> str: """Fast URL joining for common cases.""" - if url.startswith(('http://', 'https://', 'mailto:', '//')): + if url.startswith(("http://", "https://", "mailto:", "//")): return url - if url.startswith('/'): + if url.startswith("/"): # Handle absolute paths - if base.endswith('/'): + if base.endswith("/"): return base[:-1] + url return base + url return urljoin(base, url) + class MarkdownGenerationStrategy(ABC): """Abstract base class for markdown generation strategies.""" - def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None): + + def __init__( + self, + content_filter: Optional[RelevantContentFilter] = None, + options: Optional[Dict[str, Any]] = None, + ): self.content_filter = content_filter self.options = options or {} - + @abstractmethod - def generate_markdown(self, - cleaned_html: str, - base_url: str = "", - html2text_options: Optional[Dict[str, Any]] = None, - content_filter: Optional[RelevantContentFilter] = None, - citations: bool = True, - **kwargs) -> MarkdownGenerationResult: + def generate_markdown( + self, + cleaned_html: str, + base_url: str = "", + html2text_options: Optional[Dict[str, Any]] = None, + content_filter: Optional[RelevantContentFilter] = None, + citations: bool = True, + **kwargs, + ) -> MarkdownGenerationResult: """Generate markdown from cleaned HTML.""" pass + class DefaultMarkdownGenerator(MarkdownGenerationStrategy): """ Default implementation of markdown generation strategy. - + How it works: 1. Generate raw markdown from cleaned HTML. 2. Convert links to citations. 3. Generate fit markdown if content filter is provided. 4. Return MarkdownGenerationResult. - + Args: content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown. options (Optional[Dict[str, Any]]): Additional options for markdown generation. Defaults to None. - + Returns: MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown. """ - def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None): + + def __init__( + self, + content_filter: Optional[RelevantContentFilter] = None, + options: Optional[Dict[str, Any]] = None, + ): super().__init__(content_filter, options) - - def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]: + + def convert_links_to_citations( + self, markdown: str, base_url: str = "" + ) -> Tuple[str, str]: """ Convert links in markdown to citations. - + How it works: 1. Find all links in the markdown. 2. Convert links to citations. 3. Return converted markdown and references markdown. - + Note: This function uses a regex pattern to find links in markdown. - + Args: markdown (str): Markdown text. base_url (str): Base URL for URL joins. - + Returns: Tuple[str, str]: Converted markdown and references markdown. """ @@ -81,57 +98,65 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy): parts = [] last_end = 0 counter = 1 - + for match in LINK_PATTERN.finditer(markdown): - parts.append(markdown[last_end:match.start()]) + parts.append(markdown[last_end : match.start()]) text, url, title = match.groups() - + # Use cached URL if available, otherwise compute and cache - if base_url and not url.startswith(('http://', 'https://', 'mailto:')): + if base_url and not url.startswith(("http://", "https://", "mailto:")): if url not in url_cache: url_cache[url] = fast_urljoin(base_url, url) url = url_cache[url] - + if url not in link_map: desc = [] - if title: desc.append(title) - if text and text != title: desc.append(text) + if title: + desc.append(title) + if text and text != title: + desc.append(text) link_map[url] = (counter, ": " + " - ".join(desc) if desc else "") counter += 1 - + num = link_map[url][0] - parts.append(f"{text}⟨{num}⟩" if not match.group(0).startswith('!') else f"![{text}⟨{num}⟩]") + parts.append( + f"{text}⟨{num}⟩" + if not match.group(0).startswith("!") + else f"![{text}⟨{num}⟩]" + ) last_end = match.end() - + parts.append(markdown[last_end:]) - converted_text = ''.join(parts) - + converted_text = "".join(parts) + # Pre-build reference strings references = ["\n\n## References\n\n"] references.extend( - f"⟨{num}⟩ {url}{desc}\n" + f"⟨{num}⟩ {url}{desc}\n" for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0]) ) - - return converted_text, ''.join(references) - def generate_markdown(self, - cleaned_html: str, - base_url: str = "", - html2text_options: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - content_filter: Optional[RelevantContentFilter] = None, - citations: bool = True, - **kwargs) -> MarkdownGenerationResult: + return converted_text, "".join(references) + + def generate_markdown( + self, + cleaned_html: str, + base_url: str = "", + html2text_options: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + content_filter: Optional[RelevantContentFilter] = None, + citations: bool = True, + **kwargs, + ) -> MarkdownGenerationResult: """ Generate markdown with citations from cleaned HTML. - + How it works: 1. Generate raw markdown from cleaned HTML. 2. Convert links to citations. 3. Generate fit markdown if content filter is provided. 4. Return MarkdownGenerationResult. - + Args: cleaned_html (str): Cleaned HTML content. base_url (str): Base URL for URL joins. @@ -139,7 +164,7 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy): options (Optional[Dict[str, Any]]): Additional options for markdown generation. content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown. citations (bool): Whether to generate citations. - + Returns: MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown. """ @@ -147,16 +172,16 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy): # Initialize HTML2Text with default options for better conversion h = CustomHTML2Text(baseurl=base_url) default_options = { - 'body_width': 0, # Disable text wrapping - 'ignore_emphasis': False, - 'ignore_links': False, - 'ignore_images': False, - 'protect_links': True, - 'single_line_break': True, - 'mark_code': True, - 'escape_snob': False + "body_width": 0, # Disable text wrapping + "ignore_emphasis": False, + "ignore_links": False, + "ignore_images": False, + "protect_links": True, + "single_line_break": True, + "mark_code": True, + "escape_snob": False, } - + # Update with custom options if provided if html2text_options: default_options.update(html2text_options) @@ -164,7 +189,7 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy): default_options.update(options) elif self.options: default_options.update(self.options) - + h.update_params(**default_options) # Ensure we have valid input @@ -178,17 +203,18 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy): raw_markdown = h.handle(cleaned_html) except Exception as e: raw_markdown = f"Error converting HTML to markdown: {str(e)}" - - raw_markdown = raw_markdown.replace(' ```', '```') + + raw_markdown = raw_markdown.replace(" ```", "```") # Convert links to citations markdown_with_citations: str = raw_markdown references_markdown: str = "" if citations: try: - markdown_with_citations, references_markdown = self.convert_links_to_citations( - raw_markdown, base_url - ) + ( + markdown_with_citations, + references_markdown, + ) = self.convert_links_to_citations(raw_markdown, base_url) except Exception as e: markdown_with_citations = raw_markdown references_markdown = f"Error generating citations: {str(e)}" @@ -200,7 +226,9 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy): try: content_filter = content_filter or self.content_filter filtered_html = content_filter.filter_content(cleaned_html) - filtered_html = '\n'.join('
{}
'.format(s) for s in filtered_html) + filtered_html = "\n".join( + "
{}
".format(s) for s in filtered_html + ) fit_markdown = h.handle(filtered_html) except Exception as e: fit_markdown = f"Error generating fit markdown: {str(e)}" diff --git a/crawl4ai/migrations.py b/crawl4ai/migrations.py index 3386b0fb..d6da292f 100644 --- a/crawl4ai/migrations.py +++ b/crawl4ai/migrations.py @@ -1,13 +1,11 @@ import os import asyncio -import logging from pathlib import Path import aiosqlite from typing import Optional import xxhash import aiofiles import shutil -import time from datetime import datetime from .async_logger import AsyncLogger, LogLevel @@ -17,18 +15,19 @@ logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True) # logging.basicConfig(level=logging.INFO) # logger = logging.getLogger(__name__) + class DatabaseMigration: def __init__(self, db_path: str): self.db_path = db_path self.content_paths = self._ensure_content_dirs(os.path.dirname(db_path)) - + def _ensure_content_dirs(self, base_path: str) -> dict: dirs = { - 'html': 'html_content', - 'cleaned': 'cleaned_html', - 'markdown': 'markdown_content', - 'extracted': 'extracted_content', - 'screenshots': 'screenshots' + "html": "html_content", + "cleaned": "cleaned_html", + "markdown": "markdown_content", + "extracted": "extracted_content", + "screenshots": "screenshots", } content_paths = {} for key, dirname in dirs.items(): @@ -47,43 +46,55 @@ class DatabaseMigration: async def _store_content(self, content: str, content_type: str) -> str: if not content: return "" - + content_hash = self._generate_content_hash(content) file_path = os.path.join(self.content_paths[content_type], content_hash) - + if not os.path.exists(file_path): - async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: + async with aiofiles.open(file_path, "w", encoding="utf-8") as f: await f.write(content) - + return content_hash async def migrate_database(self): """Migrate existing database to file-based storage""" # logger.info("Starting database migration...") logger.info("Starting database migration...", tag="INIT") - + try: async with aiosqlite.connect(self.db_path) as db: # Get all rows async with db.execute( - '''SELECT url, html, cleaned_html, markdown, - extracted_content, screenshot FROM crawled_data''' + """SELECT url, html, cleaned_html, markdown, + extracted_content, screenshot FROM crawled_data""" ) as cursor: rows = await cursor.fetchall() migrated_count = 0 for row in rows: - url, html, cleaned_html, markdown, extracted_content, screenshot = row - + ( + url, + html, + cleaned_html, + markdown, + extracted_content, + screenshot, + ) = row + # Store content in files and get hashes - html_hash = await self._store_content(html, 'html') - cleaned_hash = await self._store_content(cleaned_html, 'cleaned') - markdown_hash = await self._store_content(markdown, 'markdown') - extracted_hash = await self._store_content(extracted_content, 'extracted') - screenshot_hash = await self._store_content(screenshot, 'screenshots') + html_hash = await self._store_content(html, "html") + cleaned_hash = await self._store_content(cleaned_html, "cleaned") + markdown_hash = await self._store_content(markdown, "markdown") + extracted_hash = await self._store_content( + extracted_content, "extracted" + ) + screenshot_hash = await self._store_content( + screenshot, "screenshots" + ) # Update database with hashes - await db.execute(''' + await db.execute( + """ UPDATE crawled_data SET html = ?, cleaned_html = ?, @@ -91,40 +102,51 @@ class DatabaseMigration: extracted_content = ?, screenshot = ? WHERE url = ? - ''', (html_hash, cleaned_hash, markdown_hash, - extracted_hash, screenshot_hash, url)) - + """, + ( + html_hash, + cleaned_hash, + markdown_hash, + extracted_hash, + screenshot_hash, + url, + ), + ) + migrated_count += 1 if migrated_count % 100 == 0: logger.info(f"Migrated {migrated_count} records...", tag="INIT") - await db.commit() - logger.success(f"Migration completed. {migrated_count} records processed.", tag="COMPLETE") + logger.success( + f"Migration completed. {migrated_count} records processed.", + tag="COMPLETE", + ) except Exception as e: # logger.error(f"Migration failed: {e}") logger.error( message="Migration failed: {error}", tag="ERROR", - params={"error": str(e)} + params={"error": str(e)}, ) raise e + async def backup_database(db_path: str) -> str: """Create backup of existing database""" if not os.path.exists(db_path): logger.info("No existing database found. Skipping backup.", tag="INIT") return None - + # Create backup with timestamp - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = f"{db_path}.backup_{timestamp}" - + try: # Wait for any potential write operations to finish await asyncio.sleep(1) - + # Create backup shutil.copy2(db_path, backup_path) logger.info(f"Database backup created at: {backup_path}", tag="COMPLETE") @@ -132,37 +154,41 @@ async def backup_database(db_path: str) -> str: except Exception as e: # logger.error(f"Backup failed: {e}") logger.error( - message="Migration failed: {error}", - tag="ERROR", - params={"error": str(e)} - ) + message="Migration failed: {error}", tag="ERROR", params={"error": str(e)} + ) raise e - + + async def run_migration(db_path: Optional[str] = None): """Run database migration""" if db_path is None: db_path = os.path.join(Path.home(), ".crawl4ai", "crawl4ai.db") - + if not os.path.exists(db_path): logger.info("No existing database found. Skipping migration.", tag="INIT") return - + # Create backup first backup_path = await backup_database(db_path) if not backup_path: return - + migration = DatabaseMigration(db_path) await migration.migrate_database() - + + def main(): """CLI entry point for migration""" import argparse - parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage') - parser.add_argument('--db-path', help='Custom database path') + + parser = argparse.ArgumentParser( + description="Migrate Crawl4AI database to file-based storage" + ) + parser.add_argument("--db-path", help="Custom database path") args = parser.parse_args() - + asyncio.run(run_migration(args.db_path)) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py index d1872d7e..aa80f673 100644 --- a/crawl4ai/model_loader.py +++ b/crawl4ai/model_loader.py @@ -2,109 +2,125 @@ from functools import lru_cache from pathlib import Path import subprocess, os import shutil -import tarfile from .model_loader import * import argparse -import urllib.request from crawl4ai.config import MODEL_REPO_BRANCH + __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + @lru_cache() def get_available_memory(device): import torch - if device.type == 'cuda': + + if device.type == "cuda": return torch.cuda.get_device_properties(device).total_memory - elif device.type == 'mps': - return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate + elif device.type == "mps": + return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate else: return 0 + @lru_cache() def calculate_batch_size(device): available_memory = get_available_memory(device) - - if device.type == 'cpu': + + if device.type == "cpu": return 16 - elif device.type in ['cuda', 'mps']: + elif device.type in ["cuda", "mps"]: # Adjust these thresholds based on your model size and available memory - if available_memory >= 31 * 1024 ** 3: # > 32GB + if available_memory >= 31 * 1024**3: # > 32GB return 256 - elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB + elif available_memory >= 15 * 1024**3: # > 16GB to 32GB return 128 - elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB + elif available_memory >= 8 * 1024**3: # 8GB to 16GB return 64 else: return 32 else: - return 16 # Default batch size - + return 16 # Default batch size + + @lru_cache() def get_device(): import torch + if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") elif torch.backends.mps.is_available(): - device = torch.device('mps') + device = torch.device("mps") else: - device = torch.device('cpu') - return device - + device = torch.device("cpu") + return device + + def set_model_device(model): device = get_device() - model.to(device) + model.to(device) return model, device + @lru_cache() def get_home_folder(): - home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") + home_folder = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai" + ) os.makedirs(home_folder, exist_ok=True) os.makedirs(f"{home_folder}/cache", exist_ok=True) os.makedirs(f"{home_folder}/models", exist_ok=True) - return home_folder + return home_folder + @lru_cache() def load_bert_base_uncased(): - from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) - model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + from transformers import BertTokenizer, BertModel + + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None) + model = BertModel.from_pretrained("bert-base-uncased", resume_download=None) model.eval() model, device = set_model_device(model) return tokenizer, model + @lru_cache() def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple: """Load the Hugging Face model for embedding. - + Args: model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5". - + Returns: tuple: The tokenizer and model. """ - from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel + from transformers import AutoTokenizer, AutoModel + tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None) model = AutoModel.from_pretrained(model_name, resume_download=None) model.eval() model, device = set_model_device(model) return tokenizer, model + @lru_cache() def load_text_classifier(): from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import pipeline - import torch - tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") - model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") + tokenizer = AutoTokenizer.from_pretrained( + "dstefa/roberta-base_topic_classification_nyt_news" + ) + model = AutoModelForSequenceClassification.from_pretrained( + "dstefa/roberta-base_topic_classification_nyt_news" + ) model.eval() model, device = set_model_device(model) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) return pipe + @lru_cache() def load_text_multilabel_classifier(): from transformers import AutoModelForSequenceClassification, AutoTokenizer - import numpy as np from scipy.special import expit import torch @@ -116,18 +132,27 @@ def load_text_multilabel_classifier(): # else: # device = torch.device("cpu") # # return load_spacy_model(), torch.device("cpu") - MODEL = "cardiffnlp/tweet-topic-21-multi" tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) - model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) + model = AutoModelForSequenceClassification.from_pretrained( + MODEL, resume_download=None + ) model.eval() model, device = set_model_device(model) class_mapping = model.config.id2label def _classifier(texts, threshold=0.5, max_length=64): - tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length) - tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device + tokens = tokenizer( + texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length, + ) + tokens = { + key: val.to(device) for key, val in tokens.items() + } # Move tokens to the selected device with torch.no_grad(): output = model(**tokens) @@ -138,35 +163,41 @@ def load_text_multilabel_classifier(): batch_labels = [] for prediction in predictions: - labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1] + labels = [ + class_mapping[i] for i, value in enumerate(prediction) if value == 1 + ] batch_labels.append(labels) return batch_labels return _classifier, device + @lru_cache() def load_nltk_punkt(): import nltk + try: - nltk.data.find('tokenizers/punkt') + nltk.data.find("tokenizers/punkt") except LookupError: - nltk.download('punkt') - return nltk.data.find('tokenizers/punkt') + nltk.download("punkt") + return nltk.data.find("tokenizers/punkt") + @lru_cache() def load_spacy_model(): import spacy + name = "models/reuters" home_folder = get_home_folder() model_folder = Path(home_folder) / name - + # Check if the model directory already exists if not (model_folder.exists() and any(model_folder.iterdir())): repo_url = "https://github.com/unclecode/crawl4ai.git" - branch = MODEL_REPO_BRANCH + branch = MODEL_REPO_BRANCH repo_folder = Path(home_folder) / "crawl4ai" - + print("[LOG] ⏬ Downloading Spacy model for the first time...") # Remove existing repo folder if it exists @@ -176,7 +207,9 @@ def load_spacy_model(): if model_folder.exists(): shutil.rmtree(model_folder) except PermissionError: - print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:") + print( + "[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:" + ) print(f"- {repo_folder}") print(f"- {model_folder}") return None @@ -187,7 +220,7 @@ def load_spacy_model(): ["git", "clone", "-b", branch, repo_url, str(repo_folder)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, - check=True + check=True, ) # Create the models directory if it doesn't exist @@ -215,6 +248,7 @@ def load_spacy_model(): print(f"Error loading spacy model: {e}") return None + def download_all_models(remove_existing=False): """Download all models required for Crawl4AI.""" if remove_existing: @@ -243,14 +277,20 @@ def download_all_models(remove_existing=False): load_nltk_punkt() print("[LOG] ✅ All models downloaded successfully.") + def main(): print("[LOG] Welcome to the Crawl4AI Model Downloader!") print("[LOG] This script will download all the models required for Crawl4AI.") parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader") - parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading") + parser.add_argument( + "--remove-existing", + action="store_true", + help="Remove existing models before downloading", + ) args = parser.parse_args() - + download_all_models(remove_existing=args.remove_existing) + if __name__ == "__main__": main() diff --git a/crawl4ai/models.py b/crawl4ai/models.py index 48aad544..9ab2389b 100644 --- a/crawl4ai/models.py +++ b/crawl4ai/models.py @@ -1,18 +1,12 @@ from pydantic import BaseModel, HttpUrl -from typing import List, Dict, Optional, Callable, Awaitable, Union, Tuple, Any +from typing import List, Dict, Optional, Callable, Awaitable, Union, Any from enum import Enum -from dataclasses import dataclass, field -from .ssl_certificate import SSLCertificate - from dataclasses import dataclass +from .ssl_certificate import SSLCertificate from datetime import datetime -from enum import Enum -from typing import Optional - from datetime import timedelta - ############################### # Dispatcher Models ############################### @@ -22,6 +16,7 @@ class DomainState: current_delay: float = 0 fail_count: int = 0 + @dataclass class CrawlerTaskResult: task_id: str @@ -33,12 +28,14 @@ class CrawlerTaskResult: end_time: datetime error_message: str = "" + class CrawlStatus(Enum): QUEUED = "QUEUED" IN_PROGRESS = "IN_PROGRESS" COMPLETED = "COMPLETED" FAILED = "FAILED" + @dataclass class CrawlStats: task_id: str @@ -49,7 +46,7 @@ class CrawlStats: memory_usage: float = 0.0 peak_memory: float = 0.0 error_message: str = "" - + @property def duration(self) -> str: if not self.start_time: @@ -58,26 +55,29 @@ class CrawlStats: duration = end - self.start_time return str(timedelta(seconds=int(duration.total_seconds()))) + class DisplayMode(Enum): DETAILED = "DETAILED" AGGREGATED = "AGGREGATED" + ############################### # Crawler Models ############################### @dataclass class TokenUsage: completion_tokens: int = 0 - prompt_tokens: int = 0 + prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens_details: Optional[dict] = None prompt_tokens_details: Optional[dict] = None - + class UrlModel(BaseModel): url: HttpUrl forced: bool = False + class MarkdownGenerationResult(BaseModel): raw_markdown: str markdown_with_citations: str @@ -85,6 +85,7 @@ class MarkdownGenerationResult(BaseModel): fit_markdown: Optional[str] = None fit_html: Optional[str] = None + class DispatchResult(BaseModel): task_id: str memory_usage: float @@ -92,6 +93,8 @@ class DispatchResult(BaseModel): start_time: datetime end_time: datetime error_message: str = "" + + class CrawlResult(BaseModel): url: str html: str @@ -101,7 +104,7 @@ class CrawlResult(BaseModel): links: Dict[str, List[Dict]] = {} downloaded_files: Optional[List[str]] = None screenshot: Optional[str] = None - pdf : Optional[bytes] = None + pdf: Optional[bytes] = None markdown: Optional[Union[str, MarkdownGenerationResult]] = None markdown_v2: Optional[MarkdownGenerationResult] = None fit_markdown: Optional[str] = None @@ -114,9 +117,11 @@ class CrawlResult(BaseModel): status_code: Optional[int] = None ssl_certificate: Optional[SSLCertificate] = None dispatch_result: Optional[DispatchResult] = None + class Config: arbitrary_types_allowed = True + class AsyncCrawlResponse(BaseModel): html: str response_headers: Dict[str, str] @@ -130,6 +135,7 @@ class AsyncCrawlResponse(BaseModel): class Config: arbitrary_types_allowed = True + ############################### # Scraping Models ############################### @@ -143,21 +149,29 @@ class MediaItem(BaseModel): format: Optional[str] = None width: Optional[int] = None + class Link(BaseModel): href: str text: str title: Optional[str] = None base_domain: str + class Media(BaseModel): images: List[MediaItem] = [] - videos: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Video model if needed - audios: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Audio model if needed + videos: List[ + MediaItem + ] = [] # Using MediaItem model for now, can be extended with Video model if needed + audios: List[ + MediaItem + ] = [] # Using MediaItem model for now, can be extended with Audio model if needed + class Links(BaseModel): internal: List[Link] = [] external: List[Link] = [] + class ScrapingResult(BaseModel): cleaned_html: str success: bool diff --git a/crawl4ai/ssl_certificate.py b/crawl4ai/ssl_certificate.py index 97529e3e..722bb7f9 100644 --- a/crawl4ai/ssl_certificate.py +++ b/crawl4ai/ssl_certificate.py @@ -13,10 +13,10 @@ from pathlib import Path class SSLCertificate: """ A class representing an SSL certificate with methods to export in various formats. - + Attributes: cert_info (Dict[str, Any]): The certificate information. - + Methods: from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: Create SSLCertificate instance from a URL. from_file(file_path: str) -> Optional['SSLCertificate']: Create SSLCertificate instance from a file. @@ -26,32 +26,35 @@ class SSLCertificate: export_as_json() -> Dict[str, Any]: Export the certificate as JSON format. export_as_text() -> str: Export the certificate as text format. """ + def __init__(self, cert_info: Dict[str, Any]): self._cert_info = self._decode_cert_data(cert_info) @staticmethod - def from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: + def from_url(url: str, timeout: int = 10) -> Optional["SSLCertificate"]: """ Create SSLCertificate instance from a URL. - + Args: url (str): URL of the website. timeout (int): Timeout for the connection (default: 10). - + Returns: Optional[SSLCertificate]: SSLCertificate instance if successful, None otherwise. """ try: hostname = urlparse(url).netloc - if ':' in hostname: - hostname = hostname.split(':')[0] - + if ":" in hostname: + hostname = hostname.split(":")[0] + context = ssl.create_default_context() with socket.create_connection((hostname, 443), timeout=timeout) as sock: with context.wrap_socket(sock, server_hostname=hostname) as ssock: cert_binary = ssock.getpeercert(binary_form=True) - x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary) - + x509 = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_ASN1, cert_binary + ) + cert_info = { "subject": dict(x509.get_subject().get_components()), "issuer": dict(x509.get_issuer().get_components()), @@ -61,32 +64,33 @@ class SSLCertificate: "not_after": x509.get_notAfter(), "fingerprint": x509.digest("sha256").hex(), "signature_algorithm": x509.get_signature_algorithm(), - "raw_cert": base64.b64encode(cert_binary) + "raw_cert": base64.b64encode(cert_binary), } - + # Add extensions extensions = [] for i in range(x509.get_extension_count()): ext = x509.get_extension(i) - extensions.append({ - "name": ext.get_short_name(), - "value": str(ext) - }) + extensions.append( + {"name": ext.get_short_name(), "value": str(ext)} + ) cert_info["extensions"] = extensions - + return SSLCertificate(cert_info) - - except Exception as e: + + except Exception: return None @staticmethod def _decode_cert_data(data: Any) -> Any: """Helper method to decode bytes in certificate data.""" if isinstance(data, bytes): - return data.decode('utf-8') + return data.decode("utf-8") elif isinstance(data, dict): return { - (k.decode('utf-8') if isinstance(k, bytes) else k): SSLCertificate._decode_cert_data(v) + ( + k.decode("utf-8") if isinstance(k, bytes) else k + ): SSLCertificate._decode_cert_data(v) for k, v in data.items() } elif isinstance(data, list): @@ -96,58 +100,57 @@ class SSLCertificate: def to_json(self, filepath: Optional[str] = None) -> Optional[str]: """ Export certificate as JSON. - + Args: filepath (Optional[str]): Path to save the JSON file (default: None). - + Returns: Optional[str]: JSON string if successful, None otherwise. """ json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False) if filepath: - Path(filepath).write_text(json_str, encoding='utf-8') + Path(filepath).write_text(json_str, encoding="utf-8") return None return json_str def to_pem(self, filepath: Optional[str] = None) -> Optional[str]: """ Export certificate as PEM. - + Args: filepath (Optional[str]): Path to save the PEM file (default: None). - + Returns: Optional[str]: PEM string if successful, None otherwise. """ try: x509 = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_ASN1, - base64.b64decode(self._cert_info['raw_cert']) + OpenSSL.crypto.FILETYPE_ASN1, + base64.b64decode(self._cert_info["raw_cert"]), ) pem_data = OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - x509 - ).decode('utf-8') - + OpenSSL.crypto.FILETYPE_PEM, x509 + ).decode("utf-8") + if filepath: - Path(filepath).write_text(pem_data, encoding='utf-8') + Path(filepath).write_text(pem_data, encoding="utf-8") return None return pem_data - except Exception as e: + except Exception: return None def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]: """ Export certificate as DER. - + Args: filepath (Optional[str]): Path to save the DER file (default: None). - + Returns: Optional[bytes]: DER bytes if successful, None otherwise. """ try: - der_data = base64.b64decode(self._cert_info['raw_cert']) + der_data = base64.b64decode(self._cert_info["raw_cert"]) if filepath: Path(filepath).write_bytes(der_data) return None @@ -158,24 +161,24 @@ class SSLCertificate: @property def issuer(self) -> Dict[str, str]: """Get certificate issuer information.""" - return self._cert_info.get('issuer', {}) + return self._cert_info.get("issuer", {}) @property def subject(self) -> Dict[str, str]: """Get certificate subject information.""" - return self._cert_info.get('subject', {}) + return self._cert_info.get("subject", {}) @property def valid_from(self) -> str: """Get certificate validity start date.""" - return self._cert_info.get('not_before', '') + return self._cert_info.get("not_before", "") @property def valid_until(self) -> str: """Get certificate validity end date.""" - return self._cert_info.get('not_after', '') + return self._cert_info.get("not_after", "") @property def fingerprint(self) -> str: """Get certificate fingerprint.""" - return self._cert_info.get('fingerprint', '') + return self._cert_info.get("fingerprint", "") diff --git a/crawl4ai/user_agent_generator.py b/crawl4ai/user_agent_generator.py index 6679bb1b..4f0f42cb 100644 --- a/crawl4ai/user_agent_generator.py +++ b/crawl4ai/user_agent_generator.py @@ -6,7 +6,7 @@ import re class UserAgentGenerator: """ Generate random user agents with specified constraints. - + Attributes: desktop_platforms (dict): A dictionary of possible desktop platforms and their corresponding user agent strings. mobile_platforms (dict): A dictionary of possible mobile platforms and their corresponding user agent strings. @@ -18,7 +18,7 @@ class UserAgentGenerator: safari_versions (list): A list of possible Safari browser versions. ios_versions (list): A list of possible iOS browser versions. android_versions (list): A list of possible Android browser versions. - + Methods: generate_user_agent( platform: Literal["desktop", "mobile"] = "desktop", @@ -30,8 +30,9 @@ class UserAgentGenerator: safari_version: Optional[str] = None, ios_version: Optional[str] = None, android_version: Optional[str] = None - ): Generates a random user agent string based on the specified parameters. + ): Generates a random user agent string based on the specified parameters. """ + def __init__(self): # Previous platform definitions remain the same... self.desktop_platforms = { @@ -47,7 +48,7 @@ class UserAgentGenerator: "generic": "(X11; Linux x86_64)", "ubuntu": "(X11; Ubuntu; Linux x86_64)", "chrome_os": "(X11; CrOS x86_64 14541.0.0)", - } + }, } self.mobile_platforms = { @@ -60,26 +61,14 @@ class UserAgentGenerator: "ios": { "iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)", "ipad": "(iPad; CPU OS 16_5 like Mac OS X)", - } + }, } # Browser Combinations self.browser_combinations = { - 1: [ - ["chrome"], - ["firefox"], - ["safari"], - ["edge"] - ], - 2: [ - ["gecko", "firefox"], - ["chrome", "safari"], - ["webkit", "safari"] - ], - 3: [ - ["chrome", "safari", "edge"], - ["webkit", "chrome", "safari"] - ] + 1: [["chrome"], ["firefox"], ["safari"], ["edge"]], + 2: [["gecko", "firefox"], ["chrome", "safari"], ["webkit", "safari"]], + 3: [["chrome", "safari", "edge"], ["webkit", "chrome", "safari"]], } # Rendering Engines with versions @@ -90,7 +79,7 @@ class UserAgentGenerator: "Gecko/20100101", "Gecko/20100101", # Firefox usually uses this constant version "Gecko/2010010", - ] + ], } # Browser Versions @@ -135,25 +124,25 @@ class UserAgentGenerator: def get_browser_stack(self, num_browsers: int = 1) -> List[str]: """ Get a valid combination of browser versions. - + How it works: 1. Check if the number of browsers is supported. 2. Randomly choose a combination of browsers. 3. Iterate through the combination and add browser versions. 4. Return the browser stack. - + Args: num_browsers: Number of browser specifications (1-3) - + Returns: List[str]: A list of browser versions. """ if num_browsers not in self.browser_combinations: raise ValueError(f"Unsupported number of browsers: {num_browsers}") - + combination = random.choice(self.browser_combinations[num_browsers]) browser_stack = [] - + for browser in combination: if browser == "chrome": browser_stack.append(random.choice(self.chrome_versions)) @@ -167,18 +156,20 @@ class UserAgentGenerator: browser_stack.append(random.choice(self.rendering_engines["gecko"])) elif browser == "webkit": browser_stack.append(self.rendering_engines["chrome_webkit"]) - + return browser_stack - def generate(self, - device_type: Optional[Literal['desktop', 'mobile']] = None, - os_type: Optional[str] = None, - device_brand: Optional[str] = None, - browser_type: Optional[Literal['chrome', 'edge', 'safari', 'firefox']] = None, - num_browsers: int = 3) -> str: + def generate( + self, + device_type: Optional[Literal["desktop", "mobile"]] = None, + os_type: Optional[str] = None, + device_brand: Optional[str] = None, + browser_type: Optional[Literal["chrome", "edge", "safari", "firefox"]] = None, + num_browsers: int = 3, + ) -> str: """ Generate a random user agent with specified constraints. - + Args: device_type: 'desktop' or 'mobile' os_type: 'windows', 'macos', 'linux', 'android', 'ios' @@ -188,23 +179,23 @@ class UserAgentGenerator: """ # Get platform string platform = self.get_random_platform(device_type, os_type, device_brand) - + # Start with Mozilla components = ["Mozilla/5.0", platform] - + # Add browser stack browser_stack = self.get_browser_stack(num_browsers) - + # Add appropriate legacy token based on browser stack if "Firefox" in str(browser_stack): components.append(random.choice(self.rendering_engines["gecko"])) elif "Chrome" in str(browser_stack) or "Safari" in str(browser_stack): components.append(self.rendering_engines["chrome_webkit"]) components.append("(KHTML, like Gecko)") - + # Add browser versions components.extend(browser_stack) - + return " ".join(components) def generate_with_client_hints(self, **kwargs) -> Tuple[str, str]: @@ -215,16 +206,20 @@ class UserAgentGenerator: def get_random_platform(self, device_type, os_type, device_brand): """Helper method to get random platform based on constraints""" - platforms = self.desktop_platforms if device_type == 'desktop' else \ - self.mobile_platforms if device_type == 'mobile' else \ - {**self.desktop_platforms, **self.mobile_platforms} - + platforms = ( + self.desktop_platforms + if device_type == "desktop" + else self.mobile_platforms + if device_type == "mobile" + else {**self.desktop_platforms, **self.mobile_platforms} + ) + if os_type: for platform_group in [self.desktop_platforms, self.mobile_platforms]: if os_type in platform_group: platforms = {os_type: platform_group[os_type]} break - + os_key = random.choice(list(platforms.keys())) if device_brand and device_brand in platforms[os_key]: return platforms[os_key][device_brand] @@ -233,73 +228,72 @@ class UserAgentGenerator: def parse_user_agent(self, user_agent: str) -> Dict[str, str]: """Parse a user agent string to extract browser and version information""" browsers = { - 'chrome': r'Chrome/(\d+)', - 'edge': r'Edg/(\d+)', - 'safari': r'Version/(\d+)', - 'firefox': r'Firefox/(\d+)' + "chrome": r"Chrome/(\d+)", + "edge": r"Edg/(\d+)", + "safari": r"Version/(\d+)", + "firefox": r"Firefox/(\d+)", } - + result = {} for browser, pattern in browsers.items(): match = re.search(pattern, user_agent) if match: result[browser] = match.group(1) - + return result def generate_client_hints(self, user_agent: str) -> str: """Generate Sec-CH-UA header value based on user agent string""" browsers = self.parse_user_agent(user_agent) - + # Client hints components hints = [] - + # Handle different browser combinations - if 'chrome' in browsers: + if "chrome" in browsers: hints.append(f'"Chromium";v="{browsers["chrome"]}"') hints.append('"Not_A Brand";v="8"') - - if 'edge' in browsers: + + if "edge" in browsers: hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"') else: hints.append(f'"Google Chrome";v="{browsers["chrome"]}"') - - elif 'firefox' in browsers: + + elif "firefox" in browsers: # Firefox doesn't typically send Sec-CH-UA return '""' - - elif 'safari' in browsers: + + elif "safari" in browsers: # Safari's format for client hints hints.append(f'"Safari";v="{browsers["safari"]}"') hints.append('"Not_A Brand";v="8"') - - return ', '.join(hints) + + return ", ".join(hints) + # Example usage: if __name__ == "__main__": generator = UserAgentGenerator() print(generator.generate()) - + print("\nSingle browser (Chrome):") - print(generator.generate(num_browsers=1, browser_type='chrome')) - + print(generator.generate(num_browsers=1, browser_type="chrome")) + print("\nTwo browsers (Gecko/Firefox):") print(generator.generate(num_browsers=2)) - + print("\nThree browsers (Chrome/Safari/Edge):") print(generator.generate(num_browsers=3)) - + print("\nFirefox on Linux:") - print(generator.generate( - device_type='desktop', - os_type='linux', - browser_type='firefox', - num_browsers=2 - )) - + print( + generator.generate( + device_type="desktop", + os_type="linux", + browser_type="firefox", + num_browsers=2, + ) + ) + print("\nChrome/Safari/Edge on Windows:") - print(generator.generate( - device_type='desktop', - os_type='windows', - num_browsers=3 - )) \ No newline at end of file + print(generator.generate(device_type="desktop", os_type="windows", num_browsers=3)) diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 4dbac2a6..63c8a092 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -14,7 +14,7 @@ from typing import Dict, Any from urllib.parse import urljoin import requests from requests.exceptions import InvalidSchema -from typing import Optional, Tuple, Dict, Any +from typing import Dict, Any import xxhash from colorama import Fore, Style, init import textwrap @@ -27,7 +27,14 @@ import asyncio class InvalidCSSSelectorError(Exception): pass -def create_box_message(message: str, type: str = "info", width: int = 120, add_newlines: bool = True, double_line: bool = False) -> str: + +def create_box_message( + message: str, + type: str = "info", + width: int = 120, + add_newlines: bool = True, + double_line: bool = False, +) -> str: """ Create a styled message box with colored borders and formatted text. @@ -53,7 +60,7 @@ def create_box_message(message: str, type: str = "info", width: int = 120, add_n # Define border and text colors for different types styles = { "warning": (Fore.YELLOW, Fore.LIGHTYELLOW_EX, "⚠"), - "info": (Fore.BLUE, Fore.LIGHTBLUE_EX, "ℹ"), + "info": (Fore.BLUE, Fore.LIGHTBLUE_EX, "ℹ"), "success": (Fore.GREEN, Fore.LIGHTGREEN_EX, "✓"), "error": (Fore.RED, Fore.LIGHTRED_EX, "×"), } @@ -63,24 +70,24 @@ def create_box_message(message: str, type: str = "info", width: int = 120, add_n # Define box characters based on line style box_chars = { "single": ("─", "│", "┌", "┐", "└", "┘"), - "double": ("═", "║", "╔", "╗", "╚", "╝") + "double": ("═", "║", "╔", "╗", "╚", "╝"), } line_style = "double" if double_line else "single" h_line, v_line, tl, tr, bl, br = box_chars[line_style] # Process lines with lighter text color formatted_lines = [] - raw_lines = message.split('\n') + raw_lines = message.split("\n") if raw_lines: first_line = f"{prefix} {raw_lines[0].strip()}" - wrapped_first = textwrap.fill(first_line, width=width-4) - formatted_lines.extend(wrapped_first.split('\n')) - + wrapped_first = textwrap.fill(first_line, width=width - 4) + formatted_lines.extend(wrapped_first.split("\n")) + for line in raw_lines[1:]: if line.strip(): - wrapped = textwrap.fill(f" {line.strip()}", width=width-4) - formatted_lines.extend(wrapped.split('\n')) + wrapped = textwrap.fill(f" {line.strip()}", width=width - 4) + formatted_lines.extend(wrapped.split("\n")) else: formatted_lines.append("") @@ -88,8 +95,11 @@ def create_box_message(message: str, type: str = "info", width: int = 120, add_n horizontal_line = h_line * (width - 1) box = [ f"{border_color}{tl}{horizontal_line}{tr}", - *[f"{border_color}{v_line}{text_color} {line:<{width-2}}{border_color}{v_line}" for line in formatted_lines], - f"{border_color}{bl}{horizontal_line}{br}{Style.RESET_ALL}" + *[ + f"{border_color}{v_line}{text_color} {line:<{width-2}}{border_color}{v_line}" + for line in formatted_lines + ], + f"{border_color}{bl}{horizontal_line}{br}{Style.RESET_ALL}", ] result = "\n".join(box) @@ -98,6 +108,7 @@ def create_box_message(message: str, type: str = "info", width: int = 120, add_n return result + def calculate_semaphore_count(): """ Calculate the optimal semaphore count based on system resources. @@ -111,13 +122,14 @@ def calculate_semaphore_count(): Returns: int: The calculated semaphore count. """ - + cpu_count = os.cpu_count() - memory_gb = get_system_memory() / (1024 ** 3) # Convert to GB + memory_gb = get_system_memory() / (1024**3) # Convert to GB base_count = max(1, cpu_count // 2) memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance return min(base_count, memory_based_cap) + def get_system_memory(): """ Get the total system memory in bytes. @@ -136,30 +148,34 @@ def get_system_memory(): system = platform.system() if system == "Linux": - with open('/proc/meminfo', 'r') as mem: + with open("/proc/meminfo", "r") as mem: for line in mem: - if line.startswith('MemTotal:'): + if line.startswith("MemTotal:"): return int(line.split()[1]) * 1024 # Convert KB to bytes elif system == "Darwin": # macOS import subprocess - output = subprocess.check_output(['sysctl', '-n', 'hw.memsize']).decode('utf-8') + + output = subprocess.check_output(["sysctl", "-n", "hw.memsize"]).decode("utf-8") return int(output.strip()) elif system == "Windows": import ctypes + kernel32 = ctypes.windll.kernel32 c_ulonglong = ctypes.c_ulonglong + class MEMORYSTATUSEX(ctypes.Structure): _fields_ = [ - ('dwLength', ctypes.c_ulong), - ('dwMemoryLoad', ctypes.c_ulong), - ('ullTotalPhys', c_ulonglong), - ('ullAvailPhys', c_ulonglong), - ('ullTotalPageFile', c_ulonglong), - ('ullAvailPageFile', c_ulonglong), - ('ullTotalVirtual', c_ulonglong), - ('ullAvailVirtual', c_ulonglong), - ('ullAvailExtendedVirtual', c_ulonglong), + ("dwLength", ctypes.c_ulong), + ("dwMemoryLoad", ctypes.c_ulong), + ("ullTotalPhys", c_ulonglong), + ("ullAvailPhys", c_ulonglong), + ("ullTotalPageFile", c_ulonglong), + ("ullAvailPageFile", c_ulonglong), + ("ullTotalVirtual", c_ulonglong), + ("ullAvailVirtual", c_ulonglong), + ("ullAvailExtendedVirtual", c_ulonglong), ] + memoryStatus = MEMORYSTATUSEX() memoryStatus.dwLength = ctypes.sizeof(MEMORYSTATUSEX) kernel32.GlobalMemoryStatusEx(ctypes.byref(memoryStatus)) @@ -167,6 +183,7 @@ def get_system_memory(): else: raise OSError("Unsupported operating system") + def get_home_folder(): """ Get or create the home folder for Crawl4AI configuration and cache. @@ -180,75 +197,84 @@ def get_home_folder(): str: The path to the Crawl4AI home folder. """ - home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home())), ".crawl4ai") + home_folder = os.path.join( + os.getenv( + "CRAWL4_AI_BASE_DIRECTORY", + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), + ), + ".crawl4ai", + ) os.makedirs(home_folder, exist_ok=True) os.makedirs(f"{home_folder}/cache", exist_ok=True) os.makedirs(f"{home_folder}/models", exist_ok=True) - return home_folder + return home_folder + def beautify_html(escaped_html): """ Beautifies an escaped HTML string. - + Parameters: escaped_html (str): A string containing escaped HTML. - + Returns: str: A beautifully formatted HTML string. """ # Unescape the HTML string unescaped_html = html.unescape(escaped_html) - + # Use BeautifulSoup to parse and prettify the HTML - soup = BeautifulSoup(unescaped_html, 'html.parser') + soup = BeautifulSoup(unescaped_html, "html.parser") pretty_html = soup.prettify() - + return pretty_html + def split_and_parse_json_objects(json_string): """ Splits a JSON string which is a list of objects and tries to parse each object. - + Parameters: json_string (str): A string representation of a list of JSON objects, e.g., '[{...}, {...}, ...]'. - + Returns: tuple: A tuple containing two lists: - First list contains all successfully parsed JSON objects. - Second list contains the string representations of all segments that couldn't be parsed. """ # Trim the leading '[' and trailing ']' - if json_string.startswith('[') and json_string.endswith(']'): + if json_string.startswith("[") and json_string.endswith("]"): json_string = json_string[1:-1].strip() - + # Split the string into segments that look like individual JSON objects segments = [] depth = 0 start_index = 0 - + for i, char in enumerate(json_string): - if char == '{': + if char == "{": if depth == 0: start_index = i depth += 1 - elif char == '}': + elif char == "}": depth -= 1 if depth == 0: - segments.append(json_string[start_index:i+1]) - + segments.append(json_string[start_index : i + 1]) + # Try parsing each segment parsed_objects = [] unparsed_segments = [] - + for segment in segments: try: obj = json.loads(segment) parsed_objects.append(obj) except json.JSONDecodeError: unparsed_segments.append(segment) - + return parsed_objects, unparsed_segments + def sanitize_html(html): """ Sanitize an HTML string by escaping quotes. @@ -263,7 +289,7 @@ def sanitize_html(html): Returns: str: The sanitized HTML string. """ - + # Replace all unwanted and special characters with an empty string sanitized_html = html # sanitized_html = re.sub(r'[^\w\s.,;:!?=\[\]{}()<>\/\\\-"]', '', html) @@ -273,21 +299,25 @@ def sanitize_html(html): return sanitized_html + def sanitize_input_encode(text: str) -> str: """Sanitize input to handle potential encoding issues.""" try: try: if not text: - return '' + return "" # Attempt to encode and decode as UTF-8 to handle potential encoding issues - return text.encode('utf-8', errors='ignore').decode('utf-8') + return text.encode("utf-8", errors="ignore").decode("utf-8") except UnicodeEncodeError as e: - print(f"Warning: Encoding issue detected. Some characters may be lost. Error: {e}") + print( + f"Warning: Encoding issue detected. Some characters may be lost. Error: {e}" + ) # Fall back to ASCII if UTF-8 fails - return text.encode('ascii', errors='ignore').decode('ascii') + return text.encode("ascii", errors="ignore").decode("ascii") except Exception as e: raise ValueError(f"Error sanitizing input: {str(e)}") from e + def escape_json_string(s): """ Escapes characters in a string to be JSON safe. @@ -299,24 +329,25 @@ def escape_json_string(s): str: The escaped string, safe for JSON encoding. """ # Replace problematic backslash first - s = s.replace('\\', '\\\\') - + s = s.replace("\\", "\\\\") + # Replace the double quote s = s.replace('"', '\\"') - + # Escape control characters - s = s.replace('\b', '\\b') - s = s.replace('\f', '\\f') - s = s.replace('\n', '\\n') - s = s.replace('\r', '\\r') - s = s.replace('\t', '\\t') - + s = s.replace("\b", "\\b") + s = s.replace("\f", "\\f") + s = s.replace("\n", "\\n") + s = s.replace("\r", "\\r") + s = s.replace("\t", "\\t") + # Additional problematic characters # Unicode control characters - s = re.sub(r'[\x00-\x1f\x7f-\x9f]', lambda x: '\\u{:04x}'.format(ord(x.group())), s) - + s = re.sub(r"[\x00-\x1f\x7f-\x9f]", lambda x: "\\u{:04x}".format(ord(x.group())), s) + return s + def replace_inline_tags(soup, tags, only_text=False): """ Replace inline HTML tags with Markdown-style equivalents. @@ -336,37 +367,39 @@ def replace_inline_tags(soup, tags, only_text=False): """ tag_replacements = { - 'b': lambda tag: f"**{tag.text}**", - 'i': lambda tag: f"*{tag.text}*", - 'u': lambda tag: f"__{tag.text}__", - 'span': lambda tag: f"{tag.text}", - 'del': lambda tag: f"~~{tag.text}~~", - 'ins': lambda tag: f"++{tag.text}++", - 'sub': lambda tag: f"~{tag.text}~", - 'sup': lambda tag: f"^^{tag.text}^^", - 'strong': lambda tag: f"**{tag.text}**", - 'em': lambda tag: f"*{tag.text}*", - 'code': lambda tag: f"`{tag.text}`", - 'kbd': lambda tag: f"`{tag.text}`", - 'var': lambda tag: f"_{tag.text}_", - 's': lambda tag: f"~~{tag.text}~~", - 'q': lambda tag: f'"{tag.text}"', - 'abbr': lambda tag: f"{tag.text} ({tag.get('title', '')})", - 'cite': lambda tag: f"_{tag.text}_", - 'dfn': lambda tag: f"_{tag.text}_", - 'time': lambda tag: f"{tag.text}", - 'small': lambda tag: f"{tag.text}", - 'mark': lambda tag: f"=={tag.text}==" + "b": lambda tag: f"**{tag.text}**", + "i": lambda tag: f"*{tag.text}*", + "u": lambda tag: f"__{tag.text}__", + "span": lambda tag: f"{tag.text}", + "del": lambda tag: f"~~{tag.text}~~", + "ins": lambda tag: f"++{tag.text}++", + "sub": lambda tag: f"~{tag.text}~", + "sup": lambda tag: f"^^{tag.text}^^", + "strong": lambda tag: f"**{tag.text}**", + "em": lambda tag: f"*{tag.text}*", + "code": lambda tag: f"`{tag.text}`", + "kbd": lambda tag: f"`{tag.text}`", + "var": lambda tag: f"_{tag.text}_", + "s": lambda tag: f"~~{tag.text}~~", + "q": lambda tag: f'"{tag.text}"', + "abbr": lambda tag: f"{tag.text} ({tag.get('title', '')})", + "cite": lambda tag: f"_{tag.text}_", + "dfn": lambda tag: f"_{tag.text}_", + "time": lambda tag: f"{tag.text}", + "small": lambda tag: f"{tag.text}", + "mark": lambda tag: f"=={tag.text}==", } - - replacement_data = [(tag, tag_replacements.get(tag, lambda t: t.text)) for tag in tags] + + replacement_data = [ + (tag, tag_replacements.get(tag, lambda t: t.text)) for tag in tags + ] for tag_name, replacement_func in replacement_data: for tag in soup.find_all(tag_name): replacement_text = tag.text if only_text else replacement_func(tag) tag.replace_with(replacement_text) - return soup + return soup # for tag_name in tags: # for tag in soup.find_all(tag_name): @@ -378,7 +411,10 @@ def replace_inline_tags(soup, tags, only_text=False): # return soup -def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, css_selector = None, **kwargs): + +def get_content_of_website( + url, html, word_count_threshold=MIN_WORD_THRESHOLD, css_selector=None, **kwargs +): """ Extract structured content, media, and links from website HTML. @@ -403,120 +439,128 @@ def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, if not html: return None # Parse HTML content with BeautifulSoup - soup = BeautifulSoup(html, 'html.parser') + soup = BeautifulSoup(html, "html.parser") # Get the content within the tag body = soup.body - + # If css_selector is provided, extract content based on the selector if css_selector: selected_elements = body.select(css_selector) if not selected_elements: - raise InvalidCSSSelectorError(f"Invalid CSS selector , No elements found for CSS selector: {css_selector}") - div_tag = soup.new_tag('div') + raise InvalidCSSSelectorError( + f"Invalid CSS selector , No elements found for CSS selector: {css_selector}" + ) + div_tag = soup.new_tag("div") for el in selected_elements: div_tag.append(el) body = div_tag - - links = { - 'internal': [], - 'external': [] - } - + + links = {"internal": [], "external": []} + # Extract all internal and external links - for a in body.find_all('a', href=True): - href = a['href'] - url_base = url.split('/')[2] - if href.startswith('http') and url_base not in href: - links['external'].append({ - 'href': href, - 'text': a.get_text() - }) + for a in body.find_all("a", href=True): + href = a["href"] + url_base = url.split("/")[2] + if href.startswith("http") and url_base not in href: + links["external"].append({"href": href, "text": a.get_text()}) else: - links['internal'].append( - { - 'href': href, - 'text': a.get_text() - } - ) + links["internal"].append({"href": href, "text": a.get_text()}) # Remove script, style, and other tags that don't carry useful content from body - for tag in body.find_all(['script', 'style', 'link', 'meta', 'noscript']): + for tag in body.find_all(["script", "style", "link", "meta", "noscript"]): tag.decompose() # Remove all attributes from remaining tags in body, except for img tags for tag in body.find_all(): - if tag.name != 'img': + if tag.name != "img": tag.attrs = {} # Extract all img tgas int0 [{src: '', alt: ''}] - media = { - 'images': [], - 'videos': [], - 'audios': [] - } - for img in body.find_all('img'): - media['images'].append({ - 'src': img.get('src'), - 'alt': img.get('alt'), - "type": "image" - }) - + media = {"images": [], "videos": [], "audios": []} + for img in body.find_all("img"): + media["images"].append( + {"src": img.get("src"), "alt": img.get("alt"), "type": "image"} + ) + # Extract all video tags into [{src: '', alt: ''}] - for video in body.find_all('video'): - media['videos'].append({ - 'src': video.get('src'), - 'alt': video.get('alt'), - "type": "video" - }) - + for video in body.find_all("video"): + media["videos"].append( + {"src": video.get("src"), "alt": video.get("alt"), "type": "video"} + ) + # Extract all audio tags into [{src: '', alt: ''}] - for audio in body.find_all('audio'): - media['audios'].append({ - 'src': audio.get('src'), - 'alt': audio.get('alt'), - "type": "audio" - }) - + for audio in body.find_all("audio"): + media["audios"].append( + {"src": audio.get("src"), "alt": audio.get("alt"), "type": "audio"} + ) + # Replace images with their alt text or remove them if no alt text is available - for img in body.find_all('img'): - alt_text = img.get('alt') + for img in body.find_all("img"): + alt_text = img.get("alt") if alt_text: img.replace_with(soup.new_string(alt_text)) else: img.decompose() - # Create a function that replace content of all"pre" tag with its inner text def replace_pre_tags_with_text(node): - for child in node.find_all('pre'): + for child in node.find_all("pre"): # set child inner html to its text child.string = child.get_text() return node - + # Replace all "pre" tags with their inner text body = replace_pre_tags_with_text(body) - + # Replace inline tags with their text content body = replace_inline_tags( - body, - ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark'], - only_text=kwargs.get('only_text', False) + body, + [ + "b", + "i", + "u", + "span", + "del", + "ins", + "sub", + "sup", + "strong", + "em", + "code", + "kbd", + "var", + "s", + "q", + "abbr", + "cite", + "dfn", + "time", + "small", + "mark", + ], + only_text=kwargs.get("only_text", False), ) # Recursively remove empty elements, their parent elements, and elements with word count below threshold def remove_empty_and_low_word_count_elements(node, word_count_threshold): for child in node.contents: if isinstance(child, element.Tag): - remove_empty_and_low_word_count_elements(child, word_count_threshold) + remove_empty_and_low_word_count_elements( + child, word_count_threshold + ) word_count = len(child.get_text(strip=True).split()) - if (len(child.contents) == 0 and not child.get_text(strip=True)) or word_count < word_count_threshold: + if ( + len(child.contents) == 0 and not child.get_text(strip=True) + ) or word_count < word_count_threshold: child.decompose() return node body = remove_empty_and_low_word_count_elements(body, word_count_threshold) - - def remove_small_text_tags(body: Tag, word_count_threshold: int = MIN_WORD_THRESHOLD): + + def remove_small_text_tags( + body: Tag, word_count_threshold: int = MIN_WORD_THRESHOLD + ): # We'll use a list to collect all tags that don't meet the word count requirement tags_to_remove = [] @@ -535,11 +579,10 @@ def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, tag.decompose() # or tag.extract() to remove and get the element return body - - + # Remove small text tags - body = remove_small_text_tags(body, word_count_threshold) - + body = remove_small_text_tags(body, word_count_threshold) + def is_empty_or_whitespace(tag: Tag): if isinstance(tag, NavigableString): return not tag.strip() @@ -554,41 +597,43 @@ def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, while changes: changes = False # Collect all tags that are empty or contain only whitespace - empty_tags = [tag for tag in body.find_all(True) if is_empty_or_whitespace(tag)] + empty_tags = [ + tag for tag in body.find_all(True) if is_empty_or_whitespace(tag) + ] for tag in empty_tags: # If a tag is empty, decompose it tag.decompose() changes = True # Mark that a change was made - return body + return body - # Remove empty tags body = remove_empty_tags(body) - + # Flatten nested elements with only one child of the same type def flatten_nested_elements(node): for child in node.contents: if isinstance(child, element.Tag): flatten_nested_elements(child) - if len(child.contents) == 1 and child.contents[0].name == child.name: + if ( + len(child.contents) == 1 + and child.contents[0].name == child.name + ): # print('Flattening:', child.name) child_content = child.contents[0] child.replace_with(child_content) - + return node body = flatten_nested_elements(body) - - # Remove comments - for comment in soup.find_all(string=lambda text: isinstance(text, Comment)): + for comment in soup.find_all(string=lambda text: isinstance(text, Comment)): comment.extract() # Remove consecutive empty newlines and replace multiple spaces with a single space - cleaned_html = str(body).replace('\n\n', '\n').replace(' ', ' ') - + cleaned_html = str(body).replace("\n\n", "\n").replace(" ", " ") + # Sanitize the cleaned HTML content cleaned_html = sanitize_html(cleaned_html) # sanitized_html = escape_json_string(cleaned_html) @@ -598,81 +643,97 @@ def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, h = CustomHTML2Text() h.ignore_links = True markdown = h.handle(cleaned_html) - markdown = markdown.replace(' ```', '```') - + markdown = markdown.replace(" ```", "```") + try: meta = extract_metadata(html, soup) except Exception as e: - print('Error extracting metadata:', str(e)) + print("Error extracting metadata:", str(e)) meta = {} - - + # Return the Markdown content - return{ - 'markdown': markdown, - 'cleaned_html': cleaned_html, - 'success': True, - 'media': media, - 'links': links, - 'metadata': meta + return { + "markdown": markdown, + "cleaned_html": cleaned_html, + "success": True, + "media": media, + "links": links, + "metadata": meta, } except Exception as e: - print('Error processing HTML content:', str(e)) + print("Error processing HTML content:", str(e)) raise InvalidCSSSelectorError(f"Invalid CSS selector: {css_selector}") from e -def get_content_of_website_optimized(url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, css_selector: str = None, **kwargs) -> Dict[str, Any]: + +def get_content_of_website_optimized( + url: str, + html: str, + word_count_threshold: int = MIN_WORD_THRESHOLD, + css_selector: str = None, + **kwargs, +) -> Dict[str, Any]: if not html: return None - soup = BeautifulSoup(html, 'html.parser') + soup = BeautifulSoup(html, "html.parser") body = soup.body - - image_description_min_word_threshold = kwargs.get('image_description_min_word_threshold', IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD) - for tag in kwargs.get('excluded_tags', []) or []: + image_description_min_word_threshold = kwargs.get( + "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD + ) + + for tag in kwargs.get("excluded_tags", []) or []: for el in body.select(tag): el.decompose() - + if css_selector: selected_elements = body.select(css_selector) if not selected_elements: - raise InvalidCSSSelectorError(f"Invalid CSS selector, No elements found for CSS selector: {css_selector}") - body = soup.new_tag('div') + raise InvalidCSSSelectorError( + f"Invalid CSS selector, No elements found for CSS selector: {css_selector}" + ) + body = soup.new_tag("div") for el in selected_elements: body.append(el) - links = {'internal': [], 'external': []} - media = {'images': [], 'videos': [], 'audios': []} + links = {"internal": [], "external": []} + media = {"images": [], "videos": [], "audios": []} # Extract meaningful text for media files from closest parent def find_closest_parent_with_useful_text(tag): - current_tag = tag - while current_tag: - current_tag = current_tag.parent - # Get the text content from the parent tag - if current_tag: - text_content = current_tag.get_text(separator=' ',strip=True) - # Check if the text content has at least word_count_threshold - if len(text_content.split()) >= image_description_min_word_threshold: - return text_content - return None + current_tag = tag + while current_tag: + current_tag = current_tag.parent + # Get the text content from the parent tag + if current_tag: + text_content = current_tag.get_text(separator=" ", strip=True) + # Check if the text content has at least word_count_threshold + if len(text_content.split()) >= image_description_min_word_threshold: + return text_content + return None def process_image(img, url, index, total_images): - #Check if an image has valid display and inside undesired html elements + # Check if an image has valid display and inside undesired html elements def is_valid_image(img, parent, parent_classes): - style = img.get('style', '') - src = img.get('src', '') - classes_to_check = ['button', 'icon', 'logo'] - tags_to_check = ['button', 'input'] - return all([ - 'display:none' not in style, - src, - not any(s in var for var in [src, img.get('alt', ''), *parent_classes] for s in classes_to_check), - parent.name not in tags_to_check - ]) + style = img.get("style", "") + src = img.get("src", "") + classes_to_check = ["button", "icon", "logo"] + tags_to_check = ["button", "input"] + return all( + [ + "display:none" not in style, + src, + not any( + s in var + for var in [src, img.get("alt", ""), *parent_classes] + for s in classes_to_check + ), + parent.name not in tags_to_check, + ] + ) - #Score an image for it's usefulness + # Score an image for it's usefulness def score_image_for_usefulness(img, base_url, index, images_count): # Function to parse image height/width value and units def parse_dimension(dimension): @@ -680,66 +741,68 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: match = re.match(r"(\d+)(\D*)", dimension) if match: number = int(match.group(1)) - unit = match.group(2) or 'px' # Default unit is 'px' if not specified + unit = ( + match.group(2) or "px" + ) # Default unit is 'px' if not specified return number, unit return None, None # Fetch image file metadata to extract size and extension def fetch_image_file_size(img, base_url): - #If src is relative path construct full URL, if not it may be CDN URL - img_url = urljoin(base_url,img.get('src')) + # If src is relative path construct full URL, if not it may be CDN URL + img_url = urljoin(base_url, img.get("src")) try: response = requests.head(img_url) if response.status_code == 200: - return response.headers.get('Content-Length',None) + return response.headers.get("Content-Length", None) else: print(f"Failed to retrieve file size for {img_url}") return None - except InvalidSchema as e: + except InvalidSchema: return None finally: return - image_height = img.get('height') + image_height = img.get("height") height_value, height_unit = parse_dimension(image_height) - image_width = img.get('width') + image_width = img.get("width") width_value, width_unit = parse_dimension(image_width) - image_size = 0 #int(fetch_image_file_size(img,base_url) or 0) - image_format = os.path.splitext(img.get('src',''))[1].lower() + image_size = 0 # int(fetch_image_file_size(img,base_url) or 0) + image_format = os.path.splitext(img.get("src", ""))[1].lower() # Remove . from format - image_format = image_format.strip('.') + image_format = image_format.strip(".") score = 0 if height_value: - if height_unit == 'px' and height_value > 150: + if height_unit == "px" and height_value > 150: score += 1 - if height_unit in ['%','vh','vmin','vmax'] and height_value >30: + if height_unit in ["%", "vh", "vmin", "vmax"] and height_value > 30: score += 1 if width_value: - if width_unit == 'px' and width_value > 150: + if width_unit == "px" and width_value > 150: score += 1 - if width_unit in ['%','vh','vmin','vmax'] and width_value >30: + if width_unit in ["%", "vh", "vmin", "vmax"] and width_value > 30: score += 1 if image_size > 10000: score += 1 - if img.get('alt') != '': - score+=1 - if any(image_format==format for format in ['jpg','png','webp']): - score+=1 - if index/images_count<0.5: - score+=1 + if img.get("alt") != "": + score += 1 + if any(image_format == format for format in ["jpg", "png", "webp"]): + score += 1 + if index / images_count < 0.5: + score += 1 return score - if not is_valid_image(img, img.parent, img.parent.get('class', [])): + if not is_valid_image(img, img.parent, img.parent.get("class", [])): return None score = score_image_for_usefulness(img, url, index, total_images) if score <= IMAGE_SCORE_THRESHOLD: return None return { - 'src': img.get('src', '').replace('\\"', '"').strip(), - 'alt': img.get('alt', ''), - 'desc': find_closest_parent_with_useful_text(img), - 'score': score, - 'type': 'image' + "src": img.get("src", "").replace('\\"', '"').strip(), + "alt": img.get("alt", ""), + "desc": find_closest_parent_with_useful_text(img), + "score": score, + "type": "image", } def process_element(element: element.PageElement) -> bool: @@ -749,60 +812,89 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: element.extract() return False - if element.name in ['script', 'style', 'link', 'meta', 'noscript']: + if element.name in ["script", "style", "link", "meta", "noscript"]: element.decompose() return False keep_element = False - if element.name == 'a' and element.get('href'): - href = element['href'] - url_base = url.split('/')[2] - link_data = {'href': href, 'text': element.get_text()} - if href.startswith('http') and url_base not in href: - links['external'].append(link_data) + if element.name == "a" and element.get("href"): + href = element["href"] + url_base = url.split("/")[2] + link_data = {"href": href, "text": element.get_text()} + if href.startswith("http") and url_base not in href: + links["external"].append(link_data) else: - links['internal'].append(link_data) + links["internal"].append(link_data) keep_element = True - elif element.name == 'img': + elif element.name == "img": return True # Always keep image elements - elif element.name in ['video', 'audio']: - media[f"{element.name}s"].append({ - 'src': element.get('src'), - 'alt': element.get('alt'), - 'type': element.name, - 'description': find_closest_parent_with_useful_text(element) - }) - source_tags = element.find_all('source') + elif element.name in ["video", "audio"]: + media[f"{element.name}s"].append( + { + "src": element.get("src"), + "alt": element.get("alt"), + "type": element.name, + "description": find_closest_parent_with_useful_text(element), + } + ) + source_tags = element.find_all("source") for source_tag in source_tags: - media[f"{element.name}s"].append({ - 'src': source_tag.get('src'), - 'alt': element.get('alt'), - 'type': element.name, - 'description': find_closest_parent_with_useful_text(element) - }) + media[f"{element.name}s"].append( + { + "src": source_tag.get("src"), + "alt": element.get("alt"), + "type": element.name, + "description": find_closest_parent_with_useful_text( + element + ), + } + ) return True # Always keep video and audio elements - if element.name != 'pre': - if element.name in ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark']: - if kwargs.get('only_text', False): + if element.name != "pre": + if element.name in [ + "b", + "i", + "u", + "span", + "del", + "ins", + "sub", + "sup", + "strong", + "em", + "code", + "kbd", + "var", + "s", + "q", + "abbr", + "cite", + "dfn", + "time", + "small", + "mark", + ]: + if kwargs.get("only_text", False): element.replace_with(element.get_text()) else: element.unwrap() - elif element.name != 'img': + elif element.name != "img": element.attrs = {} # Process children for child in list(element.children): - if isinstance(child, NavigableString) and not isinstance(child, Comment): + if isinstance(child, NavigableString) and not isinstance( + child, Comment + ): if len(child.strip()) > 0: keep_element = True else: if process_element(child): keep_element = True - # Check word count if not keep_element: @@ -814,14 +906,16 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: return keep_element except Exception as e: - print('Error processing element:', str(e)) + print("Error processing element:", str(e)) return False - #process images by filtering and extracting contextual text from the page - imgs = body.find_all('img') - media['images'] = [ - result for result in - (process_image(img, url, i, len(imgs)) for i, img in enumerate(imgs)) + # process images by filtering and extracting contextual text from the page + imgs = body.find_all("img") + media["images"] = [ + result + for result in ( + process_image(img, url, i, len(imgs)) for i, img in enumerate(imgs) + ) if result is not None ] @@ -830,7 +924,11 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: def flatten_nested_elements(node): if isinstance(node, NavigableString): return node - if len(node.contents) == 1 and isinstance(node.contents[0], element.Tag) and node.contents[0].name == node.name: + if ( + len(node.contents) == 1 + and isinstance(node.contents[0], element.Tag) + and node.contents[0].name == node.name + ): return flatten_nested_elements(node.contents[0]) node.contents = [flatten_nested_elements(child) for child in node.contents] return node @@ -839,86 +937,87 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: base64_pattern = re.compile(r'data:image/[^;]+;base64,([^"]+)') for img in imgs: try: - src = img.get('src', '') + src = img.get("src", "") if base64_pattern.match(src): - img['src'] = base64_pattern.sub('', src) + img["src"] = base64_pattern.sub("", src) except: - pass + pass - cleaned_html = str(body).replace('\n\n', '\n').replace(' ', ' ') + cleaned_html = str(body).replace("\n\n", "\n").replace(" ", " ") cleaned_html = sanitize_html(cleaned_html) h = CustomHTML2Text() h.ignore_links = True markdown = h.handle(cleaned_html) - markdown = markdown.replace(' ```', '```') + markdown = markdown.replace(" ```", "```") try: meta = extract_metadata(html, soup) except Exception as e: - print('Error extracting metadata:', str(e)) + print("Error extracting metadata:", str(e)) meta = {} return { - 'markdown': markdown, - 'cleaned_html': cleaned_html, - 'success': True, - 'media': media, - 'links': links, - 'metadata': meta + "markdown": markdown, + "cleaned_html": cleaned_html, + "success": True, + "media": media, + "links": links, + "metadata": meta, } + def extract_metadata_using_lxml(html, doc=None): """ Extract metadata from HTML using lxml for better performance. """ metadata = {} - + if not html and doc is None: return {} - + if doc is None: try: doc = lhtml.document_fromstring(html) except Exception: return {} - + # Use XPath to find head element - head = doc.xpath('//head') + head = doc.xpath("//head") if not head: return metadata - + head = head[0] - + # Title - using XPath - title = head.xpath('.//title/text()') - metadata['title'] = title[0].strip() if title else None + title = head.xpath(".//title/text()") + metadata["title"] = title[0].strip() if title else None # Meta description - using XPath with multiple attribute conditions description = head.xpath('.//meta[@name="description"]/@content') - metadata['description'] = description[0].strip() if description else None + metadata["description"] = description[0].strip() if description else None # Meta keywords keywords = head.xpath('.//meta[@name="keywords"]/@content') - metadata['keywords'] = keywords[0].strip() if keywords else None + metadata["keywords"] = keywords[0].strip() if keywords else None # Meta author author = head.xpath('.//meta[@name="author"]/@content') - metadata['author'] = author[0].strip() if author else None + metadata["author"] = author[0].strip() if author else None # Open Graph metadata - using starts-with() for performance og_tags = head.xpath('.//meta[starts-with(@property, "og:")]') for tag in og_tags: - property_name = tag.get('property', '').strip() - content = tag.get('content', '').strip() + property_name = tag.get("property", "").strip() + content = tag.get("content", "").strip() if property_name and content: metadata[property_name] = content # Twitter Card metadata twitter_tags = head.xpath('.//meta[starts-with(@name, "twitter:")]') for tag in twitter_tags: - property_name = tag.get('name', '').strip() - content = tag.get('content', '').strip() + property_name = tag.get("name", "").strip() + content = tag.get("content", "").strip() if property_name and content: metadata[property_name] = content @@ -946,66 +1045,74 @@ def extract_metadata(html, soup=None): Returns: Dict[str, Any]: Extracted content including Markdown, cleaned HTML, media, links, and metadata. """ - + metadata = {} - + if not html and not soup: return {} - + if not soup: - soup = BeautifulSoup(html, 'lxml') - + soup = BeautifulSoup(html, "lxml") + head = soup.head if not head: return metadata - + # Title - title_tag = head.find('title') - metadata['title'] = title_tag.string.strip() if title_tag and title_tag.string else None + title_tag = head.find("title") + metadata["title"] = ( + title_tag.string.strip() if title_tag and title_tag.string else None + ) # Meta description - description_tag = head.find('meta', attrs={'name': 'description'}) - metadata['description'] = description_tag.get('content', '').strip() if description_tag else None + description_tag = head.find("meta", attrs={"name": "description"}) + metadata["description"] = ( + description_tag.get("content", "").strip() if description_tag else None + ) # Meta keywords - keywords_tag = head.find('meta', attrs={'name': 'keywords'}) - metadata['keywords'] = keywords_tag.get('content', '').strip() if keywords_tag else None + keywords_tag = head.find("meta", attrs={"name": "keywords"}) + metadata["keywords"] = ( + keywords_tag.get("content", "").strip() if keywords_tag else None + ) # Meta author - author_tag = head.find('meta', attrs={'name': 'author'}) - metadata['author'] = author_tag.get('content', '').strip() if author_tag else None + author_tag = head.find("meta", attrs={"name": "author"}) + metadata["author"] = author_tag.get("content", "").strip() if author_tag else None # Open Graph metadata - og_tags = head.find_all('meta', attrs={'property': re.compile(r'^og:')}) + og_tags = head.find_all("meta", attrs={"property": re.compile(r"^og:")}) for tag in og_tags: - property_name = tag.get('property', '').strip() - content = tag.get('content', '').strip() + property_name = tag.get("property", "").strip() + content = tag.get("content", "").strip() if property_name and content: metadata[property_name] = content # Twitter Card metadata - twitter_tags = head.find_all('meta', attrs={'name': re.compile(r'^twitter:')}) + twitter_tags = head.find_all("meta", attrs={"name": re.compile(r"^twitter:")}) for tag in twitter_tags: - property_name = tag.get('name', '').strip() - content = tag.get('content', '').strip() + property_name = tag.get("name", "").strip() + content = tag.get("content", "").strip() if property_name and content: metadata[property_name] = content - + return metadata + def extract_xml_tags(string): """ Extracts XML tags from a string. - Args: + Args: string (str): The input string containing XML tags. Returns: List[str]: A list of XML tags extracted from the input string. """ - tags = re.findall(r'<(\w+)>', string) + tags = re.findall(r"<(\w+)>", string) return list(set(tags)) + def extract_xml_data(tags, string): """ Extract data for specified XML tags from a string. @@ -1034,15 +1141,16 @@ def extract_xml_data(tags, string): data[tag] = "" return data - + + def perform_completion_with_backoff( - provider, - prompt_with_variables, - api_token, - json_response = False, + provider, + prompt_with_variables, + api_token, + json_response=False, base_url=None, - **kwargs - ): + **kwargs, +): """ Perform an API completion request with exponential backoff. @@ -1062,52 +1170,49 @@ def perform_completion_with_backoff( Returns: dict: The API response or an error message after all retries. """ - - from litellm import completion + + from litellm import completion from litellm.exceptions import RateLimitError + max_attempts = 3 base_delay = 2 # Base delay in seconds, you can adjust this based on your needs - - extra_args = { - "temperature": 0.01, - 'api_key': api_token, - 'base_url': base_url - } + + extra_args = {"temperature": 0.01, "api_key": api_token, "base_url": base_url} if json_response: - extra_args["response_format"] = { "type": "json_object" } - + extra_args["response_format"] = {"type": "json_object"} + if kwargs.get("extra_args"): extra_args.update(kwargs["extra_args"]) - + for attempt in range(max_attempts): try: - - response =completion( + response = completion( model=provider, - messages=[ - {"role": "user", "content": prompt_with_variables} - ], - **extra_args + messages=[{"role": "user", "content": prompt_with_variables}], + **extra_args, ) return response # Return the successful response except RateLimitError as e: print("Rate limit error:", str(e)) - + # Check if we have exhausted our max attempts if attempt < max_attempts - 1: # Calculate the delay and wait - delay = base_delay * (2 ** attempt) # Exponential backoff formula + delay = base_delay * (2**attempt) # Exponential backoff formula print(f"Waiting for {delay} seconds before retrying...") time.sleep(delay) else: # Return an error response after exhausting all retries - return [{ - "index": 0, - "tags": ["error"], - "content": ["Rate limit error. Please try again later."] - }] - -def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, base_url = None): + return [ + { + "index": 0, + "tags": ["error"], + "content": ["Rate limit error. Please try again later."], + } + ] + + +def extract_blocks(url, html, provider=DEFAULT_PROVIDER, api_token=None, base_url=None): """ Extract content blocks from website HTML using an AI provider. @@ -1129,7 +1234,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, bas # api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token - + variable_values = { "URL": url, "HTML": escape_json_string(sanitize_html(html)), @@ -1140,29 +1245,33 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None, bas prompt_with_variables = prompt_with_variables.replace( "{" + variable + "}", variable_values[variable] ) - - response = perform_completion_with_backoff(provider, prompt_with_variables, api_token, base_url=base_url) - + + response = perform_completion_with_backoff( + provider, prompt_with_variables, api_token, base_url=base_url + ) + try: - blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] + blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[ + "blocks" + ] blocks = json.loads(blocks) ## Add error: False to the blocks for block in blocks: - block['error'] = False - except Exception as e: - parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content) + block["error"] = False + except Exception: + parsed, unparsed = split_and_parse_json_objects( + response.choices[0].message.content + ) blocks = parsed # Append all unparsed segments as onr error block and content is list of unparsed segments if unparsed: - blocks.append({ - "index": 0, - "error": True, - "tags": ["error"], - "content": unparsed - }) + blocks.append( + {"index": 0, "error": True, "tags": ["error"], "content": unparsed} + ) return blocks -def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_token = None): + +def extract_blocks_batch(batch_data, provider="groq/llama3-70b-8192", api_token=None): """ Extract content blocks from a batch of website HTMLs. @@ -1180,11 +1289,12 @@ def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_toke List[dict]: A list of extracted content blocks from all batch items. """ - api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token + api_token = os.getenv("GROQ_API_KEY", None) if not api_token else api_token from litellm import batch_completion + messages = [] - - for url, html in batch_data: + + for url, html in batch_data: variable_values = { "URL": url, "HTML": html, @@ -1195,33 +1305,37 @@ def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_toke prompt_with_variables = prompt_with_variables.replace( "{" + variable + "}", variable_values[variable] ) - + messages.append([{"role": "user", "content": prompt_with_variables}]) - - - responses = batch_completion( - model = provider, - messages = messages, - temperature = 0.01 - ) - + + responses = batch_completion(model=provider, messages=messages, temperature=0.01) + all_blocks = [] - for response in responses: + for response in responses: try: - blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] + blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[ + "blocks" + ] blocks = json.loads(blocks) - except Exception as e: - blocks = [{ - "index": 0, - "tags": ["error"], - "content": ["Error extracting blocks from the HTML content. Choose another provider/model or try again."], - "questions": ["What went wrong during the block extraction process?"] - }] + except Exception: + blocks = [ + { + "index": 0, + "tags": ["error"], + "content": [ + "Error extracting blocks from the HTML content. Choose another provider/model or try again." + ], + "questions": [ + "What went wrong during the block extraction process?" + ], + } + ] all_blocks.append(blocks) - + return sum(all_blocks, []) + def merge_chunks_based_on_token_threshold(chunks, token_threshold): """ Merges small chunks into larger ones based on the total token threshold. @@ -1235,23 +1349,28 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold): total_token_so_far = 0 for chunk in chunks: - chunk_token_count = len(chunk.split()) * 1.3 # Estimate token count with a factor + chunk_token_count = ( + len(chunk.split()) * 1.3 + ) # Estimate token count with a factor if total_token_so_far + chunk_token_count < token_threshold: current_chunk.append(chunk) total_token_so_far += chunk_token_count else: if current_chunk: - merged_sections.append('\n\n'.join(current_chunk)) + merged_sections.append("\n\n".join(current_chunk)) current_chunk = [chunk] total_token_so_far = chunk_token_count # Add the last chunk if it exists if current_chunk: - merged_sections.append('\n\n'.join(current_chunk)) + merged_sections.append("\n\n".join(current_chunk)) return merged_sections -def process_sections(url: str, sections: list, provider: str, api_token: str, base_url=None) -> list: + +def process_sections( + url: str, sections: list, provider: str, api_token: str, base_url=None +) -> list: """ Process sections of HTML content sequentially or in parallel. @@ -1275,17 +1394,25 @@ def process_sections(url: str, sections: list, provider: str, api_token: str, ba if provider.startswith("groq/"): # Sequential processing with a delay for section in sections: - extracted_content.extend(extract_blocks(url, section, provider, api_token, base_url=base_url)) + extracted_content.extend( + extract_blocks(url, section, provider, api_token, base_url=base_url) + ) time.sleep(0.5) # 500 ms delay between each processing else: # Parallel processing using ThreadPoolExecutor with ThreadPoolExecutor() as executor: - futures = [executor.submit(extract_blocks, url, section, provider, api_token, base_url=base_url) for section in sections] + futures = [ + executor.submit( + extract_blocks, url, section, provider, api_token, base_url=base_url + ) + for section in sections + ] for future in as_completed(futures): extracted_content.extend(future.result()) - + return extracted_content + def wrap_text(draw, text, font, max_width): """ Wrap text to fit within a specified width for rendering. @@ -1309,11 +1436,14 @@ def wrap_text(draw, text, font, max_width): lines = [] words = text.split() while words: - line = '' - while words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width: - line += (words.pop(0) + ' ') + line = "" + while ( + words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width + ): + line += words.pop(0) + " " lines.append(line) - return '\n'.join(lines) + return "\n".join(lines) + def format_html(html_string): """ @@ -1331,16 +1461,17 @@ def format_html(html_string): str: The prettified HTML string. """ - soup = BeautifulSoup(html_string, 'lxml.parser') + soup = BeautifulSoup(html_string, "lxml.parser") return soup.prettify() + def fast_format_html(html_string): """ A fast HTML formatter that uses string operations instead of parsing. - + Args: html_string (str): The HTML string to format - + Returns: str: The formatted HTML string """ @@ -1349,35 +1480,36 @@ def fast_format_html(html_string): indent_str = " " # Two spaces for indentation formatted = [] in_content = False - + # Split by < and > to separate tags and content - parts = html_string.replace('>', '>\n').replace('<', '\n<').split('\n') - + parts = html_string.replace(">", ">\n").replace("<", "\n<").split("\n") + for part in parts: if not part.strip(): continue - + # Handle closing tags - if part.startswith(''): + elif part.startswith("<") and part.endswith("/>"): formatted.append(indent_str * indent + part) - + # Handle opening tags - elif part.startswith('<'): + elif part.startswith("<"): formatted.append(indent_str * indent + part) indent += 1 - + # Handle content between tags else: content = part.strip() if content: formatted.append(indent_str * indent + content) - - return '\n'.join(formatted) + + return "\n".join(formatted) + def normalize_url(href, base_url): """Normalize URLs to ensure consistent format""" @@ -1392,41 +1524,43 @@ def normalize_url(href, base_url): normalized = urljoin(base_url, href.strip()) return normalized + def normalize_url_tmp(href, base_url): """Normalize URLs to ensure consistent format""" # Extract protocol and domain from base URL try: - base_parts = base_url.split('/') + base_parts = base_url.split("/") protocol = base_parts[0] domain = base_parts[2] except IndexError: raise ValueError(f"Invalid base URL format: {base_url}") - + # Handle special protocols - special_protocols = {'mailto:', 'tel:', 'ftp:', 'file:', 'data:', 'javascript:'} + special_protocols = {"mailto:", "tel:", "ftp:", "file:", "data:", "javascript:"} if any(href.lower().startswith(proto) for proto in special_protocols): return href.strip() - + # Handle anchor links - if href.startswith('#'): + if href.startswith("#"): return f"{base_url}{href}" - + # Handle protocol-relative URLs - if href.startswith('//'): + if href.startswith("//"): return f"{protocol}{href}" - + # Handle root-relative URLs - if href.startswith('/'): + if href.startswith("/"): return f"{protocol}//{domain}{href}" - + # Handle relative URLs - if not href.startswith(('http://', 'https://')): + if not href.startswith(("http://", "https://")): # Remove leading './' if present - href = href.lstrip('./') + href = href.lstrip("./") return f"{protocol}//{domain}/{href}" - + return href.strip() + def get_base_domain(url: str) -> str: """ Extract the base domain from a given URL, handling common edge cases. @@ -1447,25 +1581,37 @@ def get_base_domain(url: str) -> str: domain = urlparse(url).netloc.lower() if not domain: return "" - + # Remove port if present - domain = domain.split(':')[0] - + domain = domain.split(":")[0] + # Remove www - domain = re.sub(r'^www\.', '', domain) - + domain = re.sub(r"^www\.", "", domain) + # Extract last two parts of domain (handles co.uk etc) - parts = domain.split('.') + parts = domain.split(".") if len(parts) > 2 and parts[-2] in { - 'co', 'com', 'org', 'gov', 'edu', 'net', - 'mil', 'int', 'ac', 'ad', 'ae', 'af', 'ag' + "co", + "com", + "org", + "gov", + "edu", + "net", + "mil", + "int", + "ac", + "ad", + "ae", + "af", + "ag", }: - return '.'.join(parts[-3:]) - - return '.'.join(parts[-2:]) + return ".".join(parts[-3:]) + + return ".".join(parts[-2:]) except Exception: return "" + def is_external_url(url: str, base_domain: str) -> bool: """ Extract the base domain from a given URL, handling common edge cases. @@ -1481,24 +1627,25 @@ def is_external_url(url: str, base_domain: str) -> bool: Returns: str: The extracted base domain or an empty string if parsing fails. """ - special = {'mailto:', 'tel:', 'ftp:', 'file:', 'data:', 'javascript:'} + special = {"mailto:", "tel:", "ftp:", "file:", "data:", "javascript:"} if any(url.lower().startswith(p) for p in special): return True - + try: parsed = urlparse(url) if not parsed.netloc: # Relative URL return False - + # Strip 'www.' from both domains for comparison - url_domain = parsed.netloc.lower().replace('www.', '') - base = base_domain.lower().replace('www.', '') - + url_domain = parsed.netloc.lower().replace("www.", "") + base = base_domain.lower().replace("www.", "") + # Check if URL domain ends with base domain return not url_domain.endswith(base) except Exception: return False + def clean_tokens(tokens: list[str]) -> list[str]: """ Clean a list of tokens by removing noise, stop words, and short tokens. @@ -1516,58 +1663,217 @@ def clean_tokens(tokens: list[str]) -> list[str]: """ # Set of tokens to remove - noise = {'ccp', 'up', '↑', '▲', '⬆️', 'a', 'an', 'at', 'by', 'in', 'of', 'on', 'to', 'the'} + noise = { + "ccp", + "up", + "↑", + "▲", + "⬆️", + "a", + "an", + "at", + "by", + "in", + "of", + "on", + "to", + "the", + } STOP_WORDS = { - 'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for', 'from', - 'has', 'he', 'in', 'is', 'it', 'its', 'of', 'on', 'that', 'the', - 'to', 'was', 'were', 'will', 'with', - + "a", + "an", + "and", + "are", + "as", + "at", + "be", + "by", + "for", + "from", + "has", + "he", + "in", + "is", + "it", + "its", + "of", + "on", + "that", + "the", + "to", + "was", + "were", + "will", + "with", # Pronouns - 'i', 'you', 'he', 'she', 'it', 'we', 'they', - 'me', 'him', 'her', 'us', 'them', - 'my', 'your', 'his', 'her', 'its', 'our', 'their', - 'mine', 'yours', 'hers', 'ours', 'theirs', - 'myself', 'yourself', 'himself', 'herself', 'itself', 'ourselves', 'themselves', - + "i", + "you", + "he", + "she", + "it", + "we", + "they", + "me", + "him", + "her", + "us", + "them", + "my", + "your", + "his", + "her", + "its", + "our", + "their", + "mine", + "yours", + "hers", + "ours", + "theirs", + "myself", + "yourself", + "himself", + "herself", + "itself", + "ourselves", + "themselves", # Common verbs - 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', - 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', - + "am", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "having", + "do", + "does", + "did", + "doing", # Prepositions - 'about', 'above', 'across', 'after', 'against', 'along', 'among', 'around', - 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'between', 'beyond', - 'by', 'down', 'during', 'except', 'for', 'from', 'in', 'inside', 'into', - 'near', 'of', 'off', 'on', 'out', 'outside', 'over', 'past', 'through', - 'to', 'toward', 'under', 'underneath', 'until', 'up', 'upon', 'with', 'within', - + "about", + "above", + "across", + "after", + "against", + "along", + "among", + "around", + "at", + "before", + "behind", + "below", + "beneath", + "beside", + "between", + "beyond", + "by", + "down", + "during", + "except", + "for", + "from", + "in", + "inside", + "into", + "near", + "of", + "off", + "on", + "out", + "outside", + "over", + "past", + "through", + "to", + "toward", + "under", + "underneath", + "until", + "up", + "upon", + "with", + "within", # Conjunctions - 'and', 'but', 'or', 'nor', 'for', 'yet', 'so', - 'although', 'because', 'since', 'unless', - + "and", + "but", + "or", + "nor", + "for", + "yet", + "so", + "although", + "because", + "since", + "unless", # Articles - 'a', 'an', 'the', - + "a", + "an", + "the", # Other common words - 'this', 'that', 'these', 'those', - 'what', 'which', 'who', 'whom', 'whose', - 'when', 'where', 'why', 'how', - 'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', - 'can', 'cannot', "can't", 'could', "couldn't", - 'may', 'might', 'must', "mustn't", - 'shall', 'should', "shouldn't", - 'will', "won't", 'would', "wouldn't", - 'not', "n't", 'no', 'nor', 'none' - } - + "this", + "that", + "these", + "those", + "what", + "which", + "who", + "whom", + "whose", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "can", + "cannot", + "can't", + "could", + "couldn't", + "may", + "might", + "must", + "mustn't", + "shall", + "should", + "shouldn't", + "will", + "won't", + "would", + "wouldn't", + "not", + "n't", + "no", + "nor", + "none", + } + # Single comprehension, more efficient than multiple passes - return [token for token in tokens - if len(token) > 2 - and token not in noise - and token not in STOP_WORDS - and not token.startswith('↑') - and not token.startswith('▲') - and not token.startswith('⬆')] + return [ + token + for token in tokens + if len(token) > 2 + and token not in noise + and token not in STOP_WORDS + and not token.startswith("↑") + and not token.startswith("▲") + and not token.startswith("⬆") + ] + def profile_and_time(func): """ @@ -1589,103 +1895,108 @@ def profile_and_time(func): def wrapper(self, *args, **kwargs): # Start timer start_time = time.perf_counter() - + # Setup profiler profiler = cProfile.Profile() profiler.enable() - + # Run function result = func(self, *args, **kwargs) - + # Stop profiler profiler.disable() - + # Calculate elapsed time elapsed_time = time.perf_counter() - start_time - + # Print timing print(f"[PROFILER] Scraping completed in {elapsed_time:.2f} seconds") - + # Print profiling stats stats = pstats.Stats(profiler) - stats.sort_stats('cumulative') # Sort by cumulative time + stats.sort_stats("cumulative") # Sort by cumulative time stats.print_stats(20) # Print top 20 time-consuming functions - + return result + return wrapper + def generate_content_hash(content: str) -> str: """Generate a unique hash for content""" return xxhash.xxh64(content.encode()).hexdigest() # return hashlib.sha256(content.encode()).hexdigest() + def ensure_content_dirs(base_path: str) -> Dict[str, str]: """Create content directories if they don't exist""" dirs = { - 'html': 'html_content', - 'cleaned': 'cleaned_html', - 'markdown': 'markdown_content', - 'extracted': 'extracted_content', - 'screenshots': 'screenshots', - 'screenshot': 'screenshots' + "html": "html_content", + "cleaned": "cleaned_html", + "markdown": "markdown_content", + "extracted": "extracted_content", + "screenshots": "screenshots", + "screenshot": "screenshots", } - + content_paths = {} for key, dirname in dirs.items(): path = os.path.join(base_path, dirname) os.makedirs(path, exist_ok=True) content_paths[key] = path - + return content_paths + def configure_windows_event_loop(): """ Configure the Windows event loop to use ProactorEventLoop. This resolves the NotImplementedError that occurs on Windows when using asyncio subprocesses. - + This function should only be called on Windows systems and before any async operations. On non-Windows systems, this function does nothing. - + Example: ```python from crawl4ai.async_configs import configure_windows_event_loop - + # Call this before any async operations if you're on Windows configure_windows_event_loop() ``` """ - if platform.system() == 'Windows': + if platform.system() == "Windows": asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + def get_error_context(exc_info, context_lines: int = 5): """ Extract error context with more reliable line number tracking. - + Args: exc_info: The exception info from sys.exc_info() context_lines: Number of lines to show before and after the error - + Returns: dict: Error context information """ import traceback import linecache import os - + # Get the full traceback tb = traceback.extract_tb(exc_info[2]) - + # Get the last frame (where the error occurred) last_frame = tb[-1] filename = last_frame.filename line_no = last_frame.lineno func_name = last_frame.name - + # Get the source code context using linecache # This is more reliable than inspect.getsourcelines context_start = max(1, line_no - context_lines) context_end = line_no + context_lines + 1 - + # Build the context lines with line numbers context_lines = [] for i in range(context_start, context_end): @@ -1693,25 +2004,22 @@ def get_error_context(exc_info, context_lines: int = 5): if line: # Remove any trailing whitespace/newlines and add the pointer for error line line = line.rstrip() - pointer = '→' if i == line_no else ' ' + pointer = "→" if i == line_no else " " context_lines.append(f"{i:4d} {pointer} {line}") - + # Join the lines with newlines - code_context = '\n'.join(context_lines) - + code_context = "\n".join(context_lines) + # Get relative path for cleaner output try: rel_path = os.path.relpath(filename) except ValueError: # Fallback if relpath fails (can happen on Windows with different drives) rel_path = filename - + return { "filename": rel_path, "line_no": line_no, "function": func_name, - "code_context": code_context + "code_context": code_context, } - - - \ No newline at end of file diff --git a/crawl4ai/version_manager.py b/crawl4ai/version_manager.py index 8ae2de2e..17d73faa 100644 --- a/crawl4ai/version_manager.py +++ b/crawl4ai/version_manager.py @@ -1,14 +1,14 @@ # version_manager.py -import os from pathlib import Path from packaging import version from . import __version__ + class VersionManager: def __init__(self): self.home_dir = Path.home() / ".crawl4ai" self.version_file = self.home_dir / "version.txt" - + def get_installed_version(self): """Get the version recorded in home directory""" if not self.version_file.exists(): @@ -17,14 +17,13 @@ class VersionManager: return version.parse(self.version_file.read_text().strip()) except: return None - + def update_version(self): """Update the version file to current library version""" self.version_file.write_text(__version__.__version__) - + def needs_update(self): """Check if database needs update based on version""" installed = self.get_installed_version() current = version.parse(__version__.__version__) return installed is None or installed < current - diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index a32a988d..a92ae6dd 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -1,9 +1,10 @@ import os, time + os.environ["TOKENIZERS_PARALLELISM"] = "false" from pathlib import Path from .models import UrlModel, CrawlResult -from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db +from .database import init_db, get_cached_url, cache_url from .utils import * from .chunking_strategy import * from .extraction_strategy import * @@ -14,31 +15,44 @@ from .content_scraping_strategy import WebScrapingStrategy from .config import * import warnings import json -warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".') + +warnings.filterwarnings( + "ignore", + message='Field "model_name" has conflict with protected namespace "model_".', +) class WebCrawler: - def __init__(self, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False): - self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose) + def __init__( + self, + crawler_strategy: CrawlerStrategy = None, + always_by_pass_cache: bool = False, + verbose: bool = False, + ): + self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy( + verbose=verbose + ) self.always_by_pass_cache = always_by_pass_cache - self.crawl4ai_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") + self.crawl4ai_folder = os.path.join( + os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai" + ) os.makedirs(self.crawl4ai_folder, exist_ok=True) os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) init_db() self.ready = False - + def warmup(self): print("[LOG] 🌤️ Warming up the WebCrawler") self.run( - url='https://google.com/', + url="https://google.com/", word_count_threshold=5, extraction_strategy=NoExtractionStrategy(), bypass_cache=False, - verbose=False + verbose=False, ) self.ready = True print("[LOG] 🌞 WebCrawler is ready to crawl") - + def fetch_page( self, url_model: UrlModel, @@ -80,6 +94,7 @@ class WebCrawler: **kwargs, ) -> List[CrawlResult]: extraction_strategy = extraction_strategy or NoExtractionStrategy() + def fetch_page_wrapper(url_model, *args, **kwargs): return self.fetch_page(url_model, *args, **kwargs) @@ -104,150 +119,176 @@ class WebCrawler: return results def run( - self, - url: str, - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> CrawlResult: - try: - extraction_strategy = extraction_strategy or NoExtractionStrategy() - extraction_strategy.verbose = verbose - if not isinstance(extraction_strategy, ExtractionStrategy): - raise ValueError("Unsupported extraction strategy") - if not isinstance(chunking_strategy, ChunkingStrategy): - raise ValueError("Unsupported chunking strategy") - - word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD) + self, + url: str, + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = None, + chunking_strategy: ChunkingStrategy = RegexChunking(), + bypass_cache: bool = False, + css_selector: str = None, + screenshot: bool = False, + user_agent: str = None, + verbose=True, + **kwargs, + ) -> CrawlResult: + try: + extraction_strategy = extraction_strategy or NoExtractionStrategy() + extraction_strategy.verbose = verbose + if not isinstance(extraction_strategy, ExtractionStrategy): + raise ValueError("Unsupported extraction strategy") + if not isinstance(chunking_strategy, ChunkingStrategy): + raise ValueError("Unsupported chunking strategy") - cached = None - screenshot_data = None - extracted_content = None - if not bypass_cache and not self.always_by_pass_cache: - cached = get_cached_url(url) - - if kwargs.get("warmup", True) and not self.ready: - return None - - if cached: - html = sanitize_input_encode(cached[1]) - extracted_content = sanitize_input_encode(cached[4]) - if screenshot: - screenshot_data = cached[9] - if not screenshot_data: - cached = None - - if not cached or not html: - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - t1 = time.time() - html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs)) - t2 = time.time() - if verbose: - print(f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds") - if screenshot: - screenshot_data = self.crawler_strategy.take_screenshot() + word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD) - - crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs) - crawl_result.success = bool(html) - return crawl_result - except Exception as e: - if not hasattr(e, "msg"): - e.msg = str(e) - print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}") - return CrawlResult(url=url, html="", success=False, error_message=e.msg) + cached = None + screenshot_data = None + extracted_content = None + if not bypass_cache and not self.always_by_pass_cache: + cached = get_cached_url(url) + + if kwargs.get("warmup", True) and not self.ready: + return None + + if cached: + html = sanitize_input_encode(cached[1]) + extracted_content = sanitize_input_encode(cached[4]) + if screenshot: + screenshot_data = cached[9] + if not screenshot_data: + cached = None + + if not cached or not html: + if user_agent: + self.crawler_strategy.update_user_agent(user_agent) + t1 = time.time() + html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs)) + t2 = time.time() + if verbose: + print( + f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds" + ) + if screenshot: + screenshot_data = self.crawler_strategy.take_screenshot() + + crawl_result = self.process_html( + url, + html, + extracted_content, + word_count_threshold, + extraction_strategy, + chunking_strategy, + css_selector, + screenshot_data, + verbose, + bool(cached), + **kwargs, + ) + crawl_result.success = bool(html) + return crawl_result + except Exception as e: + if not hasattr(e, "msg"): + e.msg = str(e) + print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}") + return CrawlResult(url=url, html="", success=False, error_message=e.msg) def process_html( - self, - url: str, - html: str, - extracted_content: str, - word_count_threshold: int, - extraction_strategy: ExtractionStrategy, - chunking_strategy: ChunkingStrategy, - css_selector: str, - screenshot: bool, - verbose: bool, - is_cached: bool, - **kwargs, - ) -> CrawlResult: - t = time.time() - # Extract content from HTML - try: - t1 = time.time() - scrapping_strategy = WebScrapingStrategy() - extra_params = {k: v for k, v in kwargs.items() if k not in ["only_text", "image_description_min_word_threshold"]} - result = scrapping_strategy.scrap( - url, - html, - word_count_threshold=word_count_threshold, - css_selector=css_selector, - only_text=kwargs.get("only_text", False), - image_description_min_word_threshold=kwargs.get( - "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD - ), - **extra_params, + self, + url: str, + html: str, + extracted_content: str, + word_count_threshold: int, + extraction_strategy: ExtractionStrategy, + chunking_strategy: ChunkingStrategy, + css_selector: str, + screenshot: bool, + verbose: bool, + is_cached: bool, + **kwargs, + ) -> CrawlResult: + t = time.time() + # Extract content from HTML + try: + t1 = time.time() + scrapping_strategy = WebScrapingStrategy() + extra_params = { + k: v + for k, v in kwargs.items() + if k not in ["only_text", "image_description_min_word_threshold"] + } + result = scrapping_strategy.scrap( + url, + html, + word_count_threshold=word_count_threshold, + css_selector=css_selector, + only_text=kwargs.get("only_text", False), + image_description_min_word_threshold=kwargs.get( + "image_description_min_word_threshold", + IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, + ), + **extra_params, + ) + + # result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False)) + if verbose: + print( + f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds" ) - - # result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False)) - if verbose: - print(f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds") - - if result is None: - raise ValueError(f"Failed to extract content from the website: {url}") - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - - cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) - markdown = sanitize_input_encode(result.get("markdown", "")) - media = result.get("media", []) - links = result.get("links", []) - metadata = result.get("metadata", {}) - - if extracted_content is None: - if verbose: - print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}") - sections = chunking_strategy.chunk(markdown) - extracted_content = extraction_strategy.run(url, sections) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) + if result is None: + raise ValueError(f"Failed to extract content from the website: {url}") + except InvalidCSSSelectorError as e: + raise ValueError(str(e)) - if verbose: - print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds.") - - screenshot = None if not screenshot else screenshot - - if not is_cached: - cache_url( - url, - html, - cleaned_html, - markdown, - extracted_content, - True, - json.dumps(media), - json.dumps(links), - json.dumps(metadata), - screenshot=screenshot, - ) - - return CrawlResult( - url=url, - html=html, - cleaned_html=format_html(cleaned_html), - markdown=markdown, - media=media, - links=links, - metadata=metadata, + cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) + markdown = sanitize_input_encode(result.get("markdown", "")) + media = result.get("media", []) + links = result.get("links", []) + metadata = result.get("metadata", {}) + + if extracted_content is None: + if verbose: + print( + f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}" + ) + + sections = chunking_strategy.chunk(markdown) + extracted_content = extraction_strategy.run(url, sections) + extracted_content = json.dumps( + extracted_content, indent=4, default=str, ensure_ascii=False + ) + + if verbose: + print( + f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds." + ) + + screenshot = None if not screenshot else screenshot + + if not is_cached: + cache_url( + url, + html, + cleaned_html, + markdown, + extracted_content, + True, + json.dumps(media), + json.dumps(links), + json.dumps(metadata), screenshot=screenshot, - extracted_content=extracted_content, - success=True, - error_message="", - ) \ No newline at end of file + ) + + return CrawlResult( + url=url, + html=html, + cleaned_html=format_html(cleaned_html), + markdown=markdown, + media=media, + links=links, + metadata=metadata, + screenshot=screenshot, + extracted_content=extracted_content, + success=True, + error_message="", + ) diff --git a/docs/examples/amazon_product_extraction_direct_url.py b/docs/examples/amazon_product_extraction_direct_url.py index 769c479e..ec734245 100644 --- a/docs/examples/amazon_product_extraction_direct_url.py +++ b/docs/examples/amazon_product_extraction_direct_url.py @@ -9,13 +9,11 @@ from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig import json + async def extract_amazon_products(): # Initialize browser config - browser_config = BrowserConfig( - browser_type="chromium", - headless=True - ) - + browser_config = BrowserConfig(browser_type="chromium", headless=True) + # Initialize crawler config with JSON CSS extraction strategy crawler_config = CrawlerRunConfig( extraction_strategy=JsonCssExtractionStrategy( @@ -27,74 +25,70 @@ async def extract_amazon_products(): "name": "asin", "selector": "", "type": "attribute", - "attribute": "data-asin" - }, - { - "name": "title", - "selector": "h2 a span", - "type": "text" + "attribute": "data-asin", }, + {"name": "title", "selector": "h2 a span", "type": "text"}, { "name": "url", "selector": "h2 a", "type": "attribute", - "attribute": "href" + "attribute": "href", }, { "name": "image", "selector": ".s-image", "type": "attribute", - "attribute": "src" + "attribute": "src", }, { "name": "rating", "selector": ".a-icon-star-small .a-icon-alt", - "type": "text" + "type": "text", }, { "name": "reviews_count", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", - "type": "text" + "type": "text", }, { "name": "price", "selector": ".a-price .a-offscreen", - "type": "text" + "type": "text", }, { "name": "original_price", "selector": ".a-price.a-text-price .a-offscreen", - "type": "text" + "type": "text", }, { "name": "sponsored", "selector": ".puis-sponsored-label-text", - "type": "exists" + "type": "exists", }, { "name": "delivery_info", "selector": "[data-cy='delivery-recipe'] .a-color-base", "type": "text", - "multiple": True - } - ] + "multiple": True, + }, + ], } ) ) # Example search URL (you should replace with your actual Amazon URL) url = "https://www.amazon.com/s?k=Samsung+Galaxy+Tab" - + # Use context manager for proper resource handling async with AsyncWebCrawler(config=browser_config) as crawler: # Extract the data result = await crawler.arun(url=url, config=crawler_config) - + # Process and print the results if result and result.extracted_content: # Parse the JSON string into a list of products products = json.loads(result.extracted_content) - + # Process each product in the list for product in products: print("\nProduct Details:") @@ -105,10 +99,12 @@ async def extract_amazon_products(): print(f"Rating: {product.get('rating')}") print(f"Reviews: {product.get('reviews_count')}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") - if product.get('delivery_info'): + if product.get("delivery_info"): print(f"Delivery: {' '.join(product['delivery_info'])}") print("-" * 80) + if __name__ == "__main__": import asyncio + asyncio.run(extract_amazon_products()) diff --git a/docs/examples/amazon_product_extraction_using_hooks.py b/docs/examples/amazon_product_extraction_using_hooks.py index a17d60c5..5118b5d9 100644 --- a/docs/examples/amazon_product_extraction_using_hooks.py +++ b/docs/examples/amazon_product_extraction_using_hooks.py @@ -10,17 +10,17 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig import json from playwright.async_api import Page, BrowserContext + async def extract_amazon_products(): # Initialize browser config browser_config = BrowserConfig( # browser_type="chromium", headless=True ) - + # Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, - extraction_strategy=JsonCssExtractionStrategy( schema={ "name": "Amazon Product Search Results", @@ -30,102 +30,105 @@ async def extract_amazon_products(): "name": "asin", "selector": "", "type": "attribute", - "attribute": "data-asin" - }, - { - "name": "title", - "selector": "h2 a span", - "type": "text" + "attribute": "data-asin", }, + {"name": "title", "selector": "h2 a span", "type": "text"}, { "name": "url", "selector": "h2 a", "type": "attribute", - "attribute": "href" + "attribute": "href", }, { "name": "image", "selector": ".s-image", "type": "attribute", - "attribute": "src" + "attribute": "src", }, { "name": "rating", "selector": ".a-icon-star-small .a-icon-alt", - "type": "text" + "type": "text", }, { "name": "reviews_count", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", - "type": "text" + "type": "text", }, { "name": "price", "selector": ".a-price .a-offscreen", - "type": "text" + "type": "text", }, { "name": "original_price", "selector": ".a-price.a-text-price .a-offscreen", - "type": "text" + "type": "text", }, { "name": "sponsored", "selector": ".puis-sponsored-label-text", - "type": "exists" + "type": "exists", }, { "name": "delivery_info", "selector": "[data-cy='delivery-recipe'] .a-color-base", "type": "text", - "multiple": True - } - ] + "multiple": True, + }, + ], } - ) + ), ) url = "https://www.amazon.com/" - - async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs): + + async def after_goto( + page: Page, context: BrowserContext, url: str, response: dict, **kwargs + ): """Hook called after navigating to each URL""" print(f"[HOOK] after_goto - Successfully loaded: {url}") - + try: # Wait for search box to be available - search_box = await page.wait_for_selector('#twotabsearchtextbox', timeout=1000) - + search_box = await page.wait_for_selector( + "#twotabsearchtextbox", timeout=1000 + ) + # Type the search query - await search_box.fill('Samsung Galaxy Tab') - + await search_box.fill("Samsung Galaxy Tab") + # Get the search button and prepare for navigation - search_button = await page.wait_for_selector('#nav-search-submit-button', timeout=1000) - + search_button = await page.wait_for_selector( + "#nav-search-submit-button", timeout=1000 + ) + # Click with navigation waiting await search_button.click() - + # Wait for search results to load - await page.wait_for_selector('[data-component-type="s-search-result"]', timeout=10000) + await page.wait_for_selector( + '[data-component-type="s-search-result"]', timeout=10000 + ) print("[HOOK] Search completed and results loaded!") - + except Exception as e: print(f"[HOOK] Error during search operation: {str(e)}") - - return page - + + return page + # Use context manager for proper resource handling async with AsyncWebCrawler(config=browser_config) as crawler: - crawler.crawler_strategy.set_hook("after_goto", after_goto) - + # Extract the data result = await crawler.arun(url=url, config=crawler_config) - + # Process and print the results if result and result.extracted_content: # Parse the JSON string into a list of products products = json.loads(result.extracted_content) - + # Process each product in the list for product in products: print("\nProduct Details:") @@ -136,10 +139,12 @@ async def extract_amazon_products(): print(f"Rating: {product.get('rating')}") print(f"Reviews: {product.get('reviews_count')}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") - if product.get('delivery_info'): + if product.get("delivery_info"): print(f"Delivery: {' '.join(product['delivery_info'])}") print("-" * 80) + if __name__ == "__main__": import asyncio + asyncio.run(extract_amazon_products()) diff --git a/docs/examples/amazon_product_extraction_using_use_javascript.py b/docs/examples/amazon_product_extraction_using_use_javascript.py index 15e5d6f5..e412c931 100644 --- a/docs/examples/amazon_product_extraction_using_use_javascript.py +++ b/docs/examples/amazon_product_extraction_using_use_javascript.py @@ -8,7 +8,7 @@ from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig import json -from playwright.async_api import Page, BrowserContext + async def extract_amazon_products(): # Initialize browser config @@ -16,7 +16,7 @@ async def extract_amazon_products(): # browser_type="chromium", headless=True ) - + js_code_to_search = """ const task = async () => { document.querySelector('#twotabsearchtextbox').value = 'Samsung Galaxy Tab'; @@ -30,7 +30,7 @@ async def extract_amazon_products(): """ crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, - js_code = js_code_to_search, + js_code=js_code_to_search, wait_for='css:[data-component-type="s-search-result"]', extraction_strategy=JsonCssExtractionStrategy( schema={ @@ -41,75 +41,70 @@ async def extract_amazon_products(): "name": "asin", "selector": "", "type": "attribute", - "attribute": "data-asin" - }, - { - "name": "title", - "selector": "h2 a span", - "type": "text" + "attribute": "data-asin", }, + {"name": "title", "selector": "h2 a span", "type": "text"}, { "name": "url", "selector": "h2 a", "type": "attribute", - "attribute": "href" + "attribute": "href", }, { "name": "image", "selector": ".s-image", "type": "attribute", - "attribute": "src" + "attribute": "src", }, { "name": "rating", "selector": ".a-icon-star-small .a-icon-alt", - "type": "text" + "type": "text", }, { "name": "reviews_count", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", - "type": "text" + "type": "text", }, { "name": "price", "selector": ".a-price .a-offscreen", - "type": "text" + "type": "text", }, { "name": "original_price", "selector": ".a-price.a-text-price .a-offscreen", - "type": "text" + "type": "text", }, { "name": "sponsored", "selector": ".puis-sponsored-label-text", - "type": "exists" + "type": "exists", }, { "name": "delivery_info", "selector": "[data-cy='delivery-recipe'] .a-color-base", "type": "text", - "multiple": True - } - ] + "multiple": True, + }, + ], } - ) + ), ) # Example search URL (you should replace with your actual Amazon URL) url = "https://www.amazon.com/" - - + # Use context manager for proper resource handling async with AsyncWebCrawler(config=browser_config) as crawler: # Extract the data result = await crawler.arun(url=url, config=crawler_config) - + # Process and print the results if result and result.extracted_content: # Parse the JSON string into a list of products products = json.loads(result.extracted_content) - + # Process each product in the list for product in products: print("\nProduct Details:") @@ -120,10 +115,12 @@ async def extract_amazon_products(): print(f"Rating: {product.get('rating')}") print(f"Reviews: {product.get('reviews_count')}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") - if product.get('delivery_info'): + if product.get("delivery_info"): print(f"Delivery: {' '.join(product['delivery_info'])}") print("-" * 80) + if __name__ == "__main__": import asyncio + asyncio.run(extract_amazon_products()) diff --git a/docs/examples/async_webcrawler_multiple_urls_example.py b/docs/examples/async_webcrawler_multiple_urls_example.py index 1d63ac80..52309d13 100644 --- a/docs/examples/async_webcrawler_multiple_urls_example.py +++ b/docs/examples/async_webcrawler_multiple_urls_example.py @@ -1,12 +1,16 @@ # File: async_webcrawler_multiple_urls_example.py import os, sys + # append 2 parent directories to sys.path to import crawl4ai -parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) sys.path.append(parent_dir) import asyncio from crawl4ai import AsyncWebCrawler + async def main(): # Initialize the AsyncWebCrawler async with AsyncWebCrawler(verbose=True) as crawler: @@ -16,7 +20,7 @@ async def main(): "https://python.org", "https://github.com", "https://stackoverflow.com", - "https://news.ycombinator.com" + "https://news.ycombinator.com", ] # Set up crawling parameters @@ -27,7 +31,7 @@ async def main(): urls=urls, word_count_threshold=word_count_threshold, bypass_cache=True, - verbose=True + verbose=True, ) # Process the results @@ -36,7 +40,9 @@ async def main(): print(f"Successfully crawled: {result.url}") print(f"Title: {result.metadata.get('title', 'N/A')}") print(f"Word count: {len(result.markdown.split())}") - print(f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}") + print( + f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}" + ) print(f"Number of images: {len(result.media.get('images', []))}") print("---") else: @@ -44,5 +50,6 @@ async def main(): print(f"Error: {result.error_message}") print("---") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docs/examples/browser_optimization_example.py b/docs/examples/browser_optimization_example.py index f57dc147..73637a71 100644 --- a/docs/examples/browser_optimization_example.py +++ b/docs/examples/browser_optimization_example.py @@ -6,10 +6,8 @@ This example demonstrates optimal browser usage patterns in Crawl4AI: """ import asyncio -import os from typing import List from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig -from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator diff --git a/docs/examples/crawlai_vs_firecrawl.py b/docs/examples/crawlai_vs_firecrawl.py index b50b06da..f8b70dc7 100644 --- a/docs/examples/crawlai_vs_firecrawl.py +++ b/docs/examples/crawlai_vs_firecrawl.py @@ -1,31 +1,32 @@ import os, time + # append the path to the root of the project import sys import asyncio -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from firecrawl import FirecrawlApp from crawl4ai import AsyncWebCrawler -__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data' + +__data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data" + async def compare(): - app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY']) + app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"]) # Tet Firecrawl with a simple crawl start = time.time() scrape_status = app.scrape_url( - 'https://www.nbcnews.com/business', - params={'formats': ['markdown', 'html']} + "https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]} ) end = time.time() print(f"Time taken: {end - start} seconds") - print(len(scrape_status['markdown'])) + print(len(scrape_status["markdown"])) # save the markdown content with provider name with open(f"{__data__}/firecrawl_simple.md", "w") as f: - f.write(scrape_status['markdown']) + f.write(scrape_status["markdown"]) # Count how many "cldnry.s-nbcnews.com" are in the markdown - print(scrape_status['markdown'].count("cldnry.s-nbcnews.com")) - - + print(scrape_status["markdown"].count("cldnry.s-nbcnews.com")) async with AsyncWebCrawler() as crawler: start = time.time() @@ -33,13 +34,13 @@ async def compare(): url="https://www.nbcnews.com/business", # js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], word_count_threshold=0, - bypass_cache=True, - verbose=False + bypass_cache=True, + verbose=False, ) end = time.time() print(f"Time taken: {end - start} seconds") print(len(result.markdown)) - # save the markdown content with provider name + # save the markdown content with provider name with open(f"{__data__}/crawl4ai_simple.md", "w") as f: f.write(result.markdown) # count how many "cldnry.s-nbcnews.com" are in the markdown @@ -48,10 +49,12 @@ async def compare(): start = time.time() result = await crawler.arun( url="https://www.nbcnews.com/business", - js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], + js_code=[ + "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" + ], word_count_threshold=0, - bypass_cache=True, - verbose=False + bypass_cache=True, + verbose=False, ) end = time.time() print(f"Time taken: {end - start} seconds") @@ -61,7 +64,7 @@ async def compare(): f.write(result.markdown) # count how many "cldnry.s-nbcnews.com" are in the markdown print(result.markdown.count("cldnry.s-nbcnews.com")) - + + if __name__ == "__main__": asyncio.run(compare()) - \ No newline at end of file diff --git a/docs/examples/dispatcher_example.py b/docs/examples/dispatcher_example.py index 796ee078..c9708ccc 100644 --- a/docs/examples/dispatcher_example.py +++ b/docs/examples/dispatcher_example.py @@ -3,11 +3,18 @@ import time from rich import print from rich.table import Table from crawl4ai import ( - AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, - MemoryAdaptiveDispatcher, SemaphoreDispatcher, - RateLimiter, CrawlerMonitor, DisplayMode, CacheMode + AsyncWebCrawler, + BrowserConfig, + CrawlerRunConfig, + MemoryAdaptiveDispatcher, + SemaphoreDispatcher, + RateLimiter, + CrawlerMonitor, + DisplayMode, + CacheMode, ) + async def memory_adaptive(urls, browser_config, run_config): """Memory adaptive crawler with monitoring""" start = time.perf_counter() @@ -16,14 +23,16 @@ async def memory_adaptive(urls, browser_config, run_config): memory_threshold_percent=70.0, max_session_permit=10, monitor=CrawlerMonitor( - max_visible_rows=15, - display_mode=DisplayMode.DETAILED - ) + max_visible_rows=15, display_mode=DisplayMode.DETAILED + ), + ) + results = await crawler.arun_many( + urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) duration = time.perf_counter() - start return len(results), duration + async def memory_adaptive_with_rate_limit(urls, browser_config, run_config): """Memory adaptive crawler with rate limiting""" start = time.perf_counter() @@ -32,19 +41,19 @@ async def memory_adaptive_with_rate_limit(urls, browser_config, run_config): memory_threshold_percent=70.0, max_session_permit=10, rate_limiter=RateLimiter( - base_delay=(1.0, 2.0), - max_delay=30.0, - max_retries=2 + base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2 ), monitor=CrawlerMonitor( - max_visible_rows=15, - display_mode=DisplayMode.DETAILED - ) + max_visible_rows=15, display_mode=DisplayMode.DETAILED + ), + ) + results = await crawler.arun_many( + urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) duration = time.perf_counter() - start return len(results), duration + async def semaphore(urls, browser_config, run_config): """Basic semaphore crawler""" start = time.perf_counter() @@ -52,14 +61,16 @@ async def semaphore(urls, browser_config, run_config): dispatcher = SemaphoreDispatcher( semaphore_count=5, monitor=CrawlerMonitor( - max_visible_rows=15, - display_mode=DisplayMode.DETAILED - ) + max_visible_rows=15, display_mode=DisplayMode.DETAILED + ), + ) + results = await crawler.arun_many( + urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) duration = time.perf_counter() - start return len(results), duration + async def semaphore_with_rate_limit(urls, browser_config, run_config): """Semaphore crawler with rate limiting""" start = time.perf_counter() @@ -67,19 +78,19 @@ async def semaphore_with_rate_limit(urls, browser_config, run_config): dispatcher = SemaphoreDispatcher( semaphore_count=5, rate_limiter=RateLimiter( - base_delay=(1.0, 2.0), - max_delay=30.0, - max_retries=2 + base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2 ), monitor=CrawlerMonitor( - max_visible_rows=15, - display_mode=DisplayMode.DETAILED - ) + max_visible_rows=15, display_mode=DisplayMode.DETAILED + ), + ) + results = await crawler.arun_many( + urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) duration = time.perf_counter() - start return len(results), duration + def create_performance_table(results): """Creates a rich table showing performance results""" table = Table(title="Crawler Strategy Performance Comparison") @@ -89,18 +100,16 @@ def create_performance_table(results): table.add_column("URLs/second", justify="right", style="magenta") sorted_results = sorted(results.items(), key=lambda x: x[1][1]) - + for strategy, (urls_crawled, duration) in sorted_results: urls_per_second = urls_crawled / duration table.add_row( - strategy, - str(urls_crawled), - f"{duration:.2f}", - f"{urls_per_second:.2f}" + strategy, str(urls_crawled), f"{duration:.2f}", f"{urls_per_second:.2f}" ) - + return table + async def main(): urls = [f"https://example.com/page{i}" for i in range(1, 20)] browser_config = BrowserConfig(headless=True, verbose=False) @@ -108,14 +117,19 @@ async def main(): results = { "Memory Adaptive": await memory_adaptive(urls, browser_config, run_config), - "Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(urls, browser_config, run_config), + "Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit( + urls, browser_config, run_config + ), "Semaphore": await semaphore(urls, browser_config, run_config), - "Semaphore + Rate Limit": await semaphore_with_rate_limit(urls, browser_config, run_config), + "Semaphore + Rate Limit": await semaphore_with_rate_limit( + urls, browser_config, run_config + ), } table = create_performance_table(results) print("\nPerformance Summary:") print(table) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docs/examples/docker_example.py b/docs/examples/docker_example.py index 48acc809..fe1d0727 100644 --- a/docs/examples/docker_example.py +++ b/docs/examples/docker_example.py @@ -6,63 +6,80 @@ import base64 import os from typing import Dict, Any + class Crawl4AiTester: def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): self.base_url = base_url - self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" # Check environment variable as fallback - self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} - - def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: + self.api_token = ( + api_token or os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code" + ) # Check environment variable as fallback + self.headers = ( + {"Authorization": f"Bearer {self.api_token}"} if self.api_token else {} + ) + + def submit_and_wait( + self, request_data: Dict[str, Any], timeout: int = 300 + ) -> Dict[str, Any]: # Submit crawl job - response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) + response = requests.post( + f"{self.base_url}/crawl", json=request_data, headers=self.headers + ) if response.status_code == 403: raise Exception("API token is invalid or missing") task_id = response.json()["task_id"] print(f"Task ID: {task_id}") - + # Poll for result start_time = time.time() while True: if time.time() - start_time > timeout: - raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") - - result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) + raise TimeoutError( + f"Task {task_id} did not complete within {timeout} seconds" + ) + + result = requests.get( + f"{self.base_url}/task/{task_id}", headers=self.headers + ) status = result.json() - + if status["status"] == "failed": print("Task failed:", status.get("error")) raise Exception(f"Task failed: {status.get('error')}") - + if status["status"] == "completed": return status - + time.sleep(2) - + def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) + response = requests.post( + f"{self.base_url}/crawl_sync", + json=request_data, + headers=self.headers, + timeout=60, + ) if response.status_code == 408: raise TimeoutError("Task did not complete within server timeout") response.raise_for_status() return response.json() - + def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Directly crawl without using task queue""" response = requests.post( - f"{self.base_url}/crawl_direct", - json=request_data, - headers=self.headers + f"{self.base_url}/crawl_direct", json=request_data, headers=self.headers ) response.raise_for_status() return response.json() + def test_docker_deployment(version="basic"): tester = Crawl4AiTester( - base_url="http://localhost:11235" , + base_url="http://localhost:11235", # base_url="https://api.crawl4ai.com" # just for example # api_token="test" # just for example ) print(f"Testing Crawl4AI Docker {version} version") - + # Health check with timeout and retry max_retries = 5 for i in range(max_retries): @@ -70,19 +87,19 @@ def test_docker_deployment(version="basic"): health = requests.get(f"{tester.base_url}/health", timeout=10) print("Health check:", health.json()) break - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i == max_retries - 1: print(f"Failed to connect after {max_retries} attempts") sys.exit(1) print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") time.sleep(5) - + # Test cases based on version test_basic_crawl_direct(tester) test_basic_crawl(tester) test_basic_crawl(tester) test_basic_crawl_sync(tester) - + if version in ["full", "transformer"]: test_cosine_extraction(tester) @@ -92,49 +109,52 @@ def test_docker_deployment(version="basic"): test_llm_extraction(tester) test_llm_with_ollama(tester) test_screenshot(tester) - + def test_basic_crawl(tester: Crawl4AiTester): print("\n=== Testing Basic Crawl ===") request = { "urls": "https://www.nbcnews.com/business", - "priority": 10, - "session_id": "test" + "priority": 10, + "session_id": "test", } - + result = tester.submit_and_wait(request) print(f"Basic crawl result length: {len(result['result']['markdown'])}") assert result["result"]["success"] assert len(result["result"]["markdown"]) > 0 + def test_basic_crawl_sync(tester: Crawl4AiTester): print("\n=== Testing Basic Crawl (Sync) ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 10, - "session_id": "test" + "session_id": "test", } - + result = tester.submit_sync(request) print(f"Basic crawl result length: {len(result['result']['markdown'])}") - assert result['status'] == 'completed' - assert result['result']['success'] - assert len(result['result']['markdown']) > 0 - + assert result["status"] == "completed" + assert result["result"]["success"] + assert len(result["result"]["markdown"]) > 0 + + def test_basic_crawl_direct(tester: Crawl4AiTester): print("\n=== Testing Basic Crawl (Direct) ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 10, # "session_id": "test" - "cache_mode": "bypass" # or "enabled", "disabled", "read_only", "write_only" + "cache_mode": "bypass", # or "enabled", "disabled", "read_only", "write_only" } - + result = tester.crawl_direct(request) print(f"Basic crawl result length: {len(result['result']['markdown'])}") - assert result['result']['success'] - assert len(result['result']['markdown']) > 0 - + assert result["result"]["success"] + assert len(result["result"]["markdown"]) > 0 + + def test_js_execution(tester: Crawl4AiTester): print("\n=== Testing JS Execution ===") request = { @@ -144,32 +164,29 @@ def test_js_execution(tester: Crawl4AiTester): "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" ], "wait_for": "article.tease-card:nth-child(10)", - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } - + result = tester.submit_and_wait(request) print(f"JS execution result length: {len(result['result']['markdown'])}") assert result["result"]["success"] + def test_css_selector(tester: Crawl4AiTester): print("\n=== Testing CSS Selector ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 7, "css_selector": ".wide-tease-item__description", - "crawler_params": { - "headless": True - }, - "extra": {"word_count_threshold": 10} - + "crawler_params": {"headless": True}, + "extra": {"word_count_threshold": 10}, } - + result = tester.submit_and_wait(request) print(f"CSS selector result length: {len(result['result']['markdown'])}") assert result["result"]["success"] + def test_structured_extraction(tester: Crawl4AiTester): print("\n=== Testing Structured Extraction ===") schema = { @@ -190,21 +207,16 @@ def test_structured_extraction(tester: Crawl4AiTester): "name": "price", "selector": "td:nth-child(2)", "type": "text", - } + }, ], } - + request = { "urls": "https://www.coinbase.com/explore", "priority": 9, - "extraction_config": { - "type": "json_css", - "params": { - "schema": schema - } - } + "extraction_config": {"type": "json_css", "params": {"schema": schema}}, } - + result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) print(f"Extracted {len(extracted)} items") @@ -212,6 +224,7 @@ def test_structured_extraction(tester: Crawl4AiTester): assert result["result"]["success"] assert len(extracted) > 0 + def test_llm_extraction(tester: Crawl4AiTester): print("\n=== Testing LLM Extraction ===") schema = { @@ -219,20 +232,20 @@ def test_llm_extraction(tester: Crawl4AiTester): "properties": { "model_name": { "type": "string", - "description": "Name of the OpenAI model." + "description": "Name of the OpenAI model.", }, "input_fee": { "type": "string", - "description": "Fee for input token for the OpenAI model." + "description": "Fee for input token for the OpenAI model.", }, "output_fee": { "type": "string", - "description": "Fee for output token for the OpenAI model." - } + "description": "Fee for output token for the OpenAI model.", + }, }, - "required": ["model_name", "input_fee", "output_fee"] + "required": ["model_name", "input_fee", "output_fee"], } - + request = { "urls": "https://openai.com/api/pricing", "priority": 8, @@ -243,12 +256,12 @@ def test_llm_extraction(tester: Crawl4AiTester): "api_token": os.getenv("OPENAI_API_KEY"), "schema": schema, "extraction_type": "schema", - "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" - } + "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""", + }, }, - "crawler_params": {"word_count_threshold": 1} + "crawler_params": {"word_count_threshold": 1}, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -258,6 +271,7 @@ def test_llm_extraction(tester: Crawl4AiTester): except Exception as e: print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") + def test_llm_with_ollama(tester: Crawl4AiTester): print("\n=== Testing LLM with Ollama ===") schema = { @@ -265,20 +279,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester): "properties": { "article_title": { "type": "string", - "description": "The main title of the news article" + "description": "The main title of the news article", }, "summary": { "type": "string", - "description": "A brief summary of the article content" + "description": "A brief summary of the article content", }, "main_topics": { "type": "array", "items": {"type": "string"}, - "description": "Main topics or themes discussed in the article" - } - } + "description": "Main topics or themes discussed in the article", + }, + }, } - + request = { "urls": "https://www.nbcnews.com/business", "priority": 8, @@ -288,13 +302,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester): "provider": "ollama/llama2", "schema": schema, "extraction_type": "schema", - "instruction": "Extract the main article information including title, summary, and main topics." - } + "instruction": "Extract the main article information including title, summary, and main topics.", + }, }, "extra": {"word_count_threshold": 1}, - "crawler_params": {"verbose": True} + "crawler_params": {"verbose": True}, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -303,6 +317,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester): except Exception as e: print(f"Ollama extraction test failed: {str(e)}") + def test_cosine_extraction(tester: Crawl4AiTester): print("\n=== Testing Cosine Extraction ===") request = { @@ -314,11 +329,11 @@ def test_cosine_extraction(tester: Crawl4AiTester): "semantic_filter": "business finance economy", "word_count_threshold": 10, "max_dist": 0.2, - "top_k": 3 - } - } + "top_k": 3, + }, + }, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -328,30 +343,30 @@ def test_cosine_extraction(tester: Crawl4AiTester): except Exception as e: print(f"Cosine extraction test failed: {str(e)}") + def test_screenshot(tester: Crawl4AiTester): print("\n=== Testing Screenshot ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 5, "screenshot": True, - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } - + result = tester.submit_and_wait(request) print("Screenshot captured:", bool(result["result"]["screenshot"])) - + if result["result"]["screenshot"]: # Save screenshot screenshot_data = base64.b64decode(result["result"]["screenshot"]) with open("test_screenshot.jpg", "wb") as f: f.write(screenshot_data) print("Screenshot saved as test_screenshot.jpg") - + assert result["result"]["success"] + if __name__ == "__main__": version = sys.argv[1] if len(sys.argv) > 1 else "basic" # version = "full" - test_docker_deployment(version) \ No newline at end of file + test_docker_deployment(version) diff --git a/docs/examples/extraction_strategies_example.py b/docs/examples/extraction_strategies_example.py index 348b891e..658f7521 100644 --- a/docs/examples/extraction_strategies_example.py +++ b/docs/examples/extraction_strategies_example.py @@ -9,18 +9,17 @@ This example shows how to: import asyncio import os -from typing import Dict, Any from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from crawl4ai.extraction_strategy import ( LLMExtractionStrategy, JsonCssExtractionStrategy, - JsonXPathExtractionStrategy + JsonXPathExtractionStrategy, ) -from crawl4ai.chunking_strategy import RegexChunking, IdentityChunking from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator + async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str): """Helper function to run extraction with proper configuration""" try: @@ -30,78 +29,90 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str extraction_strategy=strategy, markdown_generator=DefaultMarkdownGenerator( content_filter=PruningContentFilter() # For fit_markdown support - ) + ), ) - + # Run the crawler result = await crawler.arun(url=url, config=config) - + if result.success: print(f"\n=== {name} Results ===") print(f"Extracted Content: {result.extracted_content}") print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}") - print(f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}") + print( + f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}" + ) else: print(f"Error in {name}: Crawl failed") - + except Exception as e: print(f"Error in {name}: {str(e)}") + async def main(): # Example URL (replace with actual URL) url = "https://example.com/product-page" - + # Configure browser settings - browser_config = BrowserConfig( - headless=True, - verbose=True - ) - + browser_config = BrowserConfig(headless=True, verbose=True) + # Initialize extraction strategies - + # 1. LLM Extraction with different input formats markdown_strategy = LLMExtractionStrategy( provider="openai/gpt-4o-mini", api_token=os.getenv("OPENAI_API_KEY"), - instruction="Extract product information including name, price, and description" + instruction="Extract product information including name, price, and description", ) - + html_strategy = LLMExtractionStrategy( input_format="html", provider="openai/gpt-4o-mini", api_token=os.getenv("OPENAI_API_KEY"), - instruction="Extract product information from HTML including structured data" + instruction="Extract product information from HTML including structured data", ) - + fit_markdown_strategy = LLMExtractionStrategy( input_format="fit_markdown", provider="openai/gpt-4o-mini", api_token=os.getenv("OPENAI_API_KEY"), - instruction="Extract product information from cleaned markdown" + instruction="Extract product information from cleaned markdown", ) - + # 2. JSON CSS Extraction (automatically uses HTML input) css_schema = { "baseSelector": ".product", "fields": [ {"name": "title", "selector": "h1.product-title", "type": "text"}, {"name": "price", "selector": ".price", "type": "text"}, - {"name": "description", "selector": ".description", "type": "text"} - ] + {"name": "description", "selector": ".description", "type": "text"}, + ], } css_strategy = JsonCssExtractionStrategy(schema=css_schema) - + # 3. JSON XPath Extraction (automatically uses HTML input) xpath_schema = { "baseSelector": "//div[@class='product']", "fields": [ - {"name": "title", "selector": ".//h1[@class='product-title']/text()", "type": "text"}, - {"name": "price", "selector": ".//span[@class='price']/text()", "type": "text"}, - {"name": "description", "selector": ".//div[@class='description']/text()", "type": "text"} - ] + { + "name": "title", + "selector": ".//h1[@class='product-title']/text()", + "type": "text", + }, + { + "name": "price", + "selector": ".//span[@class='price']/text()", + "type": "text", + }, + { + "name": "description", + "selector": ".//div[@class='description']/text()", + "type": "text", + }, + ], } xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema) - + # Use context manager for proper resource handling async with AsyncWebCrawler(config=browser_config) as crawler: # Run all strategies @@ -111,5 +122,6 @@ async def main(): await run_extraction(crawler, url, css_strategy, "CSS Extraction") await run_extraction(crawler, url, xpath_strategy, "XPath Extraction") + if __name__ == "__main__": asyncio.run(main()) diff --git a/docs/examples/hello_world.py b/docs/examples/hello_world.py index 18534d0e..97a8187e 100644 --- a/docs/examples/hello_world.py +++ b/docs/examples/hello_world.py @@ -1,20 +1,23 @@ import asyncio from crawl4ai import * + async def main(): browser_config = BrowserConfig(headless=True, verbose=True) async with AsyncWebCrawler(config=browser_config) as crawler: crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( - content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) - ) + content_filter=PruningContentFilter( + threshold=0.48, threshold_type="fixed", min_word_threshold=0 + ) + ), ) result = await crawler.arun( - url="https://www.helloworld.org", - config=crawler_config + url="https://www.helloworld.org", config=crawler_config ) print(result.markdown_v2.raw_markdown[:500]) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docs/examples/hooks_example.py b/docs/examples/hooks_example.py index 06b509bd..de0aa6e1 100644 --- a/docs/examples/hooks_example.py +++ b/docs/examples/hooks_example.py @@ -1,19 +1,18 @@ from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from playwright.async_api import Page, BrowserContext + async def main(): print("🔗 Hooks Example: Demonstrating different hook use cases") # Configure browser settings - browser_config = BrowserConfig( - headless=True - ) - + browser_config = BrowserConfig(headless=True) + # Configure crawler settings crawler_run_config = CrawlerRunConfig( js_code="window.scrollTo(0, document.body.scrollHeight);", wait_for="body", - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) # Create crawler instance @@ -30,16 +29,22 @@ async def main(): """Hook called after a new page and context are created""" print("[HOOK] on_page_context_created - New page created!") # Example: Set default viewport size - await context.add_cookies([{ - 'name': 'session_id', - 'value': 'example_session', - 'domain': '.example.com', - 'path': '/' - }]) + await context.add_cookies( + [ + { + "name": "session_id", + "value": "example_session", + "domain": ".example.com", + "path": "/", + } + ] + ) await page.set_viewport_size({"width": 1080, "height": 800}) return page - async def on_user_agent_updated(page: Page, context: BrowserContext, user_agent: str, **kwargs): + async def on_user_agent_updated( + page: Page, context: BrowserContext, user_agent: str, **kwargs + ): """Hook called when the user agent is updated""" print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}") return page @@ -53,17 +58,17 @@ async def main(): """Hook called before navigating to each URL""" print(f"[HOOK] before_goto - About to visit: {url}") # Example: Add custom headers for the request - await page.set_extra_http_headers({ - "Custom-Header": "my-value" - }) + await page.set_extra_http_headers({"Custom-Header": "my-value"}) return page - async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs): + async def after_goto( + page: Page, context: BrowserContext, url: str, response: dict, **kwargs + ): """Hook called after navigating to each URL""" print(f"[HOOK] after_goto - Successfully loaded: {url}") # Example: Wait for a specific element to be loaded try: - await page.wait_for_selector('.content', timeout=1000) + await page.wait_for_selector(".content", timeout=1000) print("Content element found!") except: print("Content element not found, continuing anyway") @@ -76,7 +81,9 @@ async def main(): await page.evaluate("window.scrollTo(0, document.body.scrollHeight);") return page - async def before_return_html(page: Page, context: BrowserContext, html:str, **kwargs): + async def before_return_html( + page: Page, context: BrowserContext, html: str, **kwargs + ): """Hook called before returning the HTML content""" print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})") # Example: You could modify the HTML content here if needed @@ -84,7 +91,9 @@ async def main(): # Set all the hooks crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created) - crawler.crawler_strategy.set_hook("on_page_context_created", on_page_context_created) + crawler.crawler_strategy.set_hook( + "on_page_context_created", on_page_context_created + ) crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated) crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started) crawler.crawler_strategy.set_hook("before_goto", before_goto) @@ -95,13 +104,15 @@ async def main(): await crawler.start() # Example usage: crawl a simple website - url = 'https://example.com' + url = "https://example.com" result = await crawler.arun(url, config=crawler_run_config) print(f"\nCrawled URL: {result.url}") print(f"HTML length: {len(result.html)}") - + await crawler.close() + if __name__ == "__main__": import asyncio - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/docs/examples/language_support_example.py b/docs/examples/language_support_example.py index b74a8402..712db2c4 100644 --- a/docs/examples/language_support_example.py +++ b/docs/examples/language_support_example.py @@ -1,6 +1,7 @@ import asyncio from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy + async def main(): # Example 1: Setting language when creating the crawler crawler1 = AsyncWebCrawler( @@ -9,11 +10,15 @@ async def main(): ) ) result1 = await crawler1.arun("https://www.example.com") - print("Example 1 result:", result1.extracted_content[:100]) # Print first 100 characters + print( + "Example 1 result:", result1.extracted_content[:100] + ) # Print first 100 characters # Example 2: Setting language before crawling crawler2 = AsyncWebCrawler() - crawler2.crawler_strategy.headers["Accept-Language"] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7" + crawler2.crawler_strategy.headers[ + "Accept-Language" + ] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7" result2 = await crawler2.arun("https://www.example.com") print("Example 2 result:", result2.extracted_content[:100]) @@ -21,7 +26,7 @@ async def main(): crawler3 = AsyncWebCrawler() result3 = await crawler3.arun( "https://www.example.com", - headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"} + headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"}, ) print("Example 3 result:", result3.extracted_content[:100]) @@ -31,15 +36,15 @@ async def main(): ("https://www.example.org", "es-ES,es;q=0.9"), ("https://www.example.net", "de-DE,de;q=0.9"), ] - + crawler4 = AsyncWebCrawler() - results = await asyncio.gather(*[ - crawler4.arun(url, headers={"Accept-Language": lang}) - for url, lang in urls - ]) - + results = await asyncio.gather( + *[crawler4.arun(url, headers={"Accept-Language": lang}) for url, lang in urls] + ) + for url, result in zip([u for u, _ in urls], results): print(f"Result for {url}:", result.extracted_content[:100]) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docs/examples/llm_extraction_openai_pricing.py b/docs/examples/llm_extraction_openai_pricing.py index 5ae3d4d1..e9e90dd2 100644 --- a/docs/examples/llm_extraction_openai_pricing.py +++ b/docs/examples/llm_extraction_openai_pricing.py @@ -3,32 +3,37 @@ from crawl4ai.crawler_strategy import * import asyncio from pydantic import BaseModel, Field -url = r'https://openai.com/api/pricing/' +url = r"https://openai.com/api/pricing/" + class OpenAIModelFee(BaseModel): model_name: str = Field(..., description="Name of the OpenAI model.") input_fee: str = Field(..., description="Fee for input token for the OpenAI model.") - output_fee: str = Field(..., description="Fee for output token for the OpenAI model.") + output_fee: str = Field( + ..., description="Fee for output token for the OpenAI model." + ) + from crawl4ai import AsyncWebCrawler + async def main(): # Use AsyncWebCrawler async with AsyncWebCrawler() as crawler: result = await crawler.arun( url=url, word_count_threshold=1, - extraction_strategy= LLMExtractionStrategy( + extraction_strategy=LLMExtractionStrategy( # provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'), - provider= "groq/llama-3.1-70b-versatile", api_token = os.getenv('GROQ_API_KEY'), + provider="groq/llama-3.1-70b-versatile", + api_token=os.getenv("GROQ_API_KEY"), schema=OpenAIModelFee.model_json_schema(), extraction_type="schema", - instruction="From the crawled content, extract all mentioned model names along with their " \ - "fees for input and output tokens. Make sure not to miss anything in the entire content. " \ - 'One extracted model JSON format should look like this: ' \ - '{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }' + instruction="From the crawled content, extract all mentioned model names along with their " + "fees for input and output tokens. Make sure not to miss anything in the entire content. " + "One extracted model JSON format should look like this: " + '{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }', ), - ) print("Success:", result.success) model_fees = json.loads(result.extracted_content) @@ -37,4 +42,5 @@ async def main(): with open(".data/data.json", "w", encoding="utf-8") as f: f.write(result.extracted_content) + asyncio.run(main()) diff --git a/docs/examples/quickstart_async.config.py b/docs/examples/quickstart_async.config.py index 4c4a9d86..a2a02da8 100644 --- a/docs/examples/quickstart_async.config.py +++ b/docs/examples/quickstart_async.config.py @@ -8,12 +8,12 @@ import asyncio import time import json import re -from typing import Dict, List +from typing import Dict from bs4 import BeautifulSoup from pydantic import BaseModel, Field from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator -from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter +from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.extraction_strategy import ( JsonCssExtractionStrategy, LLMExtractionStrategy, @@ -62,6 +62,7 @@ async def clean_content(): print(f"Full Markdown Length: {full_markdown_length}") print(f"Fit Markdown Length: {fit_markdown_length}") + async def link_analysis(): crawler_config = CrawlerRunConfig( cache_mode=CacheMode.ENABLED, @@ -76,9 +77,10 @@ async def link_analysis(): print(f"Found {len(result.links['internal'])} internal links") print(f"Found {len(result.links['external'])} external links") - for link in result.links['internal'][:5]: + for link in result.links["internal"][:5]: print(f"Href: {link['href']}\nText: {link['text']}\n") + # JavaScript Execution Example async def simple_example_with_running_js_code(): print("\n--- Executing JavaScript and Using CSS Selectors ---") @@ -112,25 +114,29 @@ async def simple_example_with_css_selector(): ) print(result.markdown[:500]) + async def media_handling(): - crawler_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True) + crawler_config = CrawlerRunConfig( + cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True + ) async with AsyncWebCrawler() as crawler: result = await crawler.arun( - url="https://www.nbcnews.com/business", - config=crawler_config + url="https://www.nbcnews.com/business", config=crawler_config ) - for img in result.media['images'][:5]: + for img in result.media["images"][:5]: print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}") + async def custom_hook_workflow(verbose=True): async with AsyncWebCrawler() as crawler: # Set a 'before_goto' hook to run custom code just before navigation - crawler.crawler_strategy.set_hook("before_goto", lambda page, context: print("[Hook] Preparing to navigate...")) + crawler.crawler_strategy.set_hook( + "before_goto", + lambda page, context: print("[Hook] Preparing to navigate..."), + ) # Perform the crawl operation - result = await crawler.arun( - url="https://crawl4ai.com" - ) + result = await crawler.arun(url="https://crawl4ai.com") print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- ")) @@ -412,21 +418,22 @@ async def cosine_similarity_extraction(): cache_mode=CacheMode.BYPASS, extraction_strategy=CosineStrategy( word_count_threshold=10, - max_dist=0.2, # Maximum distance between two words - linkage_method="ward", # Linkage method for hierarchical clustering (ward, complete, average, single) - top_k=3, # Number of top keywords to extract - sim_threshold=0.3, # Similarity threshold for clustering - semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings - verbose=True - ), + max_dist=0.2, # Maximum distance between two words + linkage_method="ward", # Linkage method for hierarchical clustering (ward, complete, average, single) + top_k=3, # Number of top keywords to extract + sim_threshold=0.3, # Similarity threshold for clustering + semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings + verbose=True, + ), ) async with AsyncWebCrawler() as crawler: result = await crawler.arun( url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156", - config=crawl_config + config=crawl_config, ) print(json.loads(result.extracted_content)[:5]) + # Browser Comparison async def crawl_custom_browser_type(): print("\n--- Browser Comparison ---") @@ -484,39 +491,42 @@ async def crawl_with_user_simulation(): result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config) print(result.markdown) + async def ssl_certification(): # Configure crawler to fetch SSL certificate config = CrawlerRunConfig( fetch_ssl_certificate=True, - cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates + cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates ) async with AsyncWebCrawler() as crawler: - result = await crawler.arun( - url='https://example.com', - config=config - ) - + result = await crawler.arun(url="https://example.com", config=config) + if result.success and result.ssl_certificate: cert = result.ssl_certificate - + # 1. Access certificate properties directly print("\nCertificate Information:") print(f"Issuer: {cert.issuer.get('CN', '')}") print(f"Valid until: {cert.valid_until}") print(f"Fingerprint: {cert.fingerprint}") - + # 2. Export certificate in different formats cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis print("\nCertificate exported to:") print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}") - - pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers + + pem_data = cert.to_pem( + os.path.join(tmp_dir, "certificate.pem") + ) # For web servers print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}") - - der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps + + der_data = cert.to_der( + os.path.join(tmp_dir, "certificate.der") + ) # For Java apps print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}") + # Speed Comparison async def speed_comparison(): print("\n--- Speed Comparison ---") diff --git a/docs/examples/quickstart_async.py b/docs/examples/quickstart_async.py index e640e6bd..1585ebea 100644 --- a/docs/examples/quickstart_async.py +++ b/docs/examples/quickstart_async.py @@ -1,6 +1,10 @@ import os, sys + # append parent directory to system path -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))); os.environ['FIRECRAWL_API_KEY'] = "fc-84b370ccfad44beabc686b38f1769692"; +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +os.environ["FIRECRAWL_API_KEY"] = "fc-84b370ccfad44beabc686b38f1769692" import asyncio # import nest_asyncio @@ -15,7 +19,7 @@ from bs4 import BeautifulSoup from pydantic import BaseModel, Field from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator -from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter +from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.extraction_strategy import ( JsonCssExtractionStrategy, LLMExtractionStrategy, @@ -32,9 +36,12 @@ print("Website: https://crawl4ai.com") async def simple_crawl(): print("\n--- Basic Usage ---") async with AsyncWebCrawler(verbose=True) as crawler: - result = await crawler.arun(url="https://www.nbcnews.com/business", cache_mode= CacheMode.BYPASS) + result = await crawler.arun( + url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS + ) print(result.markdown[:500]) # Print first 500 characters + async def simple_example_with_running_js_code(): print("\n--- Executing JavaScript and Using CSS Selectors ---") # New code to handle the wait_for parameter @@ -57,6 +64,7 @@ async def simple_example_with_running_js_code(): ) print(result.markdown[:500]) # Print first 500 characters + async def simple_example_with_css_selector(): print("\n--- Using CSS Selectors ---") async with AsyncWebCrawler(verbose=True) as crawler: @@ -67,42 +75,44 @@ async def simple_example_with_css_selector(): ) print(result.markdown[:500]) # Print first 500 characters + async def use_proxy(): print("\n--- Using a Proxy ---") print( "Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example." ) # Uncomment and modify the following lines to use a proxy - async with AsyncWebCrawler(verbose=True, proxy="http://your-proxy-url:port") as crawler: + async with AsyncWebCrawler( + verbose=True, proxy="http://your-proxy-url:port" + ) as crawler: result = await crawler.arun( - url="https://www.nbcnews.com/business", - cache_mode= CacheMode.BYPASS + url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS ) if result.success: print(result.markdown[:500]) # Print first 500 characters + async def capture_and_save_screenshot(url: str, output_path: str): async with AsyncWebCrawler(verbose=True) as crawler: result = await crawler.arun( - url=url, - screenshot=True, - cache_mode= CacheMode.BYPASS + url=url, screenshot=True, cache_mode=CacheMode.BYPASS ) - + if result.success and result.screenshot: import base64 - + # Decode the base64 screenshot data screenshot_data = base64.b64decode(result.screenshot) - + # Save the screenshot as a JPEG file - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: f.write(screenshot_data) - + print(f"Screenshot saved successfully to {output_path}") else: print("Failed to capture screenshot") + class OpenAIModelFee(BaseModel): model_name: str = Field(..., description="Name of the OpenAI model.") input_fee: str = Field(..., description="Fee for input token for the OpenAI model.") @@ -110,16 +120,19 @@ class OpenAIModelFee(BaseModel): ..., description="Fee for output token for the OpenAI model." ) -async def extract_structured_data_using_llm(provider: str, api_token: str = None, extra_headers: Dict[str, str] = None): + +async def extract_structured_data_using_llm( + provider: str, api_token: str = None, extra_headers: Dict[str, str] = None +): print(f"\n--- Extracting Structured Data with {provider} ---") - + if api_token is None and provider != "ollama": print(f"API token is required for {provider}. Skipping this example.") return # extra_args = {} - extra_args={ - "temperature": 0, + extra_args = { + "temperature": 0, "top_p": 0.9, "max_tokens": 2000, # any other supported parameters for litellm @@ -139,52 +152,49 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens. Do not miss any models in the entire content. One extracted model JSON format should look like this: {"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""", - extra_args=extra_args + extra_args=extra_args, ), cache_mode=CacheMode.BYPASS, ) print(result.extracted_content) + async def extract_structured_data_using_css_extractor(): print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---") schema = { - "name": "KidoCode Courses", - "baseSelector": "section.charge-methodology .w-tab-content > div", - "fields": [ - { - "name": "section_title", - "selector": "h3.heading-50", - "type": "text", - }, - { - "name": "section_description", - "selector": ".charge-content", - "type": "text", - }, - { - "name": "course_name", - "selector": ".text-block-93", - "type": "text", - }, - { - "name": "course_description", - "selector": ".course-content-text", - "type": "text", - }, - { - "name": "course_icon", - "selector": ".image-92", - "type": "attribute", - "attribute": "src" - } - ] -} + "name": "KidoCode Courses", + "baseSelector": "section.charge-methodology .w-tab-content > div", + "fields": [ + { + "name": "section_title", + "selector": "h3.heading-50", + "type": "text", + }, + { + "name": "section_description", + "selector": ".charge-content", + "type": "text", + }, + { + "name": "course_name", + "selector": ".text-block-93", + "type": "text", + }, + { + "name": "course_description", + "selector": ".course-content-text", + "type": "text", + }, + { + "name": "course_icon", + "selector": ".image-92", + "type": "attribute", + "attribute": "src", + }, + ], + } - async with AsyncWebCrawler( - headless=True, - verbose=True - ) as crawler: - + async with AsyncWebCrawler(headless=True, verbose=True) as crawler: # Create the JavaScript that handles clicking multiple times js_click_tabs = """ (async () => { @@ -198,19 +208,20 @@ async def extract_structured_data_using_css_extractor(): await new Promise(r => setTimeout(r, 500)); } })(); - """ + """ result = await crawler.arun( url="https://www.kidocode.com/degrees/technology", extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True), js_code=[js_click_tabs], - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) companies = json.loads(result.extracted_content) print(f"Successfully extracted {len(companies)} companies") print(json.dumps(companies[0], indent=2)) + # Advanced Session-Based Crawling with Dynamic Content 🔄 async def crawl_dynamic_content_pages_method_1(): print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---") @@ -267,6 +278,7 @@ async def crawl_dynamic_content_pages_method_1(): await crawler.crawler_strategy.kill_session(session_id) print(f"Successfully crawled {len(all_commits)} commits across 3 pages") + async def crawl_dynamic_content_pages_method_2(): print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---") @@ -334,8 +346,11 @@ async def crawl_dynamic_content_pages_method_2(): await crawler.crawler_strategy.kill_session(session_id) print(f"Successfully crawled {len(all_commits)} commits across 3 pages") + async def crawl_dynamic_content_pages_method_3(): - print("\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---") + print( + "\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---" + ) async with AsyncWebCrawler(verbose=True) as crawler: url = "https://github.com/microsoft/TypeScript/commits/main" @@ -357,7 +372,7 @@ async def crawl_dynamic_content_pages_method_3(): const firstCommit = commits[0].textContent.trim(); return firstCommit !== window.firstCommit; }""" - + schema = { "name": "Commit Extractor", "baseSelector": "li.Box-sc-g0xbh4-0", @@ -395,40 +410,53 @@ async def crawl_dynamic_content_pages_method_3(): await crawler.crawler_strategy.kill_session(session_id) print(f"Successfully crawled {len(all_commits)} commits across 3 pages") + async def crawl_custom_browser_type(): # Use Firefox start = time.time() - async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler: - result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) + async with AsyncWebCrawler( + browser_type="firefox", verbose=True, headless=True + ) as crawler: + result = await crawler.arun( + url="https://www.example.com", cache_mode=CacheMode.BYPASS + ) print(result.markdown[:500]) print("Time taken: ", time.time() - start) # Use WebKit start = time.time() - async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler: - result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) + async with AsyncWebCrawler( + browser_type="webkit", verbose=True, headless=True + ) as crawler: + result = await crawler.arun( + url="https://www.example.com", cache_mode=CacheMode.BYPASS + ) print(result.markdown[:500]) print("Time taken: ", time.time() - start) # Use Chromium (default) start = time.time() - async with AsyncWebCrawler(verbose=True, headless = True) as crawler: - result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) + async with AsyncWebCrawler(verbose=True, headless=True) as crawler: + result = await crawler.arun( + url="https://www.example.com", cache_mode=CacheMode.BYPASS + ) print(result.markdown[:500]) print("Time taken: ", time.time() - start) + async def crawl_with_user_simultion(): async with AsyncWebCrawler(verbose=True, headless=True) as crawler: url = "YOUR-URL-HERE" result = await crawler.arun( - url=url, + url=url, cache_mode=CacheMode.BYPASS, - magic = True, # Automatically detects and removes overlays, popups, and other elements that block content + magic=True, # Automatically detects and removes overlays, popups, and other elements that block content # simulate_user = True,# Causes a series of random mouse movements and clicks to simulate user interaction # override_navigator = True # Overrides the navigator object to make it look like a real user ) - - print(result.markdown) + + print(result.markdown) + async def speed_comparison(): # print("\n--- Speed Comparison ---") @@ -439,18 +467,18 @@ async def speed_comparison(): # print() # Simulated Firecrawl performance from firecrawl import FirecrawlApp - app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY']) + + app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"]) start = time.time() scrape_status = app.scrape_url( - 'https://www.nbcnews.com/business', - params={'formats': ['markdown', 'html']} + "https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]} ) end = time.time() print("Firecrawl:") print(f"Time taken: {end - start:.2f} seconds") print(f"Content length: {len(scrape_status['markdown'])} characters") print(f"Images found: {scrape_status['markdown'].count('cldnry.s-nbcnews.com')}") - print() + print() async with AsyncWebCrawler() as crawler: # Crawl4AI simple crawl @@ -474,7 +502,9 @@ async def speed_comparison(): url="https://www.nbcnews.com/business", word_count_threshold=0, markdown_generator=DefaultMarkdownGenerator( - content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) + content_filter=PruningContentFilter( + threshold=0.48, threshold_type="fixed", min_word_threshold=0 + ) # content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0) ), cache_mode=CacheMode.BYPASS, @@ -498,7 +528,9 @@ async def speed_comparison(): word_count_threshold=0, cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( - content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) + content_filter=PruningContentFilter( + threshold=0.48, threshold_type="fixed", min_word_threshold=0 + ) # content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0) ), verbose=False, @@ -520,11 +552,12 @@ async def speed_comparison(): print("If you run these tests in an environment with better network conditions,") print("you may observe an even more significant speed advantage for Crawl4AI.") + async def generate_knowledge_graph(): class Entity(BaseModel): name: str description: str - + class Relationship(BaseModel): entity1: Entity entity2: Entity @@ -536,11 +569,11 @@ async def generate_knowledge_graph(): relationships: List[Relationship] extraction_strategy = LLMExtractionStrategy( - provider='openai/gpt-4o-mini', # Or any other provider, including Ollama and open source models - api_token=os.getenv('OPENAI_API_KEY'), # In case of Ollama just pass "no-token" - schema=KnowledgeGraph.model_json_schema(), - extraction_type="schema", - instruction="""Extract entities and relationships from the given text.""" + provider="openai/gpt-4o-mini", # Or any other provider, including Ollama and open source models + api_token=os.getenv("OPENAI_API_KEY"), # In case of Ollama just pass "no-token" + schema=KnowledgeGraph.model_json_schema(), + extraction_type="schema", + instruction="""Extract entities and relationships from the given text.""", ) async with AsyncWebCrawler() as crawler: url = "https://paulgraham.com/love.html" @@ -554,27 +587,22 @@ async def generate_knowledge_graph(): with open(os.path.join(__location__, "kb.json"), "w") as f: f.write(result.extracted_content) + async def fit_markdown_remove_overlay(): - async with AsyncWebCrawler( - headless=True, # Set to False to see what is happening - verbose=True, - user_agent_mode="random", - user_agent_generator_config={ - "device_type": "mobile", - "os_type": "android" - }, + headless=True, # Set to False to see what is happening + verbose=True, + user_agent_mode="random", + user_agent_generator_config={"device_type": "mobile", "os_type": "android"}, ) as crawler: result = await crawler.arun( - url='https://www.kidocode.com/degrees/technology', + url="https://www.kidocode.com/degrees/technology", cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( content_filter=PruningContentFilter( threshold=0.48, threshold_type="fixed", min_word_threshold=0 ), - options={ - "ignore_links": True - } + options={"ignore_links": True}, ), # markdown_generator=DefaultMarkdownGenerator( # content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0), @@ -583,31 +611,38 @@ async def fit_markdown_remove_overlay(): # } # ), ) - + if result.success: print(len(result.markdown_v2.raw_markdown)) print(len(result.markdown_v2.markdown_with_citations)) print(len(result.markdown_v2.fit_markdown)) - + # Save clean html with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f: f.write(result.cleaned_html) - - with open(os.path.join(__location__, "output/output_raw_markdown.md"), "w") as f: + + with open( + os.path.join(__location__, "output/output_raw_markdown.md"), "w" + ) as f: f.write(result.markdown_v2.raw_markdown) - - with open(os.path.join(__location__, "output/output_markdown_with_citations.md"), "w") as f: - f.write(result.markdown_v2.markdown_with_citations) - - with open(os.path.join(__location__, "output/output_fit_markdown.md"), "w") as f: + + with open( + os.path.join(__location__, "output/output_markdown_with_citations.md"), + "w", + ) as f: + f.write(result.markdown_v2.markdown_with_citations) + + with open( + os.path.join(__location__, "output/output_fit_markdown.md"), "w" + ) as f: f.write(result.markdown_v2.fit_markdown) - + print("Done") async def main(): # await extract_structured_data_using_llm("openai/gpt-4o", os.getenv("OPENAI_API_KEY")) - + # await simple_crawl() # await simple_example_with_running_js_code() # await simple_example_with_css_selector() @@ -618,7 +653,7 @@ async def main(): # LLM extraction examples # await extract_structured_data_using_llm() # await extract_structured_data_using_llm("huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", os.getenv("HUGGINGFACE_API_KEY")) - # await extract_structured_data_using_llm("ollama/llama3.2") + # await extract_structured_data_using_llm("ollama/llama3.2") # You always can pass custom headers to the extraction strategy # custom_headers = { @@ -626,13 +661,13 @@ async def main(): # "X-Custom-Header": "Some-Value" # } # await extract_structured_data_using_llm(extra_headers=custom_headers) - + # await crawl_dynamic_content_pages_method_1() # await crawl_dynamic_content_pages_method_2() await crawl_dynamic_content_pages_method_3() - + # await crawl_custom_browser_type() - + # await speed_comparison() diff --git a/docs/examples/quickstart_sync.py b/docs/examples/quickstart_sync.py index 89c63139..0248af29 100644 --- a/docs/examples/quickstart_sync.py +++ b/docs/examples/quickstart_sync.py @@ -10,15 +10,17 @@ from functools import lru_cache console = Console() + @lru_cache() def create_crawler(): crawler = WebCrawler(verbose=True) crawler.warmup() return crawler + def print_result(result): # Print each key in one line and just the first 10 characters of each one's value and three dots - console.print(f"\t[bold]Result:[/bold]") + console.print("\t[bold]Result:[/bold]") for key, value in result.model_dump().items(): if isinstance(value, str) and value: console.print(f"\t{key}: [green]{value[:20]}...[/green]") @@ -33,18 +35,27 @@ def cprint(message, press_any_key=False): console.print("Press any key to continue...", style="") input() + def basic_usage(crawler): - cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") - result = crawler.run(url="https://www.nbcnews.com/business", only_text = True) + cprint( + "🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]" + ) + result = crawler.run(url="https://www.nbcnews.com/business", only_text=True) cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") print_result(result) + def basic_usage_some_params(crawler): - cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") - result = crawler.run(url="https://www.nbcnews.com/business", word_count_threshold=1, only_text = True) + cprint( + "🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]" + ) + result = crawler.run( + url="https://www.nbcnews.com/business", word_count_threshold=1, only_text=True + ) cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") print_result(result) + def screenshot_usage(crawler): cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]") result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True) @@ -55,16 +66,23 @@ def screenshot_usage(crawler): cprint("Screenshot saved to 'screenshot.png'!") print_result(result) + def understanding_parameters(crawler): - cprint("\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]") - cprint("By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action.") - + cprint( + "\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]" + ) + cprint( + "By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action." + ) + # First crawl (reads from cache) cprint("1️⃣ First crawl (caches the result):", True) start_time = time.time() result = crawler.run(url="https://www.nbcnews.com/business") end_time = time.time() - cprint(f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]") + cprint( + f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]" + ) print_result(result) # Force to crawl again @@ -72,169 +90,232 @@ def understanding_parameters(crawler): start_time = time.time() result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True) end_time = time.time() - cprint(f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]") + cprint( + f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]" + ) print_result(result) + def add_chunking_strategy(crawler): # Adding a chunking strategy: RegexChunking - cprint("\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", True) - cprint("RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!") + cprint( + "\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", + True, + ) + cprint( + "RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!" + ) result = crawler.run( url="https://www.nbcnews.com/business", - chunking_strategy=RegexChunking(patterns=["\n\n"]) + chunking_strategy=RegexChunking(patterns=["\n\n"]), ) cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]") print_result(result) # Adding another chunking strategy: NlpSentenceChunking - cprint("\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", True) - cprint("NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!") + cprint( + "\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", + True, + ) + cprint( + "NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!" + ) result = crawler.run( - url="https://www.nbcnews.com/business", - chunking_strategy=NlpSentenceChunking() + url="https://www.nbcnews.com/business", chunking_strategy=NlpSentenceChunking() ) cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]") print_result(result) + def add_extraction_strategy(crawler): # Adding an extraction strategy: CosineStrategy - cprint("\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", True) - cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!") + cprint( + "\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", + True, + ) + cprint( + "CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!" + ) result = crawler.run( url="https://www.nbcnews.com/business", - extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True) + extraction_strategy=CosineStrategy( + word_count_threshold=10, + max_dist=0.2, + linkage_method="ward", + top_k=3, + sim_threshold=0.3, + verbose=True, + ), ) cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]") print_result(result) - + # Using semantic_filter with CosineStrategy - cprint("You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!") + cprint( + "You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!" + ) result = crawler.run( url="https://www.nbcnews.com/business", extraction_strategy=CosineStrategy( semantic_filter="inflation rent prices", - ) + ), + ) + cprint( + "[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]" ) - cprint("[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]") print_result(result) + def add_llm_extraction_strategy(crawler): # Adding an LLM extraction strategy without instructions - cprint("\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", True) - cprint("LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!") + cprint( + "\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", + True, + ) + cprint( + "LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!" + ) result = crawler.run( url="https://www.nbcnews.com/business", - extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-4o", api_token=os.getenv('OPENAI_API_KEY')) + extraction_strategy=LLMExtractionStrategy( + provider="openai/gpt-4o", api_token=os.getenv("OPENAI_API_KEY") + ), + ) + cprint( + "[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]" ) - cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]") print_result(result) - + # Adding an LLM extraction strategy with instructions - cprint("\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", True) - cprint("Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!") + cprint( + "\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", + True, + ) + cprint( + "Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!" + ) result = crawler.run( url="https://www.nbcnews.com/business", extraction_strategy=LLMExtractionStrategy( provider="openai/gpt-4o", - api_token=os.getenv('OPENAI_API_KEY'), - instruction="I am interested in only financial news" - ) + api_token=os.getenv("OPENAI_API_KEY"), + instruction="I am interested in only financial news", + ), + ) + cprint( + "[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]" ) - cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]") print_result(result) - + result = crawler.run( url="https://www.nbcnews.com/business", extraction_strategy=LLMExtractionStrategy( provider="openai/gpt-4o", - api_token=os.getenv('OPENAI_API_KEY'), - instruction="Extract only content related to technology" - ) + api_token=os.getenv("OPENAI_API_KEY"), + instruction="Extract only content related to technology", + ), + ) + cprint( + "[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]" ) - cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]") print_result(result) + def targeted_extraction(crawler): # Using a CSS selector to extract only H2 tags - cprint("\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", True) - result = crawler.run( - url="https://www.nbcnews.com/business", - css_selector="h2" + cprint( + "\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", + True, ) + result = crawler.run(url="https://www.nbcnews.com/business", css_selector="h2") cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]") print_result(result) + def interactive_extraction(crawler): # Passing JavaScript code to interact with the page - cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True) - cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.") + cprint( + "\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", + True, + ) + cprint( + "In this example we try to click the 'Load More' button on the page using JavaScript code." + ) js_code = """ const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click(); """ # crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code) # crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True) - result = crawler.run( - url="https://www.nbcnews.com/business", - js = js_code + result = crawler.run(url="https://www.nbcnews.com/business", js=js_code) + cprint( + "[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]" ) - cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]") print_result(result) + def multiple_scrip(crawler): # Passing JavaScript code to interact with the page - cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True) - cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.") - js_code = [""" + cprint( + "\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", + True, + ) + cprint( + "In this example we try to click the 'Load More' button on the page using JavaScript code." + ) + js_code = [ + """ const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click(); - """] * 2 + """ + ] * 2 # crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code) # crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True) - result = crawler.run( - url="https://www.nbcnews.com/business", - js = js_code + result = crawler.run(url="https://www.nbcnews.com/business", js=js_code) + cprint( + "[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]" ) - cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]") print_result(result) + def using_crawler_hooks(crawler): # Example usage of the hooks for authentication and setting a cookie def on_driver_created(driver): print("[HOOK] on_driver_created") # Example customization: maximize the window driver.maximize_window() - + # Example customization: logging in to a hypothetical website - driver.get('https://example.com/login') - + driver.get("https://example.com/login") + from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC - + WebDriverWait(driver, 10).until( - EC.presence_of_element_located((By.NAME, 'username')) + EC.presence_of_element_located((By.NAME, "username")) ) - driver.find_element(By.NAME, 'username').send_keys('testuser') - driver.find_element(By.NAME, 'password').send_keys('password123') - driver.find_element(By.NAME, 'login').click() + driver.find_element(By.NAME, "username").send_keys("testuser") + driver.find_element(By.NAME, "password").send_keys("password123") + driver.find_element(By.NAME, "login").click() WebDriverWait(driver, 10).until( - EC.presence_of_element_located((By.ID, 'welcome')) + EC.presence_of_element_located((By.ID, "welcome")) ) # Add a custom cookie - driver.add_cookie({'name': 'test_cookie', 'value': 'cookie_value'}) - return driver - + driver.add_cookie({"name": "test_cookie", "value": "cookie_value"}) + return driver def before_get_url(driver): print("[HOOK] before_get_url") # Example customization: add a custom header # Enable Network domain for sending headers - driver.execute_cdp_cmd('Network.enable', {}) + driver.execute_cdp_cmd("Network.enable", {}) # Add a custom header - driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': {'X-Test-Header': 'test'}}) + driver.execute_cdp_cmd( + "Network.setExtraHTTPHeaders", {"headers": {"X-Test-Header": "test"}} + ) return driver - + def after_get_url(driver): print("[HOOK] after_get_url") # Example customization: log the URL @@ -246,48 +327,59 @@ def using_crawler_hooks(crawler): # Example customization: log the HTML print(len(html)) return driver - - cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", True) - + + cprint( + "\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", + True, + ) + crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True) - crawler_strategy.set_hook('on_driver_created', on_driver_created) - crawler_strategy.set_hook('before_get_url', before_get_url) - crawler_strategy.set_hook('after_get_url', after_get_url) - crawler_strategy.set_hook('before_return_html', before_return_html) - + crawler_strategy.set_hook("on_driver_created", on_driver_created) + crawler_strategy.set_hook("before_get_url", before_get_url) + crawler_strategy.set_hook("after_get_url", after_get_url) + crawler_strategy.set_hook("before_return_html", before_return_html) + crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy) - crawler.warmup() + crawler.warmup() result = crawler.run(url="https://example.com") - + cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]") - print_result(result= result) - + print_result(result=result) + + def using_crawler_hooks_dleay_example(crawler): def delay(driver): print("Delaying for 5 seconds...") time.sleep(5) print("Resuming...") - + def create_crawler(): crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True) - crawler_strategy.set_hook('after_get_url', delay) + crawler_strategy.set_hook("after_get_url", delay) crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy) crawler.warmup() return crawler - cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]") + cprint( + "\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]" + ) crawler = create_crawler() - result = crawler.run(url="https://google.com", bypass_cache=True) - + result = crawler.run(url="https://google.com", bypass_cache=True) + cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]") print_result(result) - - + def main(): - cprint("🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]") - cprint("⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]") - cprint("If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files.") + cprint( + "🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]" + ) + cprint( + "⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]" + ) + cprint( + "If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files." + ) crawler = create_crawler() @@ -295,7 +387,7 @@ def main(): basic_usage(crawler) # basic_usage_some_params(crawler) understanding_parameters(crawler) - + crawler.always_by_pass_cache = True screenshot_usage(crawler) add_chunking_strategy(crawler) @@ -305,8 +397,10 @@ def main(): interactive_extraction(crawler) multiple_scrip(crawler) - cprint("\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]") + cprint( + "\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]" + ) + if __name__ == "__main__": main() - diff --git a/docs/examples/research_assistant.py b/docs/examples/research_assistant.py index de35ce84..84ba3c76 100644 --- a/docs/examples/research_assistant.py +++ b/docs/examples/research_assistant.py @@ -11,7 +11,9 @@ from groq import Groq # Import threadpools to run the crawl_url function in a separate thread from concurrent.futures import ThreadPoolExecutor -client = AsyncOpenAI(base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")) +client = AsyncOpenAI( + base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY") +) # Instrument the OpenAI client cl.instrument_openai() @@ -25,41 +27,39 @@ settings = { "presence_penalty": 0, } + def extract_urls(text): - url_pattern = re.compile(r'(https?://\S+)') + url_pattern = re.compile(r"(https?://\S+)") return url_pattern.findall(text) + def crawl_url(url): data = { "urls": [url], "include_raw_html": True, "word_count_threshold": 10, "extraction_strategy": "NoExtractionStrategy", - "chunking_strategy": "RegexChunking" + "chunking_strategy": "RegexChunking", } response = requests.post("https://crawl4ai.com/crawl", json=data) response_data = response.json() - response_data = response_data['results'][0] - return response_data['markdown'] + response_data = response_data["results"][0] + return response_data["markdown"] + @cl.on_chat_start async def on_chat_start(): - cl.user_session.set("session", { - "history": [], - "context": {} - }) - await cl.Message( - content="Welcome to the chat! How can I assist you today?" - ).send() + cl.user_session.set("session", {"history": [], "context": {}}) + await cl.Message(content="Welcome to the chat! How can I assist you today?").send() + @cl.on_message async def on_message(message: cl.Message): user_session = cl.user_session.get("session") - + # Extract URLs from the user's message urls = extract_urls(message.content) - - + futures = [] with ThreadPoolExecutor() as executor: for url in urls: @@ -69,16 +69,9 @@ async def on_message(message: cl.Message): for url, result in zip(urls, results): ref_number = f"REF_{len(user_session['context']) + 1}" - user_session["context"][ref_number] = { - "url": url, - "content": result - } + user_session["context"][ref_number] = {"url": url, "content": result} - - user_session["history"].append({ - "role": "user", - "content": message.content - }) + user_session["history"].append({"role": "user", "content": message.content}) # Create a system message that includes the context context_messages = [ @@ -95,26 +88,17 @@ async def on_message(message: cl.Message): "If not, there is no need to add a references section. " "At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n" "\n\n".join(context_messages) - ) + ), } else: - system_message = { - "role": "system", - "content": "You are a helpful assistant." - } - + system_message = {"role": "system", "content": "You are a helpful assistant."} msg = cl.Message(content="") await msg.send() # Get response from the LLM stream = await client.chat.completions.create( - messages=[ - system_message, - *user_session["history"] - ], - stream=True, - **settings + messages=[system_message, *user_session["history"]], stream=True, **settings ) assistant_response = "" @@ -124,10 +108,7 @@ async def on_message(message: cl.Message): await msg.stream_token(token) # Add assistant message to the history - user_session["history"].append({ - "role": "assistant", - "content": assistant_response - }) + user_session["history"].append({"role": "assistant", "content": assistant_response}) await msg.update() # Append the reference section to the assistant's response @@ -154,10 +135,11 @@ async def on_audio_chunk(chunk: cl.AudioChunk): pass + @cl.step(type="tool") async def speech_to_text(audio_file): cli = Groq() - + response = await client.audio.transcriptions.create( model="whisper-large-v3", file=audio_file ) @@ -172,24 +154,19 @@ async def on_audio_end(elements: list[ElementBased]): audio_buffer.seek(0) # Move the file pointer to the beginning audio_file = audio_buffer.read() audio_mime_type: str = cl.user_session.get("audio_mime_type") - + start_time = time.time() whisper_input = (audio_buffer.name, audio_file, audio_mime_type) transcription = await speech_to_text(whisper_input) end_time = time.time() print(f"Transcription took {end_time - start_time} seconds") - - user_msg = cl.Message( - author="You", - type="user_message", - content=transcription - ) + + user_msg = cl.Message(author="You", type="user_message", content=transcription) await user_msg.send() await on_message(user_msg) if __name__ == "__main__": from chainlit.cli import run_chainlit + run_chainlit(__file__) - - diff --git a/docs/examples/rest_call.py b/docs/examples/rest_call.py index 465c6114..47c09435 100644 --- a/docs/examples/rest_call.py +++ b/docs/examples/rest_call.py @@ -1,4 +1,3 @@ - import requests, base64, os data = { @@ -6,59 +5,50 @@ data = { "screenshot": True, } -response = requests.post("https://crawl4ai.com/crawl", json=data) -result = response.json()['results'][0] +response = requests.post("https://crawl4ai.com/crawl", json=data) +result = response.json()["results"][0] print(result.keys()) -# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media', -# 'links', 'screenshot', 'markdown', 'extracted_content', +# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media', +# 'links', 'screenshot', 'markdown', 'extracted_content', # 'metadata', 'error_message']) with open("screenshot.png", "wb") as f: - f.write(base64.b64decode(result['screenshot'])) - + f.write(base64.b64decode(result["screenshot"])) + # Example of filtering the content using CSS selectors data = { - "urls": [ - "https://www.nbcnews.com/business" - ], + "urls": ["https://www.nbcnews.com/business"], "css_selector": "article", "screenshot": True, } # Example of executing a JS script on the page before extracting the content data = { - "urls": [ - "https://www.nbcnews.com/business" - ], + "urls": ["https://www.nbcnews.com/business"], "screenshot": True, - 'js' : [""" + "js": [ + """ const loadMoreButton = Array.from(document.querySelectorAll('button')). find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click(); - """] + """ + ], } # Example of using a custom extraction strategy data = { - "urls": [ - "https://www.nbcnews.com/business" - ], + "urls": ["https://www.nbcnews.com/business"], "extraction_strategy": "CosineStrategy", - "extraction_strategy_args": { - "semantic_filter": "inflation rent prices" - }, + "extraction_strategy_args": {"semantic_filter": "inflation rent prices"}, } # Example of using LLM to extract content data = { - "urls": [ - "https://www.nbcnews.com/business" - ], + "urls": ["https://www.nbcnews.com/business"], "extraction_strategy": "LLMExtractionStrategy", "extraction_strategy_args": { "provider": "groq/llama3-8b-8192", "api_token": os.environ.get("GROQ_API_KEY"), "instruction": """I am interested in only financial news, - and translate them in French.""" + and translate them in French.""", }, } - diff --git a/docs/examples/ssl_example.py b/docs/examples/ssl_example.py index 410e9485..7379862c 100644 --- a/docs/examples/ssl_example.py +++ b/docs/examples/ssl_example.py @@ -5,42 +5,47 @@ import os from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode # Create tmp directory if it doesn't exist -parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) tmp_dir = os.path.join(parent_dir, "tmp") os.makedirs(tmp_dir, exist_ok=True) + async def main(): # Configure crawler to fetch SSL certificate config = CrawlerRunConfig( fetch_ssl_certificate=True, - cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates + cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates ) async with AsyncWebCrawler() as crawler: - result = await crawler.arun( - url='https://example.com', - config=config - ) - + result = await crawler.arun(url="https://example.com", config=config) + if result.success and result.ssl_certificate: cert = result.ssl_certificate - + # 1. Access certificate properties directly print("\nCertificate Information:") print(f"Issuer: {cert.issuer.get('CN', '')}") print(f"Valid until: {cert.valid_until}") print(f"Fingerprint: {cert.fingerprint}") - + # 2. Export certificate in different formats cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis print("\nCertificate exported to:") print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}") - - pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers + + pem_data = cert.to_pem( + os.path.join(tmp_dir, "certificate.pem") + ) # For web servers print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}") - - der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps + + der_data = cert.to_der( + os.path.join(tmp_dir, "certificate.der") + ) # For Java apps print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/docs/examples/summarize_page.py b/docs/examples/summarize_page.py index 85158999..da2bcd21 100644 --- a/docs/examples/summarize_page.py +++ b/docs/examples/summarize_page.py @@ -1,39 +1,41 @@ import os -import time import json from crawl4ai.web_crawler import WebCrawler from crawl4ai.chunking_strategy import * from crawl4ai.extraction_strategy import * from crawl4ai.crawler_strategy import * -url = r'https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot' +url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot" crawler = WebCrawler() crawler.warmup() from pydantic import BaseModel, Field + class PageSummary(BaseModel): title: str = Field(..., description="Title of the page.") summary: str = Field(..., description="Summary of the page.") brief_summary: str = Field(..., description="Brief summary of the page.") keywords: list = Field(..., description="Keywords assigned to the page.") + result = crawler.run( url=url, word_count_threshold=1, - extraction_strategy= LLMExtractionStrategy( - provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'), + extraction_strategy=LLMExtractionStrategy( + provider="openai/gpt-4o", + api_token=os.getenv("OPENAI_API_KEY"), schema=PageSummary.model_json_schema(), extraction_type="schema", - apply_chunking =False, - instruction="From the crawled content, extract the following details: "\ - "1. Title of the page "\ - "2. Summary of the page, which is a detailed summary "\ - "3. Brief summary of the page, which is a paragraph text "\ - "4. Keywords assigned to the page, which is a list of keywords. "\ - 'The extracted JSON format should look like this: '\ - '{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }' + apply_chunking=False, + instruction="From the crawled content, extract the following details: " + "1. Title of the page " + "2. Summary of the page, which is a detailed summary " + "3. Brief summary of the page, which is a paragraph text " + "4. Keywords assigned to the page, which is a list of keywords. " + "The extracted JSON format should look like this: " + '{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }', ), bypass_cache=True, ) diff --git a/docs/examples/v0.3.74.overview.py b/docs/examples/v0.3.74.overview.py index 362ae8fc..4938db7b 100644 --- a/docs/examples/v0.3.74.overview.py +++ b/docs/examples/v0.3.74.overview.py @@ -1,4 +1,5 @@ import os, sys + # append the parent directory to the sys.path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) @@ -13,19 +14,18 @@ import json from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai.content_filter_strategy import BM25ContentFilter + # 1. File Download Processing Example async def download_example(): """Example of downloading files from Python.org""" # downloads_path = os.path.join(os.getcwd(), "downloads") downloads_path = os.path.join(Path.home(), ".crawl4ai", "downloads") os.makedirs(downloads_path, exist_ok=True) - + print(f"Downloads will be saved to: {downloads_path}") - + async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=downloads_path, - verbose=True + accept_downloads=True, downloads_path=downloads_path, verbose=True ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", @@ -40,9 +40,9 @@ async def download_example(): } """, delay_before_return_html=1, # Wait 5 seconds to ensure download starts - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - + if result.downloaded_files: print("\nDownload successful!") print("Downloaded files:") @@ -52,25 +52,26 @@ async def download_example(): else: print("\nNo files were downloaded") + # 2. Local File and Raw HTML Processing Example async def local_and_raw_html_example(): """Example of processing local files and raw HTML""" # Create a sample HTML file sample_file = os.path.join(__data__, "sample.html") with open(sample_file, "w") as f: - f.write(""" + f.write( + """

Test Content

This is a test paragraph.

- """) - + """ + ) + async with AsyncWebCrawler(verbose=True) as crawler: # Process local file - local_result = await crawler.arun( - url=f"file://{os.path.abspath(sample_file)}" - ) - + local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}") + # Process raw HTML raw_html = """ @@ -78,16 +79,15 @@ async def local_and_raw_html_example():

This is a test of raw HTML processing.

""" - raw_result = await crawler.arun( - url=f"raw:{raw_html}" - ) - + raw_result = await crawler.arun(url=f"raw:{raw_html}") + # Clean up os.remove(sample_file) - + print("Local file content:", local_result.markdown) print("\nRaw HTML content:", raw_result.markdown) + # 3. Enhanced Markdown Generation Example async def markdown_generation_example(): """Example of enhanced markdown generation with citations and LLM-friendly features""" @@ -97,58 +97,66 @@ async def markdown_generation_example(): # user_query="History and cultivation", bm25_threshold=1.0 ) - + result = await crawler.arun( url="https://en.wikipedia.org/wiki/Apple", css_selector="main div#bodyContent", content_filter=content_filter, - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - - from crawl4ai import AsyncWebCrawler + from crawl4ai.content_filter_strategy import BM25ContentFilter - + result = await crawler.arun( url="https://en.wikipedia.org/wiki/Apple", css_selector="main div#bodyContent", - content_filter=BM25ContentFilter() + content_filter=BM25ContentFilter(), ) print(result.markdown_v2.fit_markdown) - + print("\nMarkdown Generation Results:") print(f"1. Original markdown length: {len(result.markdown)}") - print(f"2. New markdown versions (markdown_v2):") + print("2. New markdown versions (markdown_v2):") print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}") - print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}") - print(f" - References section length: {len(result.markdown_v2.references_markdown)}") + print( + f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}" + ) + print( + f" - References section length: {len(result.markdown_v2.references_markdown)}" + ) if result.markdown_v2.fit_markdown: - print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}") - + print( + f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}" + ) + # Save examples to files output_dir = os.path.join(__data__, "markdown_examples") os.makedirs(output_dir, exist_ok=True) - + # Save different versions with open(os.path.join(output_dir, "1_raw_markdown.md"), "w") as f: f.write(result.markdown_v2.raw_markdown) - + with open(os.path.join(output_dir, "2_citations_markdown.md"), "w") as f: f.write(result.markdown_v2.markdown_with_citations) - + with open(os.path.join(output_dir, "3_references.md"), "w") as f: f.write(result.markdown_v2.references_markdown) - + if result.markdown_v2.fit_markdown: with open(os.path.join(output_dir, "4_filtered_markdown.md"), "w") as f: f.write(result.markdown_v2.fit_markdown) - + print(f"\nMarkdown examples saved to: {output_dir}") - + # Show a sample of citations and references print("\nSample of markdown with citations:") print(result.markdown_v2.markdown_with_citations[:500] + "...\n") print("Sample of references:") - print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...") + print( + "\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..." + ) + # 4. Browser Management Example async def browser_management_example(): @@ -156,38 +164,38 @@ async def browser_management_example(): # Use the specified user directory path user_data_dir = os.path.join(Path.home(), ".crawl4ai", "browser_profile") os.makedirs(user_data_dir, exist_ok=True) - + print(f"Browser profile will be saved to: {user_data_dir}") - + async with AsyncWebCrawler( use_managed_browser=True, user_data_dir=user_data_dir, headless=False, - verbose=True + verbose=True, ) as crawler: - result = await crawler.arun( url="https://crawl4ai.com", # session_id="persistent_session_1", - cache_mode=CacheMode.BYPASS - ) + cache_mode=CacheMode.BYPASS, + ) # Use GitHub as an example - it's a good test for browser management # because it requires proper browser handling result = await crawler.arun( url="https://github.com/trending", # session_id="persistent_session_1", - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - + print("\nBrowser session result:", result.success) if result.success: - print("Page title:", result.metadata.get('title', 'No title found')) + print("Page title:", result.metadata.get("title", "No title found")) + # 5. API Usage Example async def api_example(): """Example of using the new API endpoints""" - api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" - headers = {'Authorization': f'Bearer {api_token}'} + api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code" + headers = {"Authorization": f"Bearer {api_token}"} async with aiohttp.ClientSession() as session: # Submit crawl job crawl_request = { @@ -199,25 +207,17 @@ async def api_example(): "name": "Hacker News Articles", "baseSelector": ".athing", "fields": [ - { - "name": "title", - "selector": ".title a", - "type": "text" - }, - { - "name": "score", - "selector": ".score", - "type": "text" - }, + {"name": "title", "selector": ".title a", "type": "text"}, + {"name": "score", "selector": ".score", "type": "text"}, { "name": "url", "selector": ".title a", "type": "attribute", - "attribute": "href" - } - ] + "attribute": "href", + }, + ], } - } + }, }, "crawler_params": { "headless": True, @@ -227,51 +227,50 @@ async def api_example(): # "screenshot": True, # "magic": True } - + async with session.post( - "http://localhost:11235/crawl", - json=crawl_request, - headers=headers + "http://localhost:11235/crawl", json=crawl_request, headers=headers ) as response: task_data = await response.json() task_id = task_data["task_id"] - + # Check task status while True: async with session.get( - f"http://localhost:11235/task/{task_id}", - headers=headers + f"http://localhost:11235/task/{task_id}", headers=headers ) as status_response: result = await status_response.json() print(f"Task status: {result['status']}") - + if result["status"] == "completed": print("Task completed!") print("Results:") - news = json.loads(result["results"][0]['extracted_content']) + news = json.loads(result["results"][0]["extracted_content"]) print(json.dumps(news[:4], indent=2)) break else: await asyncio.sleep(1) + # Main execution async def main(): # print("Running Crawl4AI feature examples...") - + # print("\n1. Running Download Example:") # await download_example() - + # print("\n2. Running Markdown Generation Example:") # await markdown_generation_example() - + # # print("\n3. Running Local and Raw HTML Example:") # await local_and_raw_html_example() - + # # print("\n4. Running Browser Management Example:") await browser_management_example() - + # print("\n5. Running API Example:") await api_example() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/docs/examples/v0_4_24_walkthrough.py b/docs/examples/v0_4_24_walkthrough.py index 135ac29c..996e7b04 100644 --- a/docs/examples/v0_4_24_walkthrough.py +++ b/docs/examples/v0_4_24_walkthrough.py @@ -10,18 +10,17 @@ import asyncio import os import json import re -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field +from typing import List from crawl4ai import ( AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode, LLMExtractionStrategy, - JsonCssExtractionStrategy + JsonCssExtractionStrategy, ) from crawl4ai.content_filter_strategy import RelevantContentFilter -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from bs4 import BeautifulSoup # Sample HTML for demonstrations @@ -52,17 +51,18 @@ SAMPLE_HTML = """ """ + async def demo_ssl_features(): """ Enhanced SSL & Security Features Demo ----------------------------------- - + This example demonstrates the new SSL certificate handling and security features: 1. Custom certificate paths 2. SSL verification options 3. HTTPS error handling 4. Certificate validation configurations - + These features are particularly useful when: - Working with self-signed certificates - Dealing with corporate proxies @@ -76,14 +76,11 @@ async def demo_ssl_features(): run_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, - fetch_ssl_certificate=True # Enable SSL certificate fetching + fetch_ssl_certificate=True, # Enable SSL certificate fetching ) async with AsyncWebCrawler(config=browser_config) as crawler: - result = await crawler.arun( - url="https://example.com", - config=run_config - ) + result = await crawler.arun(url="https://example.com", config=run_config) print(f"SSL Crawl Success: {result.success}") result.ssl_certificate.to_json( os.path.join(os.getcwd(), "ssl_certificate.json") @@ -91,11 +88,12 @@ async def demo_ssl_features(): if not result.success: print(f"SSL Error: {result.error_message}") + async def demo_content_filtering(): """ Smart Content Filtering Demo ---------------------- - + Demonstrates advanced content filtering capabilities: 1. Custom filter to identify and extract specific content 2. Integration with markdown generation @@ -110,87 +108,90 @@ async def demo_content_filtering(): super().__init__() # Add news-specific patterns self.negative_patterns = re.compile( - r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending', - re.I + r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending", + re.I, ) self.min_word_count = 30 # Higher threshold for news content - def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: + def filter_content( + self, html: str, min_word_threshold: int = None + ) -> List[str]: """ Implements news-specific content filtering logic. - + Args: html (str): HTML content to be filtered min_word_threshold (int, optional): Minimum word count threshold - + Returns: List[str]: List of filtered HTML content blocks """ if not html or not isinstance(html, str): return [] - - soup = BeautifulSoup(html, 'lxml') + + soup = BeautifulSoup(html, "lxml") if not soup.body: - soup = BeautifulSoup(f'{html}', 'lxml') - - body = soup.find('body') - + soup = BeautifulSoup(f"{html}", "lxml") + + body = soup.find("body") + # Extract chunks with metadata - chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count) - + chunks = self.extract_text_chunks( + body, min_word_threshold or self.min_word_count + ) + # Filter chunks based on news-specific criteria filtered_chunks = [] for _, text, tag_type, element in chunks: # Skip if element has negative class/id if self.is_excluded(element): continue - + # Headers are important in news articles - if tag_type == 'header': + if tag_type == "header": filtered_chunks.append(self.clean_element(element)) continue - + # For content, check word count and link density text = element.get_text(strip=True) if len(text.split()) >= (min_word_threshold or self.min_word_count): # Calculate link density - links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a')) + links_text = " ".join( + a.get_text(strip=True) for a in element.find_all("a") + ) link_density = len(links_text) / len(text) if text else 1 - + # Accept if link density is reasonable if link_density < 0.5: filtered_chunks.append(self.clean_element(element)) - + return filtered_chunks # Create markdown generator with custom filter - markdown_gen = DefaultMarkdownGenerator( - content_filter=CustomNewsFilter() - ) + markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter()) run_config = CrawlerRunConfig( - markdown_generator=markdown_gen, - cache_mode=CacheMode.BYPASS + markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS ) async with AsyncWebCrawler() as crawler: result = await crawler.arun( - url="https://news.ycombinator.com", - config=run_config + url="https://news.ycombinator.com", config=run_config ) print("Filtered Content Sample:") print(result.markdown[:500]) # Show first 500 chars + async def demo_json_extraction(): """ Improved JSON Extraction Demo --------------------------- - + Demonstrates the enhanced JSON extraction capabilities: 1. Base element attributes extraction 2. Complex nested structures 3. Multiple extraction patterns - + Key features shown: - Extracting attributes from base elements (href, data-* attributes) - Processing repeated patterns @@ -206,7 +207,7 @@ async def demo_json_extraction(): "baseSelector": "div.article-list", "baseFields": [ {"name": "list_id", "type": "attribute", "attribute": "data-list-id"}, - {"name": "category", "type": "attribute", "attribute": "data-category"} + {"name": "category", "type": "attribute", "attribute": "data-category"}, ], "fields": [ { @@ -214,8 +215,16 @@ async def demo_json_extraction(): "selector": "article.post", "type": "nested_list", "baseFields": [ - {"name": "post_id", "type": "attribute", "attribute": "data-post-id"}, - {"name": "author_id", "type": "attribute", "attribute": "data-author"} + { + "name": "post_id", + "type": "attribute", + "attribute": "data-post-id", + }, + { + "name": "author_id", + "type": "attribute", + "attribute": "data-author", + }, ], "fields": [ { @@ -223,60 +232,68 @@ async def demo_json_extraction(): "selector": "h2.title a", "type": "text", "baseFields": [ - {"name": "url", "type": "attribute", "attribute": "href"} - ] + { + "name": "url", + "type": "attribute", + "attribute": "href", + } + ], }, { "name": "author", "selector": "div.meta a.author", "type": "text", "baseFields": [ - {"name": "profile_url", "type": "attribute", "attribute": "href"} - ] - }, - { - "name": "date", - "selector": "span.date", - "type": "text" + { + "name": "profile_url", + "type": "attribute", + "attribute": "href", + } + ], }, + {"name": "date", "selector": "span.date", "type": "text"}, { "name": "read_more", "selector": "a.read-more", "type": "nested", "fields": [ {"name": "text", "type": "text"}, - {"name": "url", "type": "attribute", "attribute": "href"} - ] - } - ] + { + "name": "url", + "type": "attribute", + "attribute": "href", + }, + ], + }, + ], } - ] + ], } ) # Demonstrate extraction from raw HTML run_config = CrawlerRunConfig( - extraction_strategy=json_strategy, - cache_mode=CacheMode.BYPASS + extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS ) async with AsyncWebCrawler() as crawler: result = await crawler.arun( url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML - config=run_config + config=run_config, ) print("Extracted Content:") print(result.extracted_content) + async def demo_input_formats(): """ Input Format Handling Demo ---------------------- - + Demonstrates how LLM extraction can work with different input formats: 1. Markdown (default) - Good for simple text extraction 2. HTML - Better when you need structure and attributes - + This example shows how HTML input can be beneficial when: - You need to understand the DOM structure - You want to extract both visible text and HTML attributes @@ -350,7 +367,7 @@ async def demo_input_formats(): """ - + # Use raw:// prefix to pass HTML content directly url = f"raw://{dummy_html}" @@ -359,18 +376,30 @@ async def demo_input_formats(): # Define our schema using Pydantic class JobRequirement(BaseModel): - category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)") - items: List[str] = Field(description="List of specific requirements in this category") - priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context") + category: str = Field( + description="Category of the requirement (e.g., Technical, Soft Skills)" + ) + items: List[str] = Field( + description="List of specific requirements in this category" + ) + priority: str = Field( + description="Priority level (Required/Preferred) based on the HTML class or context" + ) class JobPosting(BaseModel): title: str = Field(description="Job title") department: str = Field(description="Department or team") location: str = Field(description="Job location, including remote options") salary_range: Optional[str] = Field(description="Salary range if specified") - requirements: List[JobRequirement] = Field(description="Categorized job requirements") - application_deadline: Optional[str] = Field(description="Application deadline if specified") - contact_info: Optional[dict] = Field(description="Contact information from footer or contact section") + requirements: List[JobRequirement] = Field( + description="Categorized job requirements" + ) + application_deadline: Optional[str] = Field( + description="Application deadline if specified" + ) + contact_info: Optional[dict] = Field( + description="Contact information from footer or contact section" + ) # First try with markdown (default) markdown_strategy = LLMExtractionStrategy( @@ -382,7 +411,7 @@ async def demo_input_formats(): Extract job posting details into structured data. Focus on the visible text content and organize requirements into categories. """, - input_format="markdown" # default + input_format="markdown", # default ) # Then with HTML for better structure understanding @@ -400,34 +429,25 @@ async def demo_input_formats(): Use HTML attributes and classes to enhance extraction accuracy. """, - input_format="html" # explicitly use HTML + input_format="html", # explicitly use HTML ) async with AsyncWebCrawler() as crawler: # Try with markdown first - markdown_config = CrawlerRunConfig( - extraction_strategy=markdown_strategy - ) - markdown_result = await crawler.arun( - url=url, - config=markdown_config - ) + markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy) + markdown_result = await crawler.arun(url=url, config=markdown_config) print("\nMarkdown-based Extraction Result:") items = json.loads(markdown_result.extracted_content) print(json.dumps(items, indent=2)) # Then with HTML for better structure understanding - html_config = CrawlerRunConfig( - extraction_strategy=html_strategy - ) - html_result = await crawler.arun( - url=url, - config=html_config - ) + html_config = CrawlerRunConfig(extraction_strategy=html_strategy) + html_result = await crawler.arun(url=url, config=html_config) print("\nHTML-based Extraction Result:") items = json.loads(html_result.extracted_content) print(json.dumps(items, indent=2)) + # Main execution async def main(): print("Crawl4AI v0.4.24 Feature Walkthrough") @@ -439,5 +459,6 @@ async def main(): await demo_json_extraction() # await demo_input_formats() + if __name__ == "__main__": asyncio.run(main()) diff --git a/main.py b/main.py index 21e411d0..1f9e01a3 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,9 @@ import asyncio, os -from fastapi import FastAPI, HTTPException, BackgroundTasks, Request -from fastapi.responses import JSONResponse -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles -from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates -from fastapi.exceptions import RequestValidationError -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import FileResponse from fastapi.responses import RedirectResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import Depends, Security @@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union import psutil import time import uuid -from collections import defaultdict -from urllib.parse import urlparse import math import logging from enum import Enum from dataclasses import dataclass -import json from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode from crawl4ai.config import MIN_WORD_THRESHOLD from crawl4ai.extraction_strategy import ( @@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class TaskStatus(str, Enum): PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" + class CrawlerType(str, Enum): BASIC = "basic" LLM = "llm" COSINE = "cosine" JSON_CSS = "json_css" + class ExtractionConfig(BaseModel): type: CrawlerType params: Dict[str, Any] = {} + class ChunkingStrategy(BaseModel): type: str params: Dict[str, Any] = {} + class ContentFilter(BaseModel): type: str = "bm25" params: Dict[str, Any] = {} + class CrawlRequest(BaseModel): urls: Union[HttpUrl, List[HttpUrl]] word_count_threshold: int = MIN_WORD_THRESHOLD @@ -77,9 +75,10 @@ class CrawlRequest(BaseModel): session_id: Optional[str] = None cache_mode: Optional[CacheMode] = CacheMode.ENABLED priority: int = Field(default=5, ge=1, le=10) - ttl: Optional[int] = 3600 + ttl: Optional[int] = 3600 crawler_params: Dict[str, Any] = {} + @dataclass class TaskInfo: id: str @@ -89,6 +88,7 @@ class TaskInfo: created_at: float = time.time() ttl: int = 3600 + class ResourceMonitor: def __init__(self, max_concurrent_tasks: int = 10): self.max_concurrent_tasks = max_concurrent_tasks @@ -106,7 +106,9 @@ class ResourceMonitor: mem_usage = psutil.virtual_memory().percent / 100 cpu_usage = psutil.cpu_percent() / 100 - memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold) + memory_factor = max( + 0, (self.memory_threshold - mem_usage) / self.memory_threshold + ) cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) self._last_available_slots = math.floor( @@ -116,6 +118,7 @@ class ResourceMonitor: return self._last_available_slots + class TaskManager: def __init__(self, cleanup_interval: int = 300): self.tasks: Dict[str, TaskInfo] = {} @@ -149,12 +152,16 @@ class TaskManager: except asyncio.TimeoutError: try: # Then try low priority - _, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1) + _, task_id = await asyncio.wait_for( + self.low_priority.get(), timeout=0.1 + ) return task_id except asyncio.TimeoutError: return None - def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None): + def update_task( + self, task_id: str, status: TaskStatus, result: Any = None, error: str = None + ): if task_id in self.tasks: task_info = self.tasks[task_id] task_info.status = status @@ -180,6 +187,7 @@ class TaskManager: except Exception as e: logger.error(f"Error in cleanup loop: {e}") + class CrawlerPool: def __init__(self, max_size: int = 10): self.max_size = max_size @@ -222,6 +230,7 @@ class CrawlerPool: await crawler.__aexit__(None, None, None) self.active_crawlers.clear() + class CrawlerService: def __init__(self, max_concurrent_tasks: int = 10): self.resource_monitor = ResourceMonitor(max_concurrent_tasks) @@ -258,10 +267,10 @@ class CrawlerService: async def submit_task(self, request: CrawlRequest) -> str: task_id = str(uuid.uuid4()) await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) - + # Store request data with task self.task_manager.tasks[task_id].request = request - + return task_id async def _process_queue(self): @@ -286,9 +295,11 @@ class CrawlerService: try: crawler = await self.crawler_pool.acquire(**request.crawler_params) - - extraction_strategy = self._create_extraction_strategy(request.extraction_config) - + + extraction_strategy = self._create_extraction_strategy( + request.extraction_config + ) + if isinstance(request.urls, list): results = await crawler.arun_many( urls=[str(url) for url in request.urls], @@ -318,16 +329,21 @@ class CrawlerService: ) await self.crawler_pool.release(crawler) - self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results) + self.task_manager.update_task( + task_id, TaskStatus.COMPLETED, results + ) except Exception as e: logger.error(f"Error processing task {task_id}: {str(e)}") - self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e)) + self.task_manager.update_task( + task_id, TaskStatus.FAILED, error=str(e) + ) except Exception as e: logger.error(f"Error in queue processing: {str(e)}") await asyncio.sleep(1) + app = FastAPI(title="Crawl4AI API") # CORS configuration @@ -344,6 +360,7 @@ app.add_middleware( security = HTTPBearer() CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN") + async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): if not CRAWL4AI_API_TOKEN: return credentials # No token verification if CRAWL4AI_API_TOKEN is not set @@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu raise HTTPException(status_code=401, detail="Invalid token") return credentials + def secure_endpoint(): """Returns security dependency only if CRAWL4AI_API_TOKEN is set""" return Depends(verify_token) if CRAWL4AI_API_TOKEN else None + # Check if site directory exists if os.path.exists(__location__ + "/site"): # Mount the site directory as a static directory @@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site") crawler_service = CrawlerService() + @app.on_event("startup") async def startup_event(): await crawler_service.start() + @app.on_event("shutdown") async def shutdown_event(): await crawler_service.stop() + @app.get("/") def read_root(): if os.path.exists(__location__ + "/site"): @@ -379,12 +401,16 @@ def read_root(): # Return a json response return {"message": "Crawl4AI API service is running"} + @app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) async def crawl(request: CrawlRequest) -> Dict[str, str]: task_id = await crawler_service.submit_task(request) return {"task_id": task_id} -@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) + +@app.get( + "/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] +) async def get_task_status(task_id: str): task_info = crawler_service.task_manager.get_task(task_id) if not task_info: @@ -406,36 +432,45 @@ async def get_task_status(task_id: str): return response + @app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: task_id = await crawler_service.submit_task(request) - + # Wait up to 60 seconds for task completion for _ in range(60): task_info = crawler_service.task_manager.get_task(task_id) if not task_info: raise HTTPException(status_code=404, detail="Task not found") - + if task_info.status == TaskStatus.COMPLETED: # Return same format as /task/{task_id} endpoint if isinstance(task_info.result, list): - return {"status": task_info.status, "results": [result.dict() for result in task_info.result]} + return { + "status": task_info.status, + "results": [result.dict() for result in task_info.result], + } return {"status": task_info.status, "result": task_info.result.dict()} - + if task_info.status == TaskStatus.FAILED: raise HTTPException(status_code=500, detail=task_info.error) - + await asyncio.sleep(1) - + # If we get here, task didn't complete within timeout raise HTTPException(status_code=408, detail="Task timed out") -@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) + +@app.post( + "/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] +) async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: try: crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) - extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config) - + extraction_strategy = crawler_service._create_extraction_strategy( + request.extraction_config + ) + try: if isinstance(request.urls, list): results = await crawler.arun_many( @@ -470,7 +505,8 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: except Exception as e: logger.error(f"Error in direct crawl: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - + + @app.get("/health") async def health_check(): available_slots = await crawler_service.resource_monitor.get_available_slots() @@ -482,6 +518,8 @@ async def health_check(): "cpu_usage": psutil.cpu_percent(), } + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=11235) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=11235) diff --git a/setup.py b/setup.py index dad3199d..16b1b53c 100644 --- a/setup.py +++ b/setup.py @@ -51,9 +51,7 @@ setup( author_email="unclecode@kidocode.com", license="MIT", packages=find_packages(), - package_data={ - 'crawl4ai': ['js_snippet/*.js'] - }, + package_data={"crawl4ai": ["js_snippet/*.js"]}, classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/tests/async/test_0.4.2_browser_manager.py b/tests/async/test_0.4.2_browser_manager.py index 9bb19582..21b4be11 100644 --- a/tests/async/test_0.4.2_browser_manager.py +++ b/tests/async/test_0.4.2_browser_manager.py @@ -1,17 +1,18 @@ -import os, sys -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) - -import os, sys +import os +import sys import asyncio from crawl4ai import AsyncWebCrawler, CacheMode -from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator -# Assuming that the changes made allow different configurations +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + + +# Assuming that the changes made allow different configurations # for managed browser, persistent context, and so forth. + async def test_default_headless(): async with AsyncWebCrawler( headless=True, @@ -24,13 +25,14 @@ async def test_default_headless(): # Testing normal ephemeral context ) as crawler: result = await crawler.arun( - url='https://www.kidocode.com/degrees/technology', + url="https://www.kidocode.com/degrees/technology", cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_default_headless] success:", result.success) print("HTML length:", len(result.html if result.html else "")) - + + async def test_managed_browser_persistent(): # Treating use_persistent_context=True as managed_browser scenario. async with AsyncWebCrawler( @@ -44,13 +46,14 @@ async def test_managed_browser_persistent(): # This should store and reuse profile data across runs ) as crawler: result = await crawler.arun( - url='https://www.google.com', + url="https://www.google.com", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_managed_browser_persistent] success:", result.success) print("HTML length:", len(result.html if result.html else "")) + async def test_session_reuse(): # Test creating a session, using it for multiple calls session_id = "my_session" @@ -62,25 +65,25 @@ async def test_session_reuse(): use_managed_browser=False, use_persistent_context=False, ) as crawler: - # First call: create session result1 = await crawler.arun( - url='https://www.example.com', + url="https://www.example.com", cache_mode=CacheMode.BYPASS, session_id=session_id, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_session_reuse first call] success:", result1.success) - + # Second call: same session, possibly cookie retained result2 = await crawler.arun( - url='https://www.example.com/about', + url="https://www.example.com/about", cache_mode=CacheMode.BYPASS, session_id=session_id, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_session_reuse second call] success:", result2.success) + async def test_magic_mode(): # Test magic mode with override_navigator and simulate_user async with AsyncWebCrawler( @@ -95,13 +98,14 @@ async def test_magic_mode(): simulate_user=True, ) as crawler: result = await crawler.arun( - url='https://www.kidocode.com/degrees/business', + url="https://www.kidocode.com/degrees/business", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_magic_mode] success:", result.success) print("HTML length:", len(result.html if result.html else "")) + async def test_proxy_settings(): # Test with a proxy (if available) to ensure code runs with proxy async with AsyncWebCrawler( @@ -113,14 +117,15 @@ async def test_proxy_settings(): use_persistent_context=False, ) as crawler: result = await crawler.arun( - url='https://httpbin.org/ip', + url="https://httpbin.org/ip", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_proxy_settings] success:", result.success) if result.success: print("HTML preview:", result.html[:200] if result.html else "") + async def test_ignore_https_errors(): # Test ignore HTTPS errors with a self-signed or invalid cert domain # This is just conceptual, the domain should be one that triggers SSL error. @@ -134,12 +139,13 @@ async def test_ignore_https_errors(): use_persistent_context=False, ) as crawler: result = await crawler.arun( - url='https://self-signed.badssl.com/', + url="https://self-signed.badssl.com/", cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) + markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) print("[test_ignore_https_errors] success:", result.success) + async def main(): print("Running tests...") # await test_default_headless() @@ -149,5 +155,6 @@ async def main(): # await test_proxy_settings() await test_ignore_https_errors() + if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/async/test_0.4.2_config_params.py b/tests/async/test_0.4.2_config_params.py index 623ac3ab..9a15f864 100644 --- a/tests/async/test_0.4.2_config_params.py +++ b/tests/async/test_0.4.2_config_params.py @@ -1,15 +1,16 @@ import os, sys + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) import asyncio from crawl4ai import AsyncWebCrawler, CacheMode -from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig +from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.chunking_strategy import RegexChunking -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator + # Category 1: Browser Configuration Tests async def test_browser_config_object(): @@ -21,29 +22,31 @@ async def test_browser_config_object(): viewport_height=1080, use_managed_browser=True, user_agent_mode="random", - user_agent_generator_config={"device_type": "desktop", "os_type": "windows"} + user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}, ) - + async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler: - result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS) + result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS) assert result.success, "Browser config crawl failed" assert len(result.html) > 0, "No HTML content retrieved" + async def test_browser_performance_config(): """Test browser configurations focused on performance""" browser_config = BrowserConfig( text_mode=True, light_mode=True, - extra_args=['--disable-gpu', '--disable-software-rasterizer'], + extra_args=["--disable-gpu", "--disable-software-rasterizer"], ignore_https_errors=True, - java_script_enabled=False + java_script_enabled=False, ) - + async with AsyncWebCrawler(config=browser_config) as crawler: - result = await crawler.arun('https://example.com') + result = await crawler.arun("https://example.com") assert result.success, "Performance optimized crawl failed" assert result.status_code == 200, "Unexpected status code" + # Category 2: Content Processing Tests async def test_content_extraction_config(): """Test content extraction with various strategies""" @@ -53,24 +56,20 @@ async def test_content_extraction_config(): schema={ "name": "article", "baseSelector": "div", - "fields": [{ - "name": "title", - "selector": "h1", - "type": "text" - }] + "fields": [{"name": "title", "selector": "h1", "type": "text"}], } ), chunking_strategy=RegexChunking(), - content_filter=PruningContentFilter() + content_filter=PruningContentFilter(), ) - + async with AsyncWebCrawler() as crawler: result = await crawler.arun( - 'https://example.com/article', - config=crawler_config + "https://example.com/article", config=crawler_config ) assert result.extracted_content is not None, "Content extraction failed" - assert 'title' in result.extracted_content, "Missing expected content field" + assert "title" in result.extracted_content, "Missing expected content field" + # Category 3: Cache and Session Management Tests async def test_cache_and_session_management(): @@ -79,25 +78,20 @@ async def test_cache_and_session_management(): crawler_config = CrawlerRunConfig( cache_mode=CacheMode.WRITE_ONLY, process_iframes=True, - remove_overlay_elements=True + remove_overlay_elements=True, ) - + async with AsyncWebCrawler(config=browser_config) as crawler: # First request - should write to cache - result1 = await crawler.arun( - 'https://example.com', - config=crawler_config - ) - + result1 = await crawler.arun("https://example.com", config=crawler_config) + # Second request - should use fresh fetch due to WRITE_ONLY mode - result2 = await crawler.arun( - 'https://example.com', - config=crawler_config - ) - + result2 = await crawler.arun("https://example.com", config=crawler_config) + assert result1.success and result2.success, "Cache mode crawl failed" assert result1.html == result2.html, "Inconsistent results between requests" + # Category 4: Media Handling Tests async def test_media_handling_config(): """Test configurations related to media handling""" @@ -107,24 +101,22 @@ async def test_media_handling_config(): viewport_width=1920, viewport_height=1080, accept_downloads=True, - downloads_path= os.path.expanduser("~/.crawl4ai/downloads") + downloads_path=os.path.expanduser("~/.crawl4ai/downloads"), ) crawler_config = CrawlerRunConfig( screenshot=True, pdf=True, adjust_viewport_to_content=True, wait_for_images=True, - screenshot_height_threshold=20000 + screenshot_height_threshold=20000, ) - + async with AsyncWebCrawler(config=browser_config) as crawler: - result = await crawler.arun( - 'https://example.com', - config=crawler_config - ) + result = await crawler.arun("https://example.com", config=crawler_config) assert result.screenshot is not None, "Screenshot capture failed" assert result.pdf is not None, "PDF generation failed" + # Category 5: Anti-Bot and Site Interaction Tests async def test_antibot_config(): """Test configurations for handling anti-bot measures""" @@ -135,76 +127,64 @@ async def test_antibot_config(): wait_for="js:()=>document.querySelector('body')", delay_before_return_html=1.0, log_console=True, - cache_mode=CacheMode.BYPASS + cache_mode=CacheMode.BYPASS, ) - + async with AsyncWebCrawler() as crawler: - result = await crawler.arun( - 'https://example.com', - config=crawler_config - ) + result = await crawler.arun("https://example.com", config=crawler_config) assert result.success, "Anti-bot measure handling failed" + # Category 6: Parallel Processing Tests async def test_parallel_processing(): """Test parallel processing capabilities""" - crawler_config = CrawlerRunConfig( - mean_delay=0.5, - max_range=1.0, - semaphore_count=5 - ) - - urls = [ - 'https://example.com/1', - 'https://example.com/2', - 'https://example.com/3' - ] - + crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5) + + urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"] + async with AsyncWebCrawler() as crawler: - results = await crawler.arun_many( - urls, - config=crawler_config - ) + results = await crawler.arun_many(urls, config=crawler_config) assert len(results) == len(urls), "Not all URLs were processed" assert all(r.success for r in results), "Some parallel requests failed" + # Category 7: Backwards Compatibility Tests async def test_legacy_parameter_support(): """Test that legacy parameters still work""" async with AsyncWebCrawler( - headless=True, - browser_type="chromium", - viewport_width=1024, - viewport_height=768 + headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768 ) as crawler: result = await crawler.arun( - 'https://example.com', + "https://example.com", screenshot=True, word_count_threshold=200, bypass_cache=True, - css_selector=".main-content" + css_selector=".main-content", ) assert result.success, "Legacy parameter support failed" + # Category 8: Mixed Configuration Tests async def test_mixed_config_usage(): """Test mixing new config objects with legacy parameters""" browser_config = BrowserConfig(headless=True) crawler_config = CrawlerRunConfig(screenshot=True) - + async with AsyncWebCrawler( config=browser_config, - verbose=True # legacy parameter + verbose=True, # legacy parameter ) as crawler: result = await crawler.arun( - 'https://example.com', + "https://example.com", config=crawler_config, cache_mode=CacheMode.BYPASS, # legacy parameter - css_selector="body" # legacy parameter + css_selector="body", # legacy parameter ) assert result.success, "Mixed configuration usage failed" + if __name__ == "__main__": + async def run_tests(): test_functions = [ test_browser_config_object, @@ -217,7 +197,7 @@ if __name__ == "__main__": # test_legacy_parameter_support, # test_mixed_config_usage ] - + for test in test_functions: print(f"\nRunning {test.__name__}...") try: @@ -227,5 +207,5 @@ if __name__ == "__main__": print(f"✗ {test.__name__} failed: {str(e)}") except Exception as e: print(f"✗ {test.__name__} error: {str(e)}") - - asyncio.run(run_tests()) \ No newline at end of file + + asyncio.run(run_tests()) diff --git a/tests/async/test_async_doanloader.py b/tests/async/test_async_doanloader.py index 4798b4ca..055886cb 100644 --- a/tests/async/test_async_doanloader.py +++ b/tests/async/test_async_doanloader.py @@ -4,7 +4,6 @@ import asyncio import shutil from typing import List import tempfile -import time # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -12,28 +11,27 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + class TestDownloads: def __init__(self): self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_") self.download_dir = os.path.join(self.temp_dir, "downloads") os.makedirs(self.download_dir, exist_ok=True) self.results: List[str] = [] - + def cleanup(self): shutil.rmtree(self.temp_dir) - + def log_result(self, test_name: str, success: bool, message: str = ""): result = f"{'✅' if success else '❌'} {test_name}: {message}" self.results.append(result) print(result) - + async def test_basic_download(self): """Test basic file download functionality""" try: async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=self.download_dir, - verbose=True + accept_downloads=True, downloads_path=self.download_dir, verbose=True ) as crawler: # Python.org downloads page typically has stable download links result = await crawler.arun( @@ -42,14 +40,19 @@ class TestDownloads: // Click first download link const downloadLink = document.querySelector('a[href$=".exe"]'); if (downloadLink) downloadLink.click(); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 0 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 0 self.log_result( "Basic Download", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "No files downloaded", ) except Exception as e: self.log_result("Basic Download", False, str(e)) @@ -59,27 +62,32 @@ class TestDownloads: try: user_data_dir = os.path.join(self.temp_dir, "user_data") os.makedirs(user_data_dir, exist_ok=True) - + async with AsyncWebCrawler( accept_downloads=True, downloads_path=self.download_dir, use_persistent_context=True, user_data_dir=user_data_dir, - verbose=True + verbose=True, ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", js_code=""" const downloadLink = document.querySelector('a[href$=".exe"]'); if (downloadLink) downloadLink.click(); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 0 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 0 self.log_result( "Persistent Context Download", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "No files downloaded", ) except Exception as e: self.log_result("Persistent Context Download", False, str(e)) @@ -88,9 +96,7 @@ class TestDownloads: """Test multiple simultaneous downloads""" try: async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=self.download_dir, - verbose=True + accept_downloads=True, downloads_path=self.download_dir, verbose=True ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", @@ -98,14 +104,19 @@ class TestDownloads: // Click multiple download links const downloadLinks = document.querySelectorAll('a[href$=".exe"]'); downloadLinks.forEach(link => link.click()); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 1 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 1 self.log_result( "Multiple Downloads", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "Not enough files downloaded", ) except Exception as e: self.log_result("Multiple Downloads", False, str(e)) @@ -113,49 +124,51 @@ class TestDownloads: async def test_different_browsers(self): """Test downloads across different browser types""" browsers = ["chromium", "firefox", "webkit"] - + for browser_type in browsers: try: async with AsyncWebCrawler( accept_downloads=True, downloads_path=self.download_dir, browser_type=browser_type, - verbose=True + verbose=True, ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", js_code=""" const downloadLink = document.querySelector('a[href$=".exe"]'); if (downloadLink) downloadLink.click(); - """ + """, + ) + + success = ( + result.downloaded_files is not None + and len(result.downloaded_files) > 0 ) - - success = result.downloaded_files is not None and len(result.downloaded_files) > 0 self.log_result( f"{browser_type.title()} Download", success, - f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" + f"Downloaded {len(result.downloaded_files or [])} files" + if success + else "No files downloaded", ) except Exception as e: self.log_result(f"{browser_type.title()} Download", False, str(e)) async def test_edge_cases(self): """Test various edge cases""" - + # Test 1: Downloads without specifying download path try: - async with AsyncWebCrawler( - accept_downloads=True, - verbose=True - ) as crawler: + async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()" + js_code="document.querySelector('a[href$=\".exe\"]').click()", ) self.log_result( "Default Download Path", True, - f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}" + f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}", ) except Exception as e: self.log_result("Default Download Path", False, str(e)) @@ -165,31 +178,34 @@ class TestDownloads: async with AsyncWebCrawler( accept_downloads=True, downloads_path="/invalid/path/that/doesnt/exist", - verbose=True + verbose=True, ) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()" + js_code="document.querySelector('a[href$=\".exe\"]').click()", ) - self.log_result("Invalid Download Path", False, "Should have raised an error") - except Exception as e: - self.log_result("Invalid Download Path", True, "Correctly handled invalid path") + self.log_result( + "Invalid Download Path", False, "Should have raised an error" + ) + except Exception: + self.log_result( + "Invalid Download Path", True, "Correctly handled invalid path" + ) # Test 3: Download with accept_downloads=False try: - async with AsyncWebCrawler( - accept_downloads=False, - verbose=True - ) as crawler: + async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler: result = await crawler.arun( url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()" + js_code="document.querySelector('a[href$=\".exe\"]').click()", ) success = result.downloaded_files is None self.log_result( "Disabled Downloads", success, - "Correctly ignored downloads" if success else "Unexpectedly downloaded files" + "Correctly ignored downloads" + if success + else "Unexpectedly downloaded files", ) except Exception as e: self.log_result("Disabled Downloads", False, str(e)) @@ -197,33 +213,35 @@ class TestDownloads: async def run_all_tests(self): """Run all test cases""" print("\n🧪 Running Download Tests...\n") - + test_methods = [ self.test_basic_download, self.test_persistent_context_download, self.test_multiple_downloads, self.test_different_browsers, - self.test_edge_cases + self.test_edge_cases, ] - + for test in test_methods: print(f"\n📝 Running {test.__doc__}...") await test() await asyncio.sleep(2) # Brief pause between tests - + print("\n📊 Test Results Summary:") for result in self.results: print(result) - - successes = len([r for r in self.results if '✅' in r]) + + successes = len([r for r in self.results if "✅" in r]) total = len(self.results) print(f"\nTotal: {successes}/{total} tests passed") - + self.cleanup() + async def main(): tester = TestDownloads() await tester.run_all_tests() + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/async/test_basic_crawling.py b/tests/async/test_basic_crawling.py index ce38ac2f..ee4bb633 100644 --- a/tests/async/test_basic_crawling.py +++ b/tests/async/test_basic_crawling.py @@ -1,15 +1,17 @@ import os import sys import pytest -import asyncio import time # Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_successful_crawl(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -21,6 +23,7 @@ async def test_successful_crawl(): assert result.markdown assert result.cleaned_html + @pytest.mark.asyncio async def test_invalid_url(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -29,19 +32,21 @@ async def test_invalid_url(): assert not result.success assert result.error_message + @pytest.mark.asyncio async def test_multiple_urls(): async with AsyncWebCrawler(verbose=True) as crawler: urls = [ "https://www.nbcnews.com/business", "https://www.example.com", - "https://www.python.org" + "https://www.python.org", ] results = await crawler.arun_many(urls=urls, bypass_cache=True) assert len(results) == len(urls) assert all(result.success for result in results) assert all(result.html for result in results) + @pytest.mark.asyncio async def test_javascript_execution(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -51,6 +56,7 @@ async def test_javascript_execution(): assert result.success assert "

Modified by JS

" in result.html + @pytest.mark.asyncio async def test_concurrent_crawling_performance(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -59,23 +65,26 @@ async def test_concurrent_crawling_performance(): "https://www.example.com", "https://www.python.org", "https://www.github.com", - "https://www.stackoverflow.com" + "https://www.stackoverflow.com", ] - + start_time = time.time() results = await crawler.arun_many(urls=urls, bypass_cache=True) end_time = time.time() - + total_time = end_time - start_time print(f"Total time for concurrent crawling: {total_time:.2f} seconds") - + assert all(result.success for result in results) assert len(results) == len(urls) - + # Assert that concurrent crawling is faster than sequential # This multiplier may need adjustment based on the number of URLs and their complexity - assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" + assert ( + total_time < len(urls) * 5 + ), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_caching.py b/tests/async/test_caching.py index 589beca9..d7f6efb5 100644 --- a/tests/async/test_caching.py +++ b/tests/async/test_caching.py @@ -9,74 +9,79 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_caching(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - + # First crawl (should not use cache) start_time = asyncio.get_event_loop().time() result1 = await crawler.arun(url=url, bypass_cache=True) end_time = asyncio.get_event_loop().time() time_taken1 = end_time - start_time - + assert result1.success - + # Second crawl (should use cache) start_time = asyncio.get_event_loop().time() result2 = await crawler.arun(url=url, bypass_cache=False) end_time = asyncio.get_event_loop().time() time_taken2 = end_time - start_time - + assert result2.success assert time_taken2 < time_taken1 # Cached result should be faster + @pytest.mark.asyncio async def test_bypass_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - + # First crawl result1 = await crawler.arun(url=url, bypass_cache=False) assert result1.success - + # Second crawl with bypass_cache=True result2 = await crawler.arun(url=url, bypass_cache=True) assert result2.success - + # Content should be different (or at least, not guaranteed to be the same) assert result1.html != result2.html or result1.markdown != result2.markdown + @pytest.mark.asyncio async def test_clear_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - + # Crawl and cache await crawler.arun(url=url, bypass_cache=False) - + # Clear cache await crawler.aclear_cache() - + # Check cache size cache_size = await crawler.aget_cache_size() assert cache_size == 0 + @pytest.mark.asyncio async def test_flush_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - + # Crawl and cache await crawler.arun(url=url, bypass_cache=False) - + # Flush cache await crawler.aflush_cache() - + # Check cache size cache_size = await crawler.aget_cache_size() assert cache_size == 0 + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_chunking_and_extraction_strategies.py b/tests/async/test_chunking_and_extraction_strategies.py index af1c9fbd..ab9daddc 100644 --- a/tests/async/test_chunking_and_extraction_strategies.py +++ b/tests/async/test_chunking_and_extraction_strategies.py @@ -1,7 +1,6 @@ import os import sys import pytest -import asyncio import json # Add the parent directory to the Python path @@ -9,8 +8,9 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler -from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking -from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy +from crawl4ai.chunking_strategy import RegexChunking +from crawl4ai.extraction_strategy import LLMExtractionStrategy + @pytest.mark.asyncio async def test_regex_chunking(): @@ -18,15 +18,14 @@ async def test_regex_chunking(): url = "https://www.nbcnews.com/business" chunking_strategy = RegexChunking(patterns=["\n\n"]) result = await crawler.arun( - url=url, - chunking_strategy=chunking_strategy, - bypass_cache=True + url=url, chunking_strategy=chunking_strategy, bypass_cache=True ) assert result.success assert result.extracted_content chunks = json.loads(result.extracted_content) assert len(chunks) > 1 # Ensure multiple chunks were created + # @pytest.mark.asyncio # async def test_cosine_strategy(): # async with AsyncWebCrawler(verbose=True) as crawler: @@ -43,25 +42,25 @@ async def test_regex_chunking(): # assert len(extracted_data) > 0 # assert all('tags' in item for item in extracted_data) + @pytest.mark.asyncio async def test_llm_extraction_strategy(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" extraction_strategy = LLMExtractionStrategy( provider="openai/gpt-4o-mini", - api_token=os.getenv('OPENAI_API_KEY'), - instruction="Extract only content related to technology" + api_token=os.getenv("OPENAI_API_KEY"), + instruction="Extract only content related to technology", ) result = await crawler.arun( - url=url, - extraction_strategy=extraction_strategy, - bypass_cache=True + url=url, extraction_strategy=extraction_strategy, bypass_cache=True ) assert result.success assert result.extracted_content extracted_data = json.loads(result.extracted_content) assert len(extracted_data) > 0 - assert all('content' in item for item in extracted_data) + assert all("content" in item for item in extracted_data) + # @pytest.mark.asyncio # async def test_combined_chunking_and_extraction(): @@ -84,4 +83,4 @@ async def test_llm_extraction_strategy(): # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_content_extraction.py b/tests/async/test_content_extraction.py index 7604db20..9372387a 100644 --- a/tests/async/test_content_extraction.py +++ b/tests/async/test_content_extraction.py @@ -1,8 +1,6 @@ import os import sys import pytest -import asyncio -import json # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -10,6 +8,7 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_extract_markdown(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -20,6 +19,7 @@ async def test_extract_markdown(): assert isinstance(result.markdown, str) assert len(result.markdown) > 0 + @pytest.mark.asyncio async def test_extract_cleaned_html(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -30,6 +30,7 @@ async def test_extract_cleaned_html(): assert isinstance(result.cleaned_html, str) assert len(result.cleaned_html) > 0 + @pytest.mark.asyncio async def test_extract_media(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -46,6 +47,7 @@ async def test_extract_media(): assert "alt" in image assert "type" in image + @pytest.mark.asyncio async def test_extract_links(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -63,6 +65,7 @@ async def test_extract_links(): assert "href" in link assert "text" in link + @pytest.mark.asyncio async def test_extract_metadata(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -75,16 +78,20 @@ async def test_extract_metadata(): assert "title" in metadata assert isinstance(metadata["title"], str) + @pytest.mark.asyncio async def test_css_selector_extraction(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" css_selector = "h1, h2, h3" - result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector) + result = await crawler.arun( + url=url, bypass_cache=True, css_selector=css_selector + ) assert result.success assert result.markdown assert all(heading in result.markdown for heading in ["#", "##", "###"]) + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_content_filter_bm25.py b/tests/async/test_content_filter_bm25.py index a873c414..f05a8af7 100644 --- a/tests/async/test_content_filter_bm25.py +++ b/tests/async/test_content_filter_bm25.py @@ -1,7 +1,6 @@ import os, sys import pytest from bs4 import BeautifulSoup -from typing import List # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -9,6 +8,7 @@ sys.path.append(parent_dir) from crawl4ai.content_filter_strategy import BM25ContentFilter + @pytest.fixture def basic_html(): return """ @@ -28,6 +28,7 @@ def basic_html(): """ + @pytest.fixture def wiki_html(): return """ @@ -46,6 +47,7 @@ def wiki_html(): """ + @pytest.fixture def no_meta_html(): return """ @@ -57,26 +59,27 @@ def no_meta_html(): """ + class TestBM25ContentFilter: def test_basic_extraction(self, basic_html): """Test basic content extraction functionality""" filter = BM25ContentFilter() contents = filter.filter_content(basic_html) - + assert contents, "Should extract content" assert len(contents) >= 1, "Should extract at least one content block" - assert "long paragraph" in ' '.join(contents).lower() - assert "navigation" not in ' '.join(contents).lower() + assert "long paragraph" in " ".join(contents).lower() + assert "navigation" not in " ".join(contents).lower() def test_user_query_override(self, basic_html): """Test that user query overrides metadata extraction""" user_query = "specific test query" filter = BM25ContentFilter(user_query=user_query) - + # Access internal state to verify query usage - soup = BeautifulSoup(basic_html, 'lxml') - extracted_query = filter.extract_page_query(soup.find('head')) - + soup = BeautifulSoup(basic_html, "lxml") + extracted_query = filter.extract_page_query(soup.find("head")) + assert extracted_query == user_query assert "Test description" not in extracted_query @@ -84,8 +87,8 @@ class TestBM25ContentFilter: """Test that headers are properly extracted despite length""" filter = BM25ContentFilter() contents = filter.filter_content(wiki_html) - - combined_content = ' '.join(contents).lower() + + combined_content = " ".join(contents).lower() assert "section 1" in combined_content, "Should include section header" assert "article title" in combined_content, "Should include main title" @@ -93,9 +96,11 @@ class TestBM25ContentFilter: """Test fallback behavior when no metadata is present""" filter = BM25ContentFilter() contents = filter.filter_content(no_meta_html) - + assert contents, "Should extract content even without metadata" - assert "First paragraph" in ' '.join(contents), "Should use first paragraph content" + assert "First paragraph" in " ".join( + contents + ), "Should use first paragraph content" def test_empty_input(self): """Test handling of empty input""" @@ -108,29 +113,30 @@ class TestBM25ContentFilter: malformed_html = "

Unclosed paragraph

Nested content

" filter = BM25ContentFilter() contents = filter.filter_content(malformed_html) - + assert isinstance(contents, list), "Should return list even with malformed HTML" - + def test_threshold_behavior(self, basic_html): """Test different BM25 threshold values""" strict_filter = BM25ContentFilter(bm25_threshold=2.0) lenient_filter = BM25ContentFilter(bm25_threshold=0.5) - + strict_contents = strict_filter.filter_content(basic_html) lenient_contents = lenient_filter.filter_content(basic_html) - - assert len(strict_contents) <= len(lenient_contents), \ - "Strict threshold should extract fewer elements" + + assert len(strict_contents) <= len( + lenient_contents + ), "Strict threshold should extract fewer elements" def test_html_cleaning(self, basic_html): """Test HTML cleaning functionality""" filter = BM25ContentFilter() contents = filter.filter_content(basic_html) - - cleaned_content = ' '.join(contents) - assert 'class=' not in cleaned_content, "Should remove class attributes" - assert 'style=' not in cleaned_content, "Should remove style attributes" - assert ' """ + @pytest.fixture def link_heavy_html(): return """ @@ -40,6 +41,7 @@ def link_heavy_html(): """ + @pytest.fixture def mixed_content_html(): return """ @@ -60,13 +62,14 @@ def mixed_content_html(): """ + class TestPruningContentFilter: def test_basic_pruning(self, basic_html): """Test basic content pruning functionality""" filter = PruningContentFilter(min_word_threshold=5) contents = filter.filter_content(basic_html) - - combined_content = ' '.join(contents).lower() + + combined_content = " ".join(contents).lower() assert "high-quality paragraph" in combined_content assert "sidebar content" not in combined_content assert "share buttons" not in combined_content @@ -75,40 +78,42 @@ class TestPruningContentFilter: """Test minimum word threshold filtering""" filter = PruningContentFilter(min_word_threshold=10) contents = filter.filter_content(mixed_content_html) - - combined_content = ' '.join(contents).lower() + + combined_content = " ".join(contents).lower() assert "short summary" not in combined_content assert "long high-quality paragraph" in combined_content assert "short comment" not in combined_content def test_threshold_types(self, basic_html): """Test fixed vs dynamic thresholds""" - fixed_filter = PruningContentFilter(threshold_type='fixed', threshold=0.48) - dynamic_filter = PruningContentFilter(threshold_type='dynamic', threshold=0.45) - + fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48) + dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45) + fixed_contents = fixed_filter.filter_content(basic_html) dynamic_contents = dynamic_filter.filter_content(basic_html) - - assert len(fixed_contents) != len(dynamic_contents), \ - "Fixed and dynamic thresholds should yield different results" + + assert len(fixed_contents) != len( + dynamic_contents + ), "Fixed and dynamic thresholds should yield different results" def test_link_density_impact(self, link_heavy_html): """Test handling of link-heavy content""" - filter = PruningContentFilter(threshold_type='dynamic') + filter = PruningContentFilter(threshold_type="dynamic") contents = filter.filter_content(link_heavy_html) - - combined_content = ' '.join(contents).lower() + + combined_content = " ".join(contents).lower() assert "good content paragraph" in combined_content - assert len([c for c in contents if 'href' in c]) < 2, \ - "Should prune link-heavy sections" + assert ( + len([c for c in contents if "href" in c]) < 2 + ), "Should prune link-heavy sections" def test_tag_importance(self, mixed_content_html): """Test tag importance in scoring""" - filter = PruningContentFilter(threshold_type='dynamic') + filter = PruningContentFilter(threshold_type="dynamic") contents = filter.filter_content(mixed_content_html) - - has_article = any('article' in c.lower() for c in contents) - has_h1 = any('h1' in c.lower() for c in contents) + + has_article = any("article" in c.lower() for c in contents) + has_h1 = any("h1" in c.lower() for c in contents) assert has_article or has_h1, "Should retain important tags" def test_empty_input(self): @@ -127,26 +132,31 @@ class TestPruningContentFilter: def test_performance(self, basic_html): """Test performance with timer""" filter = PruningContentFilter() - + import time + start = time.perf_counter() filter.filter_content(basic_html) duration = time.perf_counter() - start - + # Extra strict on performance since you mentioned milliseconds matter assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds" - @pytest.mark.parametrize("threshold,expected_count", [ - (0.3, 4), # Very lenient - (0.48, 2), # Default - (0.7, 1), # Very strict - ]) + @pytest.mark.parametrize( + "threshold,expected_count", + [ + (0.3, 4), # Very lenient + (0.48, 2), # Default + (0.7, 1), # Very strict + ], + ) def test_threshold_levels(self, mixed_content_html, threshold, expected_count): """Test different threshold levels""" - filter = PruningContentFilter(threshold_type='fixed', threshold=threshold) + filter = PruningContentFilter(threshold_type="fixed", threshold=threshold) contents = filter.filter_content(mixed_content_html) - assert len(contents) <= expected_count, \ - f"Expected {expected_count} or fewer elements with threshold {threshold}" + assert ( + len(contents) <= expected_count + ), f"Expected {expected_count} or fewer elements with threshold {threshold}" def test_consistent_output(self, basic_html): """Test output consistency across multiple runs""" @@ -155,5 +165,6 @@ class TestPruningContentFilter: second_run = filter.filter_content(basic_html) assert first_run == second_run, "Output should be consistent" + if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/async/test_content_scraper_strategy.py b/tests/async/test_content_scraper_strategy.py index 62c49148..e6caf240 100644 --- a/tests/async/test_content_scraper_strategy.py +++ b/tests/async/test_content_scraper_strategy.py @@ -1,22 +1,24 @@ -import asyncio -from bs4 import BeautifulSoup -from typing import Dict, Any import os import sys import time import csv from tabulate import tabulate from dataclasses import dataclass -from typing import List, Dict +from typing import List -parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +parent_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) sys.path.append(parent_dir) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) from crawl4ai.content_scraping_strategy import WebScrapingStrategy -from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent +from crawl4ai.content_scraping_strategy import ( + WebScrapingStrategy as WebScrapingStrategyCurrent, +) # from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent + @dataclass class TestResult: name: str @@ -27,69 +29,71 @@ class TestResult: markdown_length: int execution_time: float + class StrategyTester: def __init__(self): self.new_scraper = WebScrapingStrategy() self.current_scraper = WebScrapingStrategyCurrent() - with open(__location__ + '/sample_wikipedia.html', 'r', encoding='utf-8') as f: + with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f: self.WIKI_HTML = f.read() - self.results = {'new': [], 'current': []} - + self.results = {"new": [], "current": []} + def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]: results = [] for scraper in [self.new_scraper, self.current_scraper]: start_time = time.time() result = scraper._get_content_of_website_optimized( - url="https://en.wikipedia.org/wiki/Test", - html=self.WIKI_HTML, - **kwargs + url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs ) execution_time = time.time() - start_time - + test_result = TestResult( name=name, - success=result['success'], - images=len(result['media']['images']), - internal_links=len(result['links']['internal']), - external_links=len(result['links']['external']), - markdown_length=len(result['markdown']), - execution_time=execution_time + success=result["success"], + images=len(result["media"]["images"]), + internal_links=len(result["links"]["internal"]), + external_links=len(result["links"]["external"]), + markdown_length=len(result["markdown"]), + execution_time=execution_time, ) results.append(test_result) - + return results[0], results[1] # new, current def run_all_tests(self): test_cases = [ ("Basic Extraction", {}), - ("Exclude Tags", {'excluded_tags': ['table', 'div.infobox', 'div.navbox']}), - ("Word Threshold", {'word_count_threshold': 50}), - ("CSS Selector", {'css_selector': 'div.mw-parser-output > p'}), - ("Link Exclusions", { - 'exclude_external_links': True, - 'exclude_social_media_links': True, - 'exclude_domains': ['facebook.com', 'twitter.com'] - }), - ("Media Handling", { - 'exclude_external_images': True, - 'image_description_min_word_threshold': 20 - }), - ("Text Only", { - 'only_text': True, - 'remove_forms': True - }), - ("HTML Cleaning", { - 'clean_html': True, - 'keep_data_attributes': True - }), - ("HTML2Text Options", { - 'html2text': { - 'skip_internal_links': True, - 'single_line_break': True, - 'mark_code': True, - 'preserve_tags': ['pre', 'code'] - } - }) + ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}), + ("Word Threshold", {"word_count_threshold": 50}), + ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}), + ( + "Link Exclusions", + { + "exclude_external_links": True, + "exclude_social_media_links": True, + "exclude_domains": ["facebook.com", "twitter.com"], + }, + ), + ( + "Media Handling", + { + "exclude_external_images": True, + "image_description_min_word_threshold": 20, + }, + ), + ("Text Only", {"only_text": True, "remove_forms": True}), + ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}), + ( + "HTML2Text Options", + { + "html2text": { + "skip_internal_links": True, + "single_line_break": True, + "mark_code": True, + "preserve_tags": ["pre", "code"], + } + }, + ), ] all_results = [] @@ -99,64 +103,117 @@ class StrategyTester: all_results.append((name, new_result, current_result)) except Exception as e: print(f"Error in {name}: {str(e)}") - + self.save_results_to_csv(all_results) self.print_comparison_table(all_results) def save_results_to_csv(self, all_results: List[tuple]): - csv_file = os.path.join(__location__, 'strategy_comparison_results.csv') - with open(csv_file, 'w', newline='') as f: + csv_file = os.path.join(__location__, "strategy_comparison_results.csv") + with open(csv_file, "w", newline="") as f: writer = csv.writer(f) - writer.writerow(['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links', - 'External Links', 'Markdown Length', 'Execution Time']) - + writer.writerow( + [ + "Test Name", + "Strategy", + "Success", + "Images", + "Internal Links", + "External Links", + "Markdown Length", + "Execution Time", + ] + ) + for name, new_result, current_result in all_results: - writer.writerow([name, 'New', new_result.success, new_result.images, - new_result.internal_links, new_result.external_links, - new_result.markdown_length, f"{new_result.execution_time:.3f}"]) - writer.writerow([name, 'Current', current_result.success, current_result.images, - current_result.internal_links, current_result.external_links, - current_result.markdown_length, f"{current_result.execution_time:.3f}"]) + writer.writerow( + [ + name, + "New", + new_result.success, + new_result.images, + new_result.internal_links, + new_result.external_links, + new_result.markdown_length, + f"{new_result.execution_time:.3f}", + ] + ) + writer.writerow( + [ + name, + "Current", + current_result.success, + current_result.images, + current_result.internal_links, + current_result.external_links, + current_result.markdown_length, + f"{current_result.execution_time:.3f}", + ] + ) def print_comparison_table(self, all_results: List[tuple]): table_data = [] - headers = ['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links', - 'External Links', 'Markdown Length', 'Time (s)'] + headers = [ + "Test Name", + "Strategy", + "Success", + "Images", + "Internal Links", + "External Links", + "Markdown Length", + "Time (s)", + ] for name, new_result, current_result in all_results: # Check for differences differences = [] - if new_result.images != current_result.images: differences.append('images') - if new_result.internal_links != current_result.internal_links: differences.append('internal_links') - if new_result.external_links != current_result.external_links: differences.append('external_links') - if new_result.markdown_length != current_result.markdown_length: differences.append('markdown') - + if new_result.images != current_result.images: + differences.append("images") + if new_result.internal_links != current_result.internal_links: + differences.append("internal_links") + if new_result.external_links != current_result.external_links: + differences.append("external_links") + if new_result.markdown_length != current_result.markdown_length: + differences.append("markdown") + # Add row for new strategy new_row = [ - name, 'New', new_result.success, new_result.images, - new_result.internal_links, new_result.external_links, - new_result.markdown_length, f"{new_result.execution_time:.3f}" + name, + "New", + new_result.success, + new_result.images, + new_result.internal_links, + new_result.external_links, + new_result.markdown_length, + f"{new_result.execution_time:.3f}", ] table_data.append(new_row) - + # Add row for current strategy current_row = [ - '', 'Current', current_result.success, current_result.images, - current_result.internal_links, current_result.external_links, - current_result.markdown_length, f"{current_result.execution_time:.3f}" + "", + "Current", + current_result.success, + current_result.images, + current_result.internal_links, + current_result.external_links, + current_result.markdown_length, + f"{current_result.execution_time:.3f}", ] table_data.append(current_row) - + # Add difference summary if any if differences: - table_data.append(['', '⚠️ Differences', ', '.join(differences), '', '', '', '', '']) - + table_data.append( + ["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""] + ) + # Add empty row for better readability - table_data.append([''] * len(headers)) + table_data.append([""] * len(headers)) print("\nStrategy Comparison Results:") - print(tabulate(table_data, headers=headers, tablefmt='grid')) + print(tabulate(table_data, headers=headers, tablefmt="grid")) + if __name__ == "__main__": tester = StrategyTester() - tester.run_all_tests() \ No newline at end of file + tester.run_all_tests() diff --git a/tests/async/test_crawler_strategy.py b/tests/async/test_crawler_strategy.py index a507058d..337b5aaa 100644 --- a/tests/async/test_crawler_strategy.py +++ b/tests/async/test_crawler_strategy.py @@ -1,14 +1,13 @@ import os import sys import pytest -import asyncio # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler -from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy + @pytest.mark.asyncio async def test_custom_user_agent(): @@ -20,6 +19,7 @@ async def test_custom_user_agent(): assert result.success assert custom_user_agent in result.html + @pytest.mark.asyncio async def test_custom_headers(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -31,6 +31,7 @@ async def test_custom_headers(): assert "X-Test-Header" in result.html assert "TestValue" in result.html + @pytest.mark.asyncio async def test_javascript_execution(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -40,19 +41,22 @@ async def test_javascript_execution(): assert result.success assert "

Modified by JS

" in result.html + @pytest.mark.asyncio async def test_hook_execution(): async with AsyncWebCrawler(verbose=True) as crawler: + async def test_hook(page): await page.evaluate("document.body.style.backgroundColor = 'red';") return page - crawler.crawler_strategy.set_hook('after_goto', test_hook) + crawler.crawler_strategy.set_hook("after_goto", test_hook) url = "https://www.example.com" result = await crawler.arun(url=url, bypass_cache=True) assert result.success assert "background-color: red" in result.html + @pytest.mark.asyncio async def test_screenshot(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -63,6 +67,7 @@ async def test_screenshot(): assert isinstance(result.screenshot, str) assert len(result.screenshot) > 0 + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_database_operations.py b/tests/async/test_database_operations.py index 90a09ff0..db0d328e 100644 --- a/tests/async/test_database_operations.py +++ b/tests/async/test_database_operations.py @@ -1,8 +1,6 @@ import os import sys import pytest -import asyncio -import json # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -10,6 +8,7 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_cache_url(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -23,6 +22,7 @@ async def test_cache_url(): assert result2.success assert result2.html == result1.html + @pytest.mark.asyncio async def test_bypass_cache(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -34,25 +34,29 @@ async def test_bypass_cache(): # Second run bypassing cache result2 = await crawler.arun(url=url, bypass_cache=True) assert result2.success - assert result2.html != result1.html # Content might be different due to dynamic nature of websites + assert ( + result2.html != result1.html + ) # Content might be different due to dynamic nature of websites + @pytest.mark.asyncio async def test_cache_size(): async with AsyncWebCrawler(verbose=True) as crawler: initial_size = await crawler.aget_cache_size() - + url = "https://www.nbcnews.com/business" await crawler.arun(url=url, bypass_cache=True) - + new_size = await crawler.aget_cache_size() assert new_size == initial_size + 1 + @pytest.mark.asyncio async def test_clear_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.example.org" await crawler.arun(url=url, bypass_cache=True) - + initial_size = await crawler.aget_cache_size() assert initial_size > 0 @@ -60,12 +64,13 @@ async def test_clear_cache(): new_size = await crawler.aget_cache_size() assert new_size == 0 + @pytest.mark.asyncio async def test_flush_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.example.net" await crawler.arun(url=url, bypass_cache=True) - + initial_size = await crawler.aget_cache_size() assert initial_size > 0 @@ -75,8 +80,11 @@ async def test_flush_cache(): # Try to retrieve the previously cached URL result = await crawler.arun(url=url, bypass_cache=False) - assert result.success # The crawler should still succeed, but it will fetch the content anew + assert ( + result.success + ) # The crawler should still succeed, but it will fetch the content anew + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_dispatchers.py b/tests/async/test_dispatchers.py index 5c3788e5..99cf4a98 100644 --- a/tests/async/test_dispatchers.py +++ b/tests/async/test_dispatchers.py @@ -1,114 +1,133 @@ import pytest -import asyncio, time +import time from crawl4ai import ( - AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, - MemoryAdaptiveDispatcher, SemaphoreDispatcher, - RateLimiter, CrawlerMonitor, DisplayMode, CacheMode + AsyncWebCrawler, + BrowserConfig, + CrawlerRunConfig, + MemoryAdaptiveDispatcher, + SemaphoreDispatcher, + RateLimiter, + CrawlerMonitor, + DisplayMode, + CacheMode, ) + @pytest.fixture def browser_config(): - return BrowserConfig( - headless=True, - verbose=False - ) + return BrowserConfig(headless=True, verbose=False) + @pytest.fixture def run_config(): - return CrawlerRunConfig( - cache_mode=CacheMode.BYPASS, - verbose=False - ) + return CrawlerRunConfig(cache_mode=CacheMode.BYPASS, verbose=False) + @pytest.fixture def test_urls(): return [ "http://example.com", "http://example.com/page1", - "http://example.com/page2" + "http://example.com/page2", ] + @pytest.mark.asyncio class TestDispatchStrategies: - async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher( - memory_threshold_percent=70.0, - max_session_permit=2, - check_interval=0.1 + memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1 + ) + results = await crawler.arun_many( + test_urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) assert len(results) == len(test_urls) assert all(r.success for r in results) - async def test_memory_adaptive_with_rate_limit(self, browser_config, run_config, test_urls): + async def test_memory_adaptive_with_rate_limit( + self, browser_config, run_config, test_urls + ): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1, rate_limiter=RateLimiter( - base_delay=(0.1, 0.2), - max_delay=1.0, - max_retries=2 - ) + base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2 + ), + ) + results = await crawler.arun_many( + test_urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) assert len(results) == len(test_urls) assert all(r.success for r in results) async def test_semaphore_basic(self, browser_config, run_config, test_urls): async with AsyncWebCrawler(config=browser_config) as crawler: - dispatcher = SemaphoreDispatcher( - semaphore_count=2 + dispatcher = SemaphoreDispatcher(semaphore_count=2) + results = await crawler.arun_many( + test_urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) assert len(results) == len(test_urls) assert all(r.success for r in results) - async def test_semaphore_with_rate_limit(self, browser_config, run_config, test_urls): + async def test_semaphore_with_rate_limit( + self, browser_config, run_config, test_urls + ): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = SemaphoreDispatcher( semaphore_count=2, rate_limiter=RateLimiter( - base_delay=(0.1, 0.2), - max_delay=1.0, - max_retries=2 - ) + base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2 + ), + ) + results = await crawler.arun_many( + test_urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) assert len(results) == len(test_urls) assert all(r.success for r in results) - async def test_memory_adaptive_memory_error(self, browser_config, run_config, test_urls): + async def test_memory_adaptive_memory_error( + self, browser_config, run_config, test_urls + ): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=1.0, # Set unrealistically low threshold max_session_permit=2, check_interval=0.1, - memory_wait_timeout=1.0 # Short timeout for testing + memory_wait_timeout=1.0, # Short timeout for testing ) with pytest.raises(MemoryError): - await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) + await crawler.arun_many( + test_urls, config=run_config, dispatcher=dispatcher + ) async def test_empty_urls(self, browser_config, run_config): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) - results = await crawler.arun_many([], config=run_config, dispatcher=dispatcher) + results = await crawler.arun_many( + [], config=run_config, dispatcher=dispatcher + ) assert len(results) == 0 async def test_single_url(self, browser_config, run_config): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) - results = await crawler.arun_many(["http://example.com"], config=run_config, dispatcher=dispatcher) + results = await crawler.arun_many( + ["http://example.com"], config=run_config, dispatcher=dispatcher + ) assert len(results) == 1 assert results[0].success async def test_invalid_urls(self, browser_config, run_config): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) - results = await crawler.arun_many(["http://invalid.url.that.doesnt.exist"], config=run_config, dispatcher=dispatcher) + results = await crawler.arun_many( + ["http://invalid.url.that.doesnt.exist"], + config=run_config, + dispatcher=dispatcher, + ) assert len(results) == 1 assert not results[0].success @@ -121,27 +140,31 @@ class TestDispatchStrategies: base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2, - rate_limit_codes=[200] # Force rate limiting for testing - ) + rate_limit_codes=[200], # Force rate limiting for testing + ), ) start_time = time.time() - results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) + results = await crawler.arun_many( + urls, config=run_config, dispatcher=dispatcher + ) duration = time.time() - start_time assert len(results) == len(urls) assert duration > 1.0 # Ensure rate limiting caused delays async def test_monitor_integration(self, browser_config, run_config, test_urls): async with AsyncWebCrawler(config=browser_config) as crawler: - monitor = CrawlerMonitor(max_visible_rows=5, display_mode=DisplayMode.DETAILED) - dispatcher = MemoryAdaptiveDispatcher( - max_session_permit=2, - monitor=monitor + monitor = CrawlerMonitor( + max_visible_rows=5, display_mode=DisplayMode.DETAILED + ) + dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor) + results = await crawler.arun_many( + test_urls, config=run_config, dispatcher=dispatcher ) - results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) assert len(results) == len(test_urls) # Check monitor stats assert len(monitor.stats) == len(test_urls) assert all(stat.end_time is not None for stat in monitor.stats.values()) + if __name__ == "__main__": - pytest.main([__file__, "-v", "--asyncio-mode=auto"]) \ No newline at end of file + pytest.main([__file__, "-v", "--asyncio-mode=auto"]) diff --git a/tests/async/test_edge_cases.py b/tests/async/test_edge_cases.py index 34fadb1e..d3adb53c 100644 --- a/tests/async/test_edge_cases.py +++ b/tests/async/test_edge_cases.py @@ -2,9 +2,9 @@ import os import re import sys import pytest -import json from bs4 import BeautifulSoup import asyncio + # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) @@ -59,19 +59,21 @@ from crawl4ai.async_webcrawler import AsyncWebCrawler # assert result.success # assert "github" in result.html.lower() + # Add this test to your existing test file @pytest.mark.asyncio async def test_typescript_commits_multi_page(): first_commit = "" + async def on_execution_started(page): - nonlocal first_commit + nonlocal first_commit try: # Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4')) while True: - await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4') - commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4') - commit = await commit.evaluate('(element) => element.textContent') - commit = re.sub(r'\s+', '', commit) + await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4") + commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4") + commit = await commit.evaluate("(element) => element.textContent") + commit = re.sub(r"\s+", "", commit) if commit and commit != first_commit: first_commit = commit break @@ -79,9 +81,8 @@ async def test_typescript_commits_multi_page(): except Exception as e: print(f"Warning: New content didn't appear after JavaScript execution: {e}") - async with AsyncWebCrawler(verbose=True) as crawler: - crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started) + crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started) url = "https://github.com/microsoft/TypeScript/commits/main" session_id = "typescript_commits_session" @@ -97,19 +98,21 @@ async def test_typescript_commits_multi_page(): url=url, # Only use URL for the first page session_id=session_id, css_selector="li.Box-sc-g0xbh4-0", - js=js_next_page if page > 0 else None, # Don't click 'next' on the first page + js=js_next_page + if page > 0 + else None, # Don't click 'next' on the first page bypass_cache=True, - js_only=page > 0 # Use js_only for subsequent pages + js_only=page > 0, # Use js_only for subsequent pages ) assert result.success, f"Failed to crawl page {page + 1}" # Parse the HTML and extract commits - soup = BeautifulSoup(result.cleaned_html, 'html.parser') + soup = BeautifulSoup(result.cleaned_html, "html.parser") commits = soup.select("li") # Take first commit find h4 extract text first_commit = commits[0].find("h4").text - first_commit = re.sub(r'\s+', '', first_commit) + first_commit = re.sub(r"\s+", "", first_commit) all_commits.extend(commits) print(f"Page {page + 1}: Found {len(commits)} commits") @@ -118,10 +121,13 @@ async def test_typescript_commits_multi_page(): await crawler.crawler_strategy.kill_session(session_id) # Assertions - assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}" - - print(f"Successfully crawled {len(all_commits)} commits across 3 pages") + assert ( + len(all_commits) >= 90 + ), f"Expected at least 90 commits, but got {len(all_commits)}" + + print(f"Successfully crawled {len(all_commits)} commits across 3 pages") + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_error_handling.py b/tests/async/test_error_handling.py index 3015edbd..ae4af6c8 100644 --- a/tests/async/test_error_handling.py +++ b/tests/async/test_error_handling.py @@ -75,4 +75,4 @@ # # Entry point for debugging # if __name__ == "__main__": -# pytest.main([__file__, "-v"]) \ No newline at end of file +# pytest.main([__file__, "-v"]) diff --git a/tests/async/test_evaluation_scraping_methods_performance.configs.py b/tests/async/test_evaluation_scraping_methods_performance.configs.py index e6305736..797cf681 100644 --- a/tests/async/test_evaluation_scraping_methods_performance.configs.py +++ b/tests/async/test_evaluation_scraping_methods_performance.configs.py @@ -1,11 +1,15 @@ import json import time from bs4 import BeautifulSoup -from crawl4ai.content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy -from typing import Dict, Any, List, Tuple +from crawl4ai.content_scraping_strategy import ( + WebScrapingStrategy, + LXMLWebScrapingStrategy, +) +from typing import Dict, List, Tuple import difflib from lxml import html as lhtml, etree + def normalize_dom(element): """ Recursively normalizes an lxml HTML element: @@ -15,7 +19,7 @@ def normalize_dom(element): Returns the same element (mutated). """ # Remove comment nodes - comments = element.xpath('//comment()') + comments = element.xpath("//comment()") for c in comments: p = c.getparent() if p is not None: @@ -45,7 +49,7 @@ def strip_html_body(root): """ If 'root' is , find its child and move all of 's children into a new
. Return that
. - + If 'root' is , similarly move all of its children into a new
and return it. Otherwise, return 'root' as-is. @@ -53,8 +57,8 @@ def strip_html_body(root): tag_name = (root.tag or "").lower() # Case 1: The root is - if tag_name == 'html': - bodies = root.xpath('./body') + if tag_name == "html": + bodies = root.xpath("./body") if bodies: body = bodies[0] new_div = lhtml.Element("div") @@ -66,7 +70,7 @@ def strip_html_body(root): return root # Case 2: The root is - elif tag_name == 'body': + elif tag_name == "body": new_div = lhtml.Element("div") for child in root: new_div.append(child) @@ -92,7 +96,9 @@ def compare_nodes(node1, node2, differences, path="/"): attrs1 = list(node1.attrib.items()) attrs2 = list(node2.attrib.items()) if attrs1 != attrs2: - differences.append(f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}") + differences.append( + f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}" + ) # 3) Compare text (trim or unify whitespace as needed) text1 = (node1.text or "").strip() @@ -102,7 +108,9 @@ def compare_nodes(node1, node2, differences, path="/"): text2 = " ".join(text2.split()) if text1 != text2: # If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup - differences.append(f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'") + differences.append( + f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'" + ) # 4) Compare number of children children1 = list(node1) @@ -123,7 +131,9 @@ def compare_nodes(node1, node2, differences, path="/"): tail1 = (node1.tail or "").strip() tail2 = (node2.tail or "").strip() if tail1 != tail2: - differences.append(f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'") + differences.append( + f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'" + ) def compare_html_structurally(html1, html2): @@ -156,11 +166,11 @@ def compare_html_structurally(html1, html2): return differences - def generate_large_html(n_elements=1000): - html = [''] + html = [""] for i in range(n_elements): - html.append(f''' + html.append( + f"""

Heading {i}

This is paragraph {i} with some content and a link

@@ -170,13 +180,15 @@ def generate_large_html(n_elements=1000):
  • List item {i}.2
  • - ''') - html.append('') - return ''.join(html) + """ + ) + html.append("") + return "".join(html) + def generate_complicated_html(): """ - HTML with multiple domains, forms, data attributes, + HTML with multiple domains, forms, data attributes, various images, comments, style, and noscript to test all parameter toggles. """ return """ @@ -258,7 +270,7 @@ def generate_complicated_html(): def get_test_scenarios(): """ Returns a dictionary of parameter sets (test scenarios) for the scraper. - Each scenario name maps to a dictionary of keyword arguments + Each scenario name maps to a dictionary of keyword arguments that will be passed into scrap() for testing various features. """ TEST_SCENARIOS = { @@ -341,7 +353,7 @@ def get_test_scenarios(): # "exclude_external_links": True # }, # "comprehensive_removal": { - # # Exclude multiple tags, remove forms & comments, + # # Exclude multiple tags, remove forms & comments, # # and also remove targeted selectors # "excluded_tags": ["aside", "noscript", "script"], # "excluded_selector": "#promo-section, .social-widget", @@ -352,19 +364,18 @@ def get_test_scenarios(): return TEST_SCENARIOS - class ScraperEquivalenceTester: def __init__(self): self.test_cases = { - 'basic': self.generate_basic_html(), - 'complex': self.generate_complex_html(), - 'malformed': self.generate_malformed_html(), + "basic": self.generate_basic_html(), + "complex": self.generate_complex_html(), + "malformed": self.generate_malformed_html(), # 'real_world': self.load_real_samples() } - + def generate_basic_html(self): return generate_large_html(1000) # Your existing function - + def generate_complex_html(self): return """ @@ -384,7 +395,7 @@ class ScraperEquivalenceTester:
    """ - + def generate_malformed_html(self): return """
    Unclosed div @@ -395,139 +406,139 @@ class ScraperEquivalenceTester: """ - + def load_real_samples(self): # Load some real-world HTML samples you've collected samples = { - 'article': open('tests/samples/article.html').read(), - 'product': open('tests/samples/product.html').read(), - 'blog': open('tests/samples/blog.html').read() + "article": open("tests/samples/article.html").read(), + "product": open("tests/samples/product.html").read(), + "blog": open("tests/samples/blog.html").read(), } return samples - def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]: """Detailed comparison of link structures""" differences = [] - - for category in ['internal', 'external']: - old_urls = {link['href'] for link in old_links[category]} - new_urls = {link['href'] for link in new_links[category]} - + + for category in ["internal", "external"]: + old_urls = {link["href"] for link in old_links[category]} + new_urls = {link["href"] for link in new_links[category]} + missing = old_urls - new_urls extra = new_urls - old_urls - + if missing: differences.append(f"Missing {category} links: {missing}") if extra: differences.append(f"Extra {category} links: {extra}") - + # Compare link attributes for common URLs common = old_urls & new_urls for url in common: - old_link = next(l for l in old_links[category] if l['href'] == url) - new_link = next(l for l in new_links[category] if l['href'] == url) - - for attr in ['text', 'title']: + old_link = next(l for l in old_links[category] if l["href"] == url) + new_link = next(l for l in new_links[category] if l["href"] == url) + + for attr in ["text", "title"]: if old_link[attr] != new_link[attr]: differences.append( f"Link attribute mismatch for {url} - {attr}:" f" old='{old_link[attr]}' vs new='{new_link[attr]}'" ) - + return differences def deep_compare_media(self, old_media: Dict, new_media: Dict) -> List[str]: """Detailed comparison of media elements""" differences = [] - - for media_type in ['images', 'videos', 'audios']: - old_srcs = {item['src'] for item in old_media[media_type]} - new_srcs = {item['src'] for item in new_media[media_type]} - + + for media_type in ["images", "videos", "audios"]: + old_srcs = {item["src"] for item in old_media[media_type]} + new_srcs = {item["src"] for item in new_media[media_type]} + missing = old_srcs - new_srcs extra = new_srcs - old_srcs - + if missing: differences.append(f"Missing {media_type}: {missing}") if extra: differences.append(f"Extra {media_type}: {extra}") - + # Compare media attributes for common sources common = old_srcs & new_srcs for src in common: - old_item = next(m for m in old_media[media_type] if m['src'] == src) - new_item = next(m for m in new_media[media_type] if m['src'] == src) - - for attr in ['alt', 'description']: + old_item = next(m for m in old_media[media_type] if m["src"] == src) + new_item = next(m for m in new_media[media_type] if m["src"] == src) + + for attr in ["alt", "description"]: if old_item.get(attr) != new_item.get(attr): differences.append( f"{media_type} attribute mismatch for {src} - {attr}:" f" old='{old_item.get(attr)}' vs new='{new_item.get(attr)}'" ) - + return differences def compare_html_content(self, old_html: str, new_html: str) -> List[str]: """Compare HTML content structure and text""" # return compare_html_structurally(old_html, new_html) differences = [] - + def normalize_html(html: str) -> Tuple[str, str]: - soup = BeautifulSoup(html, 'lxml') + soup = BeautifulSoup(html, "lxml") # Get both structure and text - structure = ' '.join(tag.name for tag in soup.find_all()) - text = ' '.join(soup.get_text().split()) + structure = " ".join(tag.name for tag in soup.find_all()) + text = " ".join(soup.get_text().split()) return structure, text - + old_structure, old_text = normalize_html(old_html) new_structure, new_text = normalize_html(new_html) - + # Compare structure if abs(len(old_structure) - len(new_structure)) > 100: - # if old_structure != new_structure: + # if old_structure != new_structure: diff = difflib.unified_diff( - old_structure.split(), - new_structure.split(), - lineterm='' + old_structure.split(), new_structure.split(), lineterm="" ) - differences.append("HTML structure differences:\n" + '\n'.join(diff)) - + differences.append("HTML structure differences:\n" + "\n".join(diff)) + # Compare text content if abs(len(old_text) - len(new_text)) > 100: - # if old_text != new_text: + # if old_text != new_text: # Show detailed text differences text_diff = difflib.unified_diff( - old_text.split(), - new_text.split(), - lineterm='' + old_text.split(), new_text.split(), lineterm="" ) - differences.append("Text content differences:\n" + '\n'.join(text_diff)) - + differences.append("Text content differences:\n" + "\n".join(text_diff)) + return differences - def compare_results(self, old_result: Dict, new_result: Dict) -> Dict[str, List[str]]: + def compare_results( + self, old_result: Dict, new_result: Dict + ) -> Dict[str, List[str]]: """Comprehensive comparison of scraper outputs""" differences = {} - + # Compare links - link_differences = self.deep_compare_links(old_result['links'], new_result['links']) + link_differences = self.deep_compare_links( + old_result["links"], new_result["links"] + ) if link_differences: - differences['links'] = link_differences - + differences["links"] = link_differences + # Compare media - media_differences = self.deep_compare_media(old_result['media'], new_result['media']) + media_differences = self.deep_compare_media( + old_result["media"], new_result["media"] + ) if media_differences: - differences['media'] = media_differences - + differences["media"] = media_differences + # Compare HTML html_differences = self.compare_html_content( - old_result['cleaned_html'], - new_result['cleaned_html'] + old_result["cleaned_html"], new_result["cleaned_html"] ) if html_differences: - differences['html'] = html_differences - + differences["html"] = html_differences + return differences def run_tests(self) -> Dict: @@ -535,52 +546,49 @@ class ScraperEquivalenceTester: # We'll still keep some "test_cases" logic from above (basic, complex, malformed). # But we add a new section for the complicated HTML scenarios. - results = { - 'tests': [], - 'summary': {'passed': 0, 'failed': 0} - } + results = {"tests": [], "summary": {"passed": 0, "failed": 0}} # 1) First, run the existing 3 built-in test cases (basic, complex, malformed). # for case_name, html in self.test_cases.items(): # print(f"\nTesting built-in case: {case_name}...") - + # original = WebScrapingStrategy() # lxml = LXMLWebScrapingStrategy() - + # start = time.time() # orig_result = original.scrap("http://test.com", html) # orig_time = time.time() - start - + # print("\nOriginal Mode:") # print(f"Cleaned HTML size: {len(orig_result['cleaned_html'])/1024:.2f} KB") # print(f"Images: {len(orig_result['media']['images'])}") # print(f"External links: {len(orig_result['links']['external'])}") # print(f"Times - Original: {orig_time:.3f}s") - + # start = time.time() # lxml_result = lxml.scrap("http://test.com", html) # lxml_time = time.time() - start - + # print("\nLXML Mode:") # print(f"Cleaned HTML size: {len(lxml_result['cleaned_html'])/1024:.2f} KB") # print(f"Images: {len(lxml_result['media']['images'])}") # print(f"External links: {len(lxml_result['links']['external'])}") # print(f"Times - LXML: {lxml_time:.3f}s") - + # # Compare # diffs = {} # link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links']) # if link_diff: # diffs['links'] = link_diff - + # media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media']) # if media_diff: # diffs['media'] = media_diff - + # html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html']) # if html_diff: # diffs['html'] = html_diff - + # test_result = { # 'case': case_name, # 'lxml_mode': { @@ -590,7 +598,7 @@ class ScraperEquivalenceTester: # 'original_time': orig_time # } # results['tests'].append(test_result) - + # if not diffs: # results['summary']['passed'] += 1 # else: @@ -599,50 +607,55 @@ class ScraperEquivalenceTester: # 2) Now, run the complicated HTML with multiple parameter scenarios. complicated_html = generate_complicated_html() print("\n=== Testing complicated HTML with multiple parameter scenarios ===") - + # Create the scrapers once (or you can re-create if needed) original = WebScrapingStrategy() lxml = LXMLWebScrapingStrategy() for scenario_name, params in get_test_scenarios().items(): print(f"\nScenario: {scenario_name}") - + start = time.time() orig_result = original.scrap("http://test.com", complicated_html, **params) orig_time = time.time() - start - + start = time.time() lxml_result = lxml.scrap("http://test.com", complicated_html, **params) lxml_time = time.time() - start - + diffs = {} - link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links']) + link_diff = self.deep_compare_links( + orig_result["links"], lxml_result["links"] + ) if link_diff: - diffs['links'] = link_diff + diffs["links"] = link_diff - media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media']) + media_diff = self.deep_compare_media( + orig_result["media"], lxml_result["media"] + ) if media_diff: - diffs['media'] = media_diff + diffs["media"] = media_diff - html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html']) + html_diff = self.compare_html_content( + orig_result["cleaned_html"], lxml_result["cleaned_html"] + ) if html_diff: - diffs['html'] = html_diff - + diffs["html"] = html_diff + test_result = { - 'case': f"complicated_{scenario_name}", - 'lxml_mode': { - 'differences': diffs, - 'execution_time': lxml_time - }, - 'original_time': orig_time + "case": f"complicated_{scenario_name}", + "lxml_mode": {"differences": diffs, "execution_time": lxml_time}, + "original_time": orig_time, } - results['tests'].append(test_result) - + results["tests"].append(test_result) + if not diffs: - results['summary']['passed'] += 1 - print(f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)") + results["summary"]["passed"] += 1 + print( + f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)" + ) else: - results['summary']['failed'] += 1 + results["summary"]["failed"] += 1 print("❌ Differences found:") for category, dlist in diffs.items(): print(f" {category}:") @@ -657,20 +670,22 @@ class ScraperEquivalenceTester: print(f"Total Cases: {len(results['tests'])}") print(f"Passed: {results['summary']['passed']}") print(f"Failed: {results['summary']['failed']}") - - for test in results['tests']: + + for test in results["tests"]: print(f"\nTest Case: {test['case']}") - - if not test['lxml_mode']['differences']: + + if not test["lxml_mode"]["differences"]: print("✅ All implementations produced identical results") - print(f"Times - Original: {test['original_time']:.3f}s, " - f"LXML: {test['lxml_mode']['execution_time']:.3f}s") + print( + f"Times - Original: {test['original_time']:.3f}s, " + f"LXML: {test['lxml_mode']['execution_time']:.3f}s" + ) else: print("❌ Differences found:") - - if test['lxml_mode']['differences']: + + if test["lxml_mode"]["differences"]: print("\nLXML Mode Differences:") - for category, diffs in test['lxml_mode']['differences'].items(): + for category, diffs in test["lxml_mode"]["differences"].items(): print(f"\n{category}:") for diff in diffs: print(f" - {diff}") @@ -680,11 +695,11 @@ def main(): tester = ScraperEquivalenceTester() results = tester.run_tests() tester.print_report(results) - + # Save detailed results for debugging - with open('scraper_equivalence_results.json', 'w') as f: + with open("scraper_equivalence_results.json", "w") as f: json.dump(results, f, indent=2) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/async/test_markdown_genertor.py b/tests/async/test_markdown_genertor.py index 2b1102ab..7eaf5d85 100644 --- a/tests/async/test_markdown_genertor.py +++ b/tests/async/test_markdown_genertor.py @@ -4,10 +4,10 @@ # - **State:** open import os, sys, time + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) -__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) -import asyncio +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) import os import time from typing import Dict, Any @@ -16,18 +16,18 @@ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator # Get current directory __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + def print_test_result(name: str, result: Dict[str, Any], execution_time: float): """Helper function to print test results.""" print(f"\n{'='*20} {name} {'='*20}") print(f"Execution time: {execution_time:.4f} seconds") - - + # Save markdown to files for key, content in result.items(): if isinstance(content, str): with open(__location__ + f"/output/{name.lower()}_{key}.md", "w") as f: f.write(content) - + # # Print first few lines of each markdown version # for key, content in result.items(): # if isinstance(content, str): @@ -36,32 +36,39 @@ def print_test_result(name: str, result: Dict[str, Any], execution_time: float): # print(preview) # print(f"Total length: {len(content)} characters") + def test_basic_markdown_conversion(): """Test basic markdown conversion with links.""" with open(__location__ + "/data/wikipedia.html", "r") as f: cleaned_html = f.read() generator = DefaultMarkdownGenerator() - + start_time = time.perf_counter() result = generator.generate_markdown( - cleaned_html=cleaned_html, - base_url="https://en.wikipedia.org" + cleaned_html=cleaned_html, base_url="https://en.wikipedia.org" ) execution_time = time.perf_counter() - start_time - - print_test_result("Basic Markdown Conversion", { - 'raw': result.raw_markdown, - 'with_citations': result.markdown_with_citations, - 'references': result.references_markdown - }, execution_time) - + + print_test_result( + "Basic Markdown Conversion", + { + "raw": result.raw_markdown, + "with_citations": result.markdown_with_citations, + "references": result.references_markdown, + }, + execution_time, + ) + # Basic assertions assert result.raw_markdown, "Raw markdown should not be empty" assert result.markdown_with_citations, "Markdown with citations should not be empty" assert result.references_markdown, "References should not be empty" assert "⟨" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets" - assert "## References" in result.references_markdown, "Should contain references section" + assert ( + "## References" in result.references_markdown + ), "Should contain references section" + def test_relative_links(): """Test handling of relative links with base URL.""" @@ -69,97 +76,106 @@ def test_relative_links(): Here's a [relative link](/wiki/Apple) and an [absolute link](https://example.com). Also an [image](/images/test.png) and another [page](/wiki/Banana). """ - + generator = DefaultMarkdownGenerator() result = generator.generate_markdown( - cleaned_html=markdown, - base_url="https://en.wikipedia.org" + cleaned_html=markdown, base_url="https://en.wikipedia.org" ) - + assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown assert "https://example.com" in result.references_markdown assert "https://en.wikipedia.org/images/test.png" in result.references_markdown + def test_duplicate_links(): """Test handling of duplicate links.""" markdown = """ Here's a [link](/test) and another [link](/test) and a [different link](/other). """ - + generator = DefaultMarkdownGenerator() result = generator.generate_markdown( - cleaned_html=markdown, - base_url="https://example.com" + cleaned_html=markdown, base_url="https://example.com" ) - + # Count citations in markdown citations = result.markdown_with_citations.count("⟨1⟩") assert citations == 2, "Same link should use same citation number" + def test_link_descriptions(): """Test handling of link titles and descriptions.""" markdown = """ Here's a [link with title](/test "Test Title") and a [link with description](/other) to test. """ - + generator = DefaultMarkdownGenerator() result = generator.generate_markdown( - cleaned_html=markdown, - base_url="https://example.com" + cleaned_html=markdown, base_url="https://example.com" ) - - assert "Test Title" in result.references_markdown, "Link title should be in references" - assert "link with description" in result.references_markdown, "Link text should be in references" + + assert ( + "Test Title" in result.references_markdown + ), "Link title should be in references" + assert ( + "link with description" in result.references_markdown + ), "Link text should be in references" + def test_performance_large_document(): """Test performance with large document.""" with open(__location__ + "/data/wikipedia.md", "r") as f: markdown = f.read() - + # Test with multiple iterations iterations = 5 times = [] - + generator = DefaultMarkdownGenerator() - + for i in range(iterations): start_time = time.perf_counter() result = generator.generate_markdown( - cleaned_html=markdown, - base_url="https://en.wikipedia.org" + cleaned_html=markdown, base_url="https://en.wikipedia.org" ) end_time = time.perf_counter() times.append(end_time - start_time) - + avg_time = sum(times) / len(times) print(f"\n{'='*20} Performance Test {'='*20}") - print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds") + print( + f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds" + ) print(f"Min time: {min(times):.4f} seconds") print(f"Max time: {max(times):.4f} seconds") + def test_image_links(): """Test handling of image links.""" markdown = """ Here's an ![image](/image.png "Image Title") and another ![image](/other.jpg). And a regular [link](/page). """ - + generator = DefaultMarkdownGenerator() result = generator.generate_markdown( - cleaned_html=markdown, - base_url="https://example.com" + cleaned_html=markdown, base_url="https://example.com" ) - - assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved" - assert "Image Title" in result.references_markdown, "Image title should be in references" + + assert ( + "![" in result.markdown_with_citations + ), "Image markdown syntax should be preserved" + assert ( + "Image Title" in result.references_markdown + ), "Image title should be in references" + if __name__ == "__main__": print("Running markdown generation strategy tests...") - + test_basic_markdown_conversion() test_relative_links() test_duplicate_links() test_link_descriptions() test_performance_large_document() test_image_links() - \ No newline at end of file diff --git a/tests/async/test_parameters_and_options.py b/tests/async/test_parameters_and_options.py index 8ae7c1d3..e153fbd3 100644 --- a/tests/async/test_parameters_and_options.py +++ b/tests/async/test_parameters_and_options.py @@ -1,8 +1,6 @@ import os import sys import pytest -import asyncio -import json # Add the parent directory to the Python path parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -10,24 +8,37 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_word_count_threshold(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True) - result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True) - + result_no_threshold = await crawler.arun( + url=url, word_count_threshold=0, bypass_cache=True + ) + result_with_threshold = await crawler.arun( + url=url, word_count_threshold=50, bypass_cache=True + ) + assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown) + @pytest.mark.asyncio async def test_css_selector(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" css_selector = "h1, h2, h3" - result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True) - + result = await crawler.arun( + url=url, css_selector=css_selector, bypass_cache=True + ) + assert result.success - assert " button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"] + + js_code = [ + "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" + ] result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True) - + assert result_with_more.success assert len(result_with_more.markdown) > len(result_without_more.markdown) + @pytest.mark.asyncio async def test_screenshot(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" result = await crawler.arun(url=url, screenshot=True, bypass_cache=True) - + assert result.success assert result.screenshot assert isinstance(result.screenshot, str) # Should be a base64 encoded string + @pytest.mark.asyncio async def test_custom_user_agent(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0" - result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True) - + result = await crawler.arun( + url=url, user_agent=custom_user_agent, bypass_cache=True + ) + assert result.success # Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful + @pytest.mark.asyncio async def test_extract_media_and_links(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" result = await crawler.arun(url=url, bypass_cache=True) - + assert result.success assert result.media assert isinstance(result.media, dict) - assert 'images' in result.media + assert "images" in result.media assert result.links assert isinstance(result.links, dict) - assert 'internal' in result.links and 'external' in result.links + assert "internal" in result.links and "external" in result.links + @pytest.mark.asyncio async def test_metadata_extraction(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" result = await crawler.arun(url=url, bypass_cache=True) - + assert result.success assert result.metadata assert isinstance(result.metadata, dict) # Check for common metadata fields - assert any(key in result.metadata for key in ['title', 'description', 'keywords']) + assert any( + key in result.metadata for key in ["title", "description", "keywords"] + ) + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_performance.py b/tests/async/test_performance.py index 9528b5ab..a35e2bee 100644 --- a/tests/async/test_performance.py +++ b/tests/async/test_performance.py @@ -1,7 +1,6 @@ import os import sys import pytest -import asyncio import time # Add the parent directory to the Python path @@ -10,6 +9,7 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_crawl_speed(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -17,13 +17,14 @@ async def test_crawl_speed(): start_time = time.time() result = await crawler.arun(url=url, bypass_cache=True) end_time = time.time() - + assert result.success crawl_time = end_time - start_time print(f"Crawl time: {crawl_time:.2f} seconds") - + assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds" + @pytest.mark.asyncio async def test_concurrent_crawling_performance(): async with AsyncWebCrawler(verbose=True) as crawler: @@ -32,41 +33,47 @@ async def test_concurrent_crawling_performance(): "https://www.example.com", "https://www.python.org", "https://www.github.com", - "https://www.stackoverflow.com" + "https://www.stackoverflow.com", ] - + start_time = time.time() results = await crawler.arun_many(urls=urls, bypass_cache=True) end_time = time.time() - + total_time = end_time - start_time print(f"Total time for concurrent crawling: {total_time:.2f} seconds") - + assert all(result.success for result in results) assert len(results) == len(urls) - - assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" + + assert ( + total_time < len(urls) * 5 + ), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" + @pytest.mark.asyncio async def test_crawl_speed_with_caching(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - + start_time = time.time() result1 = await crawler.arun(url=url, bypass_cache=True) end_time = time.time() first_crawl_time = end_time - start_time - + start_time = time.time() result2 = await crawler.arun(url=url, bypass_cache=False) end_time = time.time() second_crawl_time = end_time - start_time - + assert result1.success and result2.success print(f"First crawl time: {first_crawl_time:.2f} seconds") print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds") - - assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster" + + assert ( + second_crawl_time < first_crawl_time / 2 + ), "Cached crawl not significantly faster" + if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/async/test_screenshot.py b/tests/async/test_screenshot.py index 0c4439f6..36c6c0aa 100644 --- a/tests/async/test_screenshot.py +++ b/tests/async/test_screenshot.py @@ -1,7 +1,6 @@ import os import sys import pytest -import asyncio import base64 from PIL import Image import io @@ -12,113 +11,112 @@ sys.path.append(parent_dir) from crawl4ai.async_webcrawler import AsyncWebCrawler + @pytest.mark.asyncio async def test_basic_screenshot(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://example.com" # A static website result = await crawler.arun(url=url, bypass_cache=True, screenshot=True) - + assert result.success assert result.screenshot is not None - + # Verify the screenshot is a valid image image_data = base64.b64decode(result.screenshot) image = Image.open(io.BytesIO(image_data)) assert image.format == "PNG" + @pytest.mark.asyncio async def test_screenshot_with_wait_for(): async with AsyncWebCrawler(verbose=True) as crawler: # Using a website with dynamic content url = "https://www.youtube.com" wait_for = "css:#content" # Wait for the main content to load - + result = await crawler.arun( - url=url, - bypass_cache=True, - screenshot=True, - wait_for=wait_for + url=url, bypass_cache=True, screenshot=True, wait_for=wait_for ) - + assert result.success assert result.screenshot is not None - + # Verify the screenshot is a valid image image_data = base64.b64decode(result.screenshot) image = Image.open(io.BytesIO(image_data)) assert image.format == "PNG" - + # You might want to add more specific checks here, like image dimensions # or even use image recognition to verify certain elements are present + @pytest.mark.asyncio async def test_screenshot_with_js_wait_for(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.amazon.com" wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null" - + result = await crawler.arun( - url=url, - bypass_cache=True, - screenshot=True, - wait_for=wait_for + url=url, bypass_cache=True, screenshot=True, wait_for=wait_for ) - + assert result.success assert result.screenshot is not None - + image_data = base64.b64decode(result.screenshot) image = Image.open(io.BytesIO(image_data)) assert image.format == "PNG" + @pytest.mark.asyncio async def test_screenshot_without_wait_for(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nytimes.com" # A website with lots of dynamic content - + result = await crawler.arun(url=url, bypass_cache=True, screenshot=True) - + assert result.success assert result.screenshot is not None - + image_data = base64.b64decode(result.screenshot) image = Image.open(io.BytesIO(image_data)) assert image.format == "PNG" + @pytest.mark.asyncio async def test_screenshot_comparison(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.reddit.com" wait_for = "css:#SHORTCUT_FOCUSABLE_DIV" - + # Take screenshot without wait_for result_without_wait = await crawler.arun( - url=url, - bypass_cache=True, - screenshot=True + url=url, bypass_cache=True, screenshot=True ) - + # Take screenshot with wait_for result_with_wait = await crawler.arun( - url=url, - bypass_cache=True, - screenshot=True, - wait_for=wait_for + url=url, bypass_cache=True, screenshot=True, wait_for=wait_for ) - + assert result_without_wait.success and result_with_wait.success assert result_without_wait.screenshot is not None assert result_with_wait.screenshot is not None - + # Compare the two screenshots - image_without_wait = Image.open(io.BytesIO(base64.b64decode(result_without_wait.screenshot))) - image_with_wait = Image.open(io.BytesIO(base64.b64decode(result_with_wait.screenshot))) - + image_without_wait = Image.open( + io.BytesIO(base64.b64decode(result_without_wait.screenshot)) + ) + image_with_wait = Image.open( + io.BytesIO(base64.b64decode(result_with_wait.screenshot)) + ) + # This is a simple size comparison. In a real-world scenario, you might want to use # more sophisticated image comparison techniques. assert image_with_wait.size[0] >= image_without_wait.size[0] assert image_with_wait.size[1] >= image_without_wait.size[1] + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/docker_example.py b/tests/docker_example.py index 658e80fd..336ca52f 100644 --- a/tests/docker_example.py +++ b/tests/docker_example.py @@ -6,53 +6,72 @@ import base64 import os from typing import Dict, Any + class Crawl4AiTester: def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): self.base_url = base_url - self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback - self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} - - def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: + self.api_token = api_token or os.getenv( + "CRAWL4AI_API_TOKEN" + ) # Check environment variable as fallback + self.headers = ( + {"Authorization": f"Bearer {self.api_token}"} if self.api_token else {} + ) + + def submit_and_wait( + self, request_data: Dict[str, Any], timeout: int = 300 + ) -> Dict[str, Any]: # Submit crawl job - response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) + response = requests.post( + f"{self.base_url}/crawl", json=request_data, headers=self.headers + ) if response.status_code == 403: raise Exception("API token is invalid or missing") task_id = response.json()["task_id"] print(f"Task ID: {task_id}") - + # Poll for result start_time = time.time() while True: if time.time() - start_time > timeout: - raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") - - result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) + raise TimeoutError( + f"Task {task_id} did not complete within {timeout} seconds" + ) + + result = requests.get( + f"{self.base_url}/task/{task_id}", headers=self.headers + ) status = result.json() - + if status["status"] == "failed": print("Task failed:", status.get("error")) raise Exception(f"Task failed: {status.get('error')}") - + if status["status"] == "completed": return status - + time.sleep(2) - + def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) + response = requests.post( + f"{self.base_url}/crawl_sync", + json=request_data, + headers=self.headers, + timeout=60, + ) if response.status_code == 408: raise TimeoutError("Task did not complete within server timeout") response.raise_for_status() return response.json() + def test_docker_deployment(version="basic"): tester = Crawl4AiTester( # base_url="http://localhost:11235" , base_url="https://crawl4ai-sby74.ondigitalocean.app", - api_token="test" + api_token="test", ) print(f"Testing Crawl4AI Docker {version} version") - + # Health check with timeout and retry max_retries = 5 for i in range(max_retries): @@ -60,18 +79,18 @@ def test_docker_deployment(version="basic"): health = requests.get(f"{tester.base_url}/health", timeout=10) print("Health check:", health.json()) break - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i == max_retries - 1: print(f"Failed to connect after {max_retries} attempts") sys.exit(1) print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") time.sleep(5) - + # Test cases based on version test_basic_crawl(tester) test_basic_crawl(tester) test_basic_crawl_sync(tester) - + # if version in ["full", "transformer"]: # test_cosine_extraction(tester) @@ -81,35 +100,37 @@ def test_docker_deployment(version="basic"): # test_llm_extraction(tester) # test_llm_with_ollama(tester) # test_screenshot(tester) - + def test_basic_crawl(tester: Crawl4AiTester): print("\n=== Testing Basic Crawl ===") request = { "urls": "https://www.nbcnews.com/business", - "priority": 10, - "session_id": "test" + "priority": 10, + "session_id": "test", } - + result = tester.submit_and_wait(request) print(f"Basic crawl result length: {len(result['result']['markdown'])}") assert result["result"]["success"] assert len(result["result"]["markdown"]) > 0 + def test_basic_crawl_sync(tester: Crawl4AiTester): print("\n=== Testing Basic Crawl (Sync) ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 10, - "session_id": "test" + "session_id": "test", } - + result = tester.submit_sync(request) print(f"Basic crawl result length: {len(result['result']['markdown'])}") - assert result['status'] == 'completed' - assert result['result']['success'] - assert len(result['result']['markdown']) > 0 - + assert result["status"] == "completed" + assert result["result"]["success"] + assert len(result["result"]["markdown"]) > 0 + + def test_js_execution(tester: Crawl4AiTester): print("\n=== Testing JS Execution ===") request = { @@ -119,32 +140,29 @@ def test_js_execution(tester: Crawl4AiTester): "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" ], "wait_for": "article.tease-card:nth-child(10)", - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } - + result = tester.submit_and_wait(request) print(f"JS execution result length: {len(result['result']['markdown'])}") assert result["result"]["success"] + def test_css_selector(tester: Crawl4AiTester): print("\n=== Testing CSS Selector ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 7, "css_selector": ".wide-tease-item__description", - "crawler_params": { - "headless": True - }, - "extra": {"word_count_threshold": 10} - + "crawler_params": {"headless": True}, + "extra": {"word_count_threshold": 10}, } - + result = tester.submit_and_wait(request) print(f"CSS selector result length: {len(result['result']['markdown'])}") assert result["result"]["success"] + def test_structured_extraction(tester: Crawl4AiTester): print("\n=== Testing Structured Extraction ===") schema = { @@ -165,21 +183,16 @@ def test_structured_extraction(tester: Crawl4AiTester): "name": "price", "selector": "td:nth-child(2)", "type": "text", - } + }, ], } - + request = { "urls": "https://www.coinbase.com/explore", "priority": 9, - "extraction_config": { - "type": "json_css", - "params": { - "schema": schema - } - } + "extraction_config": {"type": "json_css", "params": {"schema": schema}}, } - + result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) print(f"Extracted {len(extracted)} items") @@ -187,6 +200,7 @@ def test_structured_extraction(tester: Crawl4AiTester): assert result["result"]["success"] assert len(extracted) > 0 + def test_llm_extraction(tester: Crawl4AiTester): print("\n=== Testing LLM Extraction ===") schema = { @@ -194,20 +208,20 @@ def test_llm_extraction(tester: Crawl4AiTester): "properties": { "model_name": { "type": "string", - "description": "Name of the OpenAI model." + "description": "Name of the OpenAI model.", }, "input_fee": { "type": "string", - "description": "Fee for input token for the OpenAI model." + "description": "Fee for input token for the OpenAI model.", }, "output_fee": { "type": "string", - "description": "Fee for output token for the OpenAI model." - } + "description": "Fee for output token for the OpenAI model.", + }, }, - "required": ["model_name", "input_fee", "output_fee"] + "required": ["model_name", "input_fee", "output_fee"], } - + request = { "urls": "https://openai.com/api/pricing", "priority": 8, @@ -218,12 +232,12 @@ def test_llm_extraction(tester: Crawl4AiTester): "api_token": os.getenv("OPENAI_API_KEY"), "schema": schema, "extraction_type": "schema", - "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" - } + "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""", + }, }, - "crawler_params": {"word_count_threshold": 1} + "crawler_params": {"word_count_threshold": 1}, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -233,6 +247,7 @@ def test_llm_extraction(tester: Crawl4AiTester): except Exception as e: print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") + def test_llm_with_ollama(tester: Crawl4AiTester): print("\n=== Testing LLM with Ollama ===") schema = { @@ -240,20 +255,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester): "properties": { "article_title": { "type": "string", - "description": "The main title of the news article" + "description": "The main title of the news article", }, "summary": { "type": "string", - "description": "A brief summary of the article content" + "description": "A brief summary of the article content", }, "main_topics": { "type": "array", "items": {"type": "string"}, - "description": "Main topics or themes discussed in the article" - } - } + "description": "Main topics or themes discussed in the article", + }, + }, } - + request = { "urls": "https://www.nbcnews.com/business", "priority": 8, @@ -263,13 +278,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester): "provider": "ollama/llama2", "schema": schema, "extraction_type": "schema", - "instruction": "Extract the main article information including title, summary, and main topics." - } + "instruction": "Extract the main article information including title, summary, and main topics.", + }, }, "extra": {"word_count_threshold": 1}, - "crawler_params": {"verbose": True} + "crawler_params": {"verbose": True}, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -278,6 +293,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester): except Exception as e: print(f"Ollama extraction test failed: {str(e)}") + def test_cosine_extraction(tester: Crawl4AiTester): print("\n=== Testing Cosine Extraction ===") request = { @@ -289,11 +305,11 @@ def test_cosine_extraction(tester: Crawl4AiTester): "semantic_filter": "business finance economy", "word_count_threshold": 10, "max_dist": 0.2, - "top_k": 3 - } - } + "top_k": 3, + }, + }, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -303,30 +319,30 @@ def test_cosine_extraction(tester: Crawl4AiTester): except Exception as e: print(f"Cosine extraction test failed: {str(e)}") + def test_screenshot(tester: Crawl4AiTester): print("\n=== Testing Screenshot ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 5, "screenshot": True, - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } - + result = tester.submit_and_wait(request) print("Screenshot captured:", bool(result["result"]["screenshot"])) - + if result["result"]["screenshot"]: # Save screenshot screenshot_data = base64.b64decode(result["result"]["screenshot"]) with open("test_screenshot.jpg", "wb") as f: f.write(screenshot_data) print("Screenshot saved as test_screenshot.jpg") - + assert result["result"]["success"] + if __name__ == "__main__": version = sys.argv[1] if len(sys.argv) > 1 else "basic" # version = "full" - test_docker_deployment(version) \ No newline at end of file + test_docker_deployment(version) diff --git a/tests/test_cli_docs.py b/tests/test_cli_docs.py index 9d2a7841..6941f20d 100644 --- a/tests/test_cli_docs.py +++ b/tests/test_cli_docs.py @@ -1,13 +1,13 @@ import asyncio -from pathlib import Path from crawl4ai.docs_manager import DocsManager from click.testing import CliRunner from crawl4ai.cli import cli + def test_cli(): """Test all CLI commands""" runner = CliRunner() - + print("\n1. Testing docs update...") # Use sync version for testing docs_manager = DocsManager() @@ -27,17 +27,18 @@ def test_cli(): # print("\n3. Testing search...") # result = runner.invoke(cli, ['docs', 'search', 'how to use crawler', '--build-index']) # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") - # print(f"First 200 chars: {result.output[:200]}...") - + # print(f"First 200 chars: {result.output[:200]}...") + # print("\n4. Testing combine with sections...") # result = runner.invoke(cli, ['docs', 'combine', 'chunking_strategies', 'extraction_strategies', '--mode', 'extended']) # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") # print(f"First 200 chars: {result.output[:200]}...") print("\n5. Testing combine all sections...") - result = runner.invoke(cli, ['docs', 'combine', '--mode', 'condensed']) + result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"]) print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") print(f"First 200 chars: {result.output[:200]}...") + if __name__ == "__main__": - test_cli() \ No newline at end of file + test_cli() diff --git a/tests/test_docker.py b/tests/test_docker.py index c22acd55..3570d608 100644 --- a/tests/test_docker.py +++ b/tests/test_docker.py @@ -6,38 +6,44 @@ import base64 import os from typing import Dict, Any + class Crawl4AiTester: def __init__(self, base_url: str = "http://localhost:11235"): self.base_url = base_url - - def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: + + def submit_and_wait( + self, request_data: Dict[str, Any], timeout: int = 300 + ) -> Dict[str, Any]: # Submit crawl job response = requests.post(f"{self.base_url}/crawl", json=request_data) task_id = response.json()["task_id"] print(f"Task ID: {task_id}") - + # Poll for result start_time = time.time() while True: if time.time() - start_time > timeout: - raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") - + raise TimeoutError( + f"Task {task_id} did not complete within {timeout} seconds" + ) + result = requests.get(f"{self.base_url}/task/{task_id}") status = result.json() - + if status["status"] == "failed": print("Task failed:", status.get("error")) raise Exception(f"Task failed: {status.get('error')}") - + if status["status"] == "completed": return status - + time.sleep(2) + def test_docker_deployment(version="basic"): tester = Crawl4AiTester() print(f"Testing Crawl4AI Docker {version} version") - + # Health check with timeout and retry max_retries = 5 for i in range(max_retries): @@ -45,16 +51,16 @@ def test_docker_deployment(version="basic"): health = requests.get(f"{tester.base_url}/health", timeout=10) print("Health check:", health.json()) break - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i == max_retries - 1: print(f"Failed to connect after {max_retries} attempts") sys.exit(1) print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") time.sleep(5) - + # Test cases based on version test_basic_crawl(tester) - + # if version in ["full", "transformer"]: # test_cosine_extraction(tester) @@ -64,20 +70,18 @@ def test_docker_deployment(version="basic"): # test_llm_extraction(tester) # test_llm_with_ollama(tester) # test_screenshot(tester) - + def test_basic_crawl(tester: Crawl4AiTester): print("\n=== Testing Basic Crawl ===") - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 10 - } - + request = {"urls": "https://www.nbcnews.com/business", "priority": 10} + result = tester.submit_and_wait(request) print(f"Basic crawl result length: {len(result['result']['markdown'])}") assert result["result"]["success"] assert len(result["result"]["markdown"]) > 0 + def test_js_execution(tester: Crawl4AiTester): print("\n=== Testing JS Execution ===") request = { @@ -87,32 +91,29 @@ def test_js_execution(tester: Crawl4AiTester): "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" ], "wait_for": "article.tease-card:nth-child(10)", - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } - + result = tester.submit_and_wait(request) print(f"JS execution result length: {len(result['result']['markdown'])}") assert result["result"]["success"] + def test_css_selector(tester: Crawl4AiTester): print("\n=== Testing CSS Selector ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 7, "css_selector": ".wide-tease-item__description", - "crawler_params": { - "headless": True - }, - "extra": {"word_count_threshold": 10} - + "crawler_params": {"headless": True}, + "extra": {"word_count_threshold": 10}, } - + result = tester.submit_and_wait(request) print(f"CSS selector result length: {len(result['result']['markdown'])}") assert result["result"]["success"] + def test_structured_extraction(tester: Crawl4AiTester): print("\n=== Testing Structured Extraction ===") schema = { @@ -133,21 +134,16 @@ def test_structured_extraction(tester: Crawl4AiTester): "name": "price", "selector": "td:nth-child(2)", "type": "text", - } + }, ], } - + request = { "urls": "https://www.coinbase.com/explore", "priority": 9, - "extraction_config": { - "type": "json_css", - "params": { - "schema": schema - } - } + "extraction_config": {"type": "json_css", "params": {"schema": schema}}, } - + result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) print(f"Extracted {len(extracted)} items") @@ -155,6 +151,7 @@ def test_structured_extraction(tester: Crawl4AiTester): assert result["result"]["success"] assert len(extracted) > 0 + def test_llm_extraction(tester: Crawl4AiTester): print("\n=== Testing LLM Extraction ===") schema = { @@ -162,20 +159,20 @@ def test_llm_extraction(tester: Crawl4AiTester): "properties": { "model_name": { "type": "string", - "description": "Name of the OpenAI model." + "description": "Name of the OpenAI model.", }, "input_fee": { "type": "string", - "description": "Fee for input token for the OpenAI model." + "description": "Fee for input token for the OpenAI model.", }, "output_fee": { "type": "string", - "description": "Fee for output token for the OpenAI model." - } + "description": "Fee for output token for the OpenAI model.", + }, }, - "required": ["model_name", "input_fee", "output_fee"] + "required": ["model_name", "input_fee", "output_fee"], } - + request = { "urls": "https://openai.com/api/pricing", "priority": 8, @@ -186,12 +183,12 @@ def test_llm_extraction(tester: Crawl4AiTester): "api_token": os.getenv("OPENAI_API_KEY"), "schema": schema, "extraction_type": "schema", - "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" - } + "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""", + }, }, - "crawler_params": {"word_count_threshold": 1} + "crawler_params": {"word_count_threshold": 1}, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -201,6 +198,7 @@ def test_llm_extraction(tester: Crawl4AiTester): except Exception as e: print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") + def test_llm_with_ollama(tester: Crawl4AiTester): print("\n=== Testing LLM with Ollama ===") schema = { @@ -208,20 +206,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester): "properties": { "article_title": { "type": "string", - "description": "The main title of the news article" + "description": "The main title of the news article", }, "summary": { "type": "string", - "description": "A brief summary of the article content" + "description": "A brief summary of the article content", }, "main_topics": { "type": "array", "items": {"type": "string"}, - "description": "Main topics or themes discussed in the article" - } - } + "description": "Main topics or themes discussed in the article", + }, + }, } - + request = { "urls": "https://www.nbcnews.com/business", "priority": 8, @@ -231,13 +229,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester): "provider": "ollama/llama2", "schema": schema, "extraction_type": "schema", - "instruction": "Extract the main article information including title, summary, and main topics." - } + "instruction": "Extract the main article information including title, summary, and main topics.", + }, }, "extra": {"word_count_threshold": 1}, - "crawler_params": {"verbose": True} + "crawler_params": {"verbose": True}, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -246,6 +244,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester): except Exception as e: print(f"Ollama extraction test failed: {str(e)}") + def test_cosine_extraction(tester: Crawl4AiTester): print("\n=== Testing Cosine Extraction ===") request = { @@ -257,11 +256,11 @@ def test_cosine_extraction(tester: Crawl4AiTester): "semantic_filter": "business finance economy", "word_count_threshold": 10, "max_dist": 0.2, - "top_k": 3 - } - } + "top_k": 3, + }, + }, } - + try: result = tester.submit_and_wait(request) extracted = json.loads(result["result"]["extracted_content"]) @@ -271,30 +270,30 @@ def test_cosine_extraction(tester: Crawl4AiTester): except Exception as e: print(f"Cosine extraction test failed: {str(e)}") + def test_screenshot(tester: Crawl4AiTester): print("\n=== Testing Screenshot ===") request = { "urls": "https://www.nbcnews.com/business", "priority": 5, "screenshot": True, - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } - + result = tester.submit_and_wait(request) print("Screenshot captured:", bool(result["result"]["screenshot"])) - + if result["result"]["screenshot"]: # Save screenshot screenshot_data = base64.b64decode(result["result"]["screenshot"]) with open("test_screenshot.jpg", "wb") as f: f.write(screenshot_data) print("Screenshot saved as test_screenshot.jpg") - + assert result["result"]["success"] + if __name__ == "__main__": version = sys.argv[1] if len(sys.argv) > 1 else "basic" # version = "full" - test_docker_deployment(version) \ No newline at end of file + test_docker_deployment(version) diff --git a/tests/test_llmtxt.py b/tests/test_llmtxt.py index bdbe5c27..2cdb0271 100644 --- a/tests/test_llmtxt.py +++ b/tests/test_llmtxt.py @@ -3,20 +3,21 @@ from crawl4ai.async_logger import AsyncLogger from pathlib import Path import asyncio + async def main(): current_file = Path(__file__).resolve() # base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs" base_dir = current_file.parent.parent / "local/_docs/llm.txt" docs_dir = base_dir - + # Create directory if it doesn't exist docs_dir.mkdir(parents=True, exist_ok=True) - + # Initialize logger logger = AsyncLogger() # Updated initialization with default batching params # manager = AsyncLLMTextManager(docs_dir, logger, max_concurrent_calls=3, batch_size=2) - manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2) + manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2) # Let's first check what files we have print("\nAvailable files:") @@ -26,8 +27,7 @@ async def main(): # Generate index files print("\nGenerating index files...") await manager.generate_index_files( - force_generate_facts=False, - clear_bm25_cache=False + force_generate_facts=False, clear_bm25_cache=False ) # Test some relevant queries about Crawl4AI @@ -41,9 +41,12 @@ async def main(): results = manager.search(query, top_k=2) print(f"Results length: {len(results)} characters") if results: - print("First 200 chars of results:", results[:200].replace('\n', ' '), "...") + print( + "First 200 chars of results:", results[:200].replace("\n", " "), "..." + ) else: print("No results found") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/test_main.py b/tests/test_main.py index 19f938c8..0e938f59 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,8 +3,8 @@ import aiohttp import json import time import os -from typing import Optional, Dict, Any -from pydantic import BaseModel, HttpUrl +from typing import Dict, Any + class NBCNewsAPITest: def __init__(self, base_url: str = "http://localhost:8000"): @@ -20,7 +20,9 @@ class NBCNewsAPITest: await self.session.close() async def submit_crawl(self, request_data: Dict[str, Any]) -> str: - async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response: + async with self.session.post( + f"{self.base_url}/crawl", json=request_data + ) as response: result = await response.json() return result["task_id"] @@ -28,11 +30,15 @@ class NBCNewsAPITest: async with self.session.get(f"{self.base_url}/task/{task_id}") as response: return await response.json() - async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]: + async def wait_for_task( + self, task_id: str, timeout: int = 300, poll_interval: int = 2 + ) -> Dict[str, Any]: start_time = time.time() while True: if time.time() - start_time > timeout: - raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") + raise TimeoutError( + f"Task {task_id} did not complete within {timeout} seconds" + ) status = await self.get_task_status(task_id) if status["status"] in ["completed", "failed"]: @@ -44,13 +50,11 @@ class NBCNewsAPITest: async with self.session.get(f"{self.base_url}/health") as response: return await response.json() + async def test_basic_crawl(): print("\n=== Testing Basic Crawl ===") async with NBCNewsAPITest() as api: - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 10 - } + request = {"urls": "https://www.nbcnews.com/business", "priority": 10} task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) print(f"Basic crawl result length: {len(result['result']['markdown'])}") @@ -58,6 +62,7 @@ async def test_basic_crawl(): assert "result" in result assert result["result"]["success"] + async def test_js_execution(): print("\n=== Testing JS Execution ===") async with NBCNewsAPITest() as api: @@ -68,9 +73,7 @@ async def test_js_execution(): "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" ], "wait_for": "article.tease-card:nth-child(10)", - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) @@ -78,13 +81,14 @@ async def test_js_execution(): assert result["status"] == "completed" assert result["result"]["success"] + async def test_css_selector(): print("\n=== Testing CSS Selector ===") async with NBCNewsAPITest() as api: request = { "urls": "https://www.nbcnews.com/business", "priority": 7, - "css_selector": ".wide-tease-item__description" + "css_selector": ".wide-tease-item__description", } task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) @@ -92,6 +96,7 @@ async def test_css_selector(): assert result["status"] == "completed" assert result["result"]["success"] + async def test_structured_extraction(): print("\n=== Testing Structured Extraction ===") async with NBCNewsAPITest() as api: @@ -99,34 +104,25 @@ async def test_structured_extraction(): "name": "NBC News Articles", "baseSelector": "article.tease-card", "fields": [ - { - "name": "title", - "selector": "h2", - "type": "text" - }, + {"name": "title", "selector": "h2", "type": "text"}, { "name": "description", "selector": ".tease-card__description", - "type": "text" + "type": "text", }, { "name": "link", "selector": "a", "type": "attribute", - "attribute": "href" - } - ] + "attribute": "href", + }, + ], } - + request = { "urls": "https://www.nbcnews.com/business", "priority": 9, - "extraction_config": { - "type": "json_css", - "params": { - "schema": schema - } - } + "extraction_config": {"type": "json_css", "params": {"schema": schema}}, } task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) @@ -136,6 +132,7 @@ async def test_structured_extraction(): assert result["result"]["success"] assert len(extracted) > 0 + async def test_batch_crawl(): print("\n=== Testing Batch Crawl ===") async with NBCNewsAPITest() as api: @@ -143,12 +140,10 @@ async def test_batch_crawl(): "urls": [ "https://www.nbcnews.com/business", "https://www.nbcnews.com/business/consumer", - "https://www.nbcnews.com/business/economy" + "https://www.nbcnews.com/business/economy", ], "priority": 6, - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) @@ -157,6 +152,7 @@ async def test_batch_crawl(): assert "results" in result assert len(result["results"]) == 3 + async def test_llm_extraction(): print("\n=== Testing LLM Extraction with Ollama ===") async with NBCNewsAPITest() as api: @@ -165,19 +161,19 @@ async def test_llm_extraction(): "properties": { "article_title": { "type": "string", - "description": "The main title of the news article" + "description": "The main title of the news article", }, "summary": { "type": "string", - "description": "A brief summary of the article content" + "description": "A brief summary of the article content", }, "main_topics": { "type": "array", "items": {"type": "string"}, - "description": "Main topics or themes discussed in the article" - } + "description": "Main topics or themes discussed in the article", + }, }, - "required": ["article_title", "summary", "main_topics"] + "required": ["article_title", "summary", "main_topics"], } request = { @@ -191,26 +187,24 @@ async def test_llm_extraction(): "schema": schema, "extraction_type": "schema", "instruction": """Extract the main article information including title, a brief summary, and main topics discussed. - Focus on the primary business news article on the page.""" - } + Focus on the primary business news article on the page.""", + }, }, - "crawler_params": { - "headless": True, - "word_count_threshold": 1 - } + "crawler_params": {"headless": True, "word_count_threshold": 1}, } - + task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) - + if result["status"] == "completed": extracted = json.loads(result["result"]["extracted_content"]) - print(f"Extracted article analysis:") + print("Extracted article analysis:") print(json.dumps(extracted, indent=2)) - + assert result["status"] == "completed" assert result["result"]["success"] + async def test_screenshot(): print("\n=== Testing Screenshot ===") async with NBCNewsAPITest() as api: @@ -218,9 +212,7 @@ async def test_screenshot(): "urls": "https://www.nbcnews.com/business", "priority": 5, "screenshot": True, - "crawler_params": { - "headless": True - } + "crawler_params": {"headless": True}, } task_id = await api.submit_crawl(request) result = await api.wait_for_task(task_id) @@ -229,6 +221,7 @@ async def test_screenshot(): assert result["result"]["success"] assert result["result"]["screenshot"] is not None + async def test_priority_handling(): print("\n=== Testing Priority Handling ===") async with NBCNewsAPITest() as api: @@ -236,7 +229,7 @@ async def test_priority_handling(): low_priority = { "urls": "https://www.nbcnews.com/business", "priority": 1, - "crawler_params": {"headless": True} + "crawler_params": {"headless": True}, } low_task_id = await api.submit_crawl(low_priority) @@ -244,7 +237,7 @@ async def test_priority_handling(): high_priority = { "urls": "https://www.nbcnews.com/business/consumer", "priority": 10, - "crawler_params": {"headless": True} + "crawler_params": {"headless": True}, } high_task_id = await api.submit_crawl(high_priority) @@ -256,6 +249,7 @@ async def test_priority_handling(): assert high_result["status"] == "completed" assert low_result["status"] == "completed" + async def main(): try: # Start with health check @@ -277,5 +271,6 @@ async def main(): print(f"Test failed: {str(e)}") raise + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/test_scraping_strategy.py b/tests/test_scraping_strategy.py index 6d742182..425d02c9 100644 --- a/tests/test_scraping_strategy.py +++ b/tests/test_scraping_strategy.py @@ -1,21 +1,26 @@ import nest_asyncio + nest_asyncio.apply() import asyncio -from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, LXMLWebScrapingStrategy, CacheMode +from crawl4ai import ( + AsyncWebCrawler, + CrawlerRunConfig, + LXMLWebScrapingStrategy, + CacheMode, +) + async def main(): config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, - scraping_strategy=LXMLWebScrapingStrategy() # Faster alternative to default BeautifulSoup + scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup ) async with AsyncWebCrawler() as crawler: - result = await crawler.arun( - url="https://example.com", - config=config - ) + result = await crawler.arun(url="https://example.com", config=config) print(f"Success: {result.success}") print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/test_web_crawler.py b/tests/test_web_crawler.py index 99360f42..d6eddfdc 100644 --- a/tests/test_web_crawler.py +++ b/tests/test_web_crawler.py @@ -1,79 +1,105 @@ import unittest, os from crawl4ai.web_crawler import WebCrawler -from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking -from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy +from crawl4ai.chunking_strategy import ( + RegexChunking, + FixedLengthWordChunking, + SlidingWindowChunking, +) +from crawl4ai.extraction_strategy import ( + CosineStrategy, + LLMExtractionStrategy, + TopicExtractionStrategy, + NoExtractionStrategy, +) + class TestWebCrawler(unittest.TestCase): - def setUp(self): self.crawler = WebCrawler() - + def test_warmup(self): self.crawler.warmup() self.assertTrue(self.crawler.ready, "WebCrawler failed to warm up") - + def test_run_default_strategies(self): result = self.crawler.run( - url='https://www.nbcnews.com/business', + url="https://www.nbcnews.com/business", word_count_threshold=5, chunking_strategy=RegexChunking(), - extraction_strategy=CosineStrategy(), bypass_cache=True + extraction_strategy=CosineStrategy(), + bypass_cache=True, ) - self.assertTrue(result.success, "Failed to crawl and extract using default strategies") - + self.assertTrue( + result.success, "Failed to crawl and extract using default strategies" + ) + def test_run_different_strategies(self): - url = 'https://www.nbcnews.com/business' - + url = "https://www.nbcnews.com/business" + # Test with FixedLengthWordChunking and LLMExtractionStrategy result = self.crawler.run( url=url, word_count_threshold=5, chunking_strategy=FixedLengthWordChunking(chunk_size=100), - extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True + extraction_strategy=LLMExtractionStrategy( + provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY") + ), + bypass_cache=True, ) - self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy") - + self.assertTrue( + result.success, + "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy", + ) + # Test with SlidingWindowChunking and TopicExtractionStrategy result = self.crawler.run( url=url, word_count_threshold=5, chunking_strategy=SlidingWindowChunking(window_size=100, step=50), - extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True + extraction_strategy=TopicExtractionStrategy(num_keywords=5), + bypass_cache=True, ) - self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy") - + self.assertTrue( + result.success, + "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy", + ) + def test_invalid_url(self): with self.assertRaises(Exception) as context: - self.crawler.run(url='invalid_url', bypass_cache=True) + self.crawler.run(url="invalid_url", bypass_cache=True) self.assertIn("Invalid URL", str(context.exception)) - + def test_unsupported_extraction_strategy(self): with self.assertRaises(Exception) as context: - self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True) + self.crawler.run( + url="https://www.nbcnews.com/business", + extraction_strategy="UnsupportedStrategy", + bypass_cache=True, + ) self.assertIn("Unsupported extraction strategy", str(context.exception)) - + def test_invalid_css_selector(self): with self.assertRaises(ValueError) as context: - self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True) + self.crawler.run( + url="https://www.nbcnews.com/business", + css_selector="invalid_selector", + bypass_cache=True, + ) self.assertIn("Invalid CSS selector", str(context.exception)) - def test_crawl_with_cache_and_bypass_cache(self): - url = 'https://www.nbcnews.com/business' - + url = "https://www.nbcnews.com/business" + # First crawl with cache enabled result = self.crawler.run(url=url, bypass_cache=False) self.assertTrue(result.success, "Failed to crawl and cache the result") - + # Second crawl with bypass_cache=True result = self.crawler.run(url=url, bypass_cache=True) self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data") - + def test_fetch_multiple_pages(self): - urls = [ - 'https://www.nbcnews.com/business', - 'https://www.bbc.com/news' - ] + urls = ["https://www.nbcnews.com/business", "https://www.bbc.com/news"] results = [] for url in urls: result = self.crawler.run( @@ -81,31 +107,42 @@ class TestWebCrawler(unittest.TestCase): word_count_threshold=5, chunking_strategy=RegexChunking(), extraction_strategy=CosineStrategy(), - bypass_cache=True + bypass_cache=True, ) results.append(result) - + self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages") for result in results: - self.assertTrue(result.success, "Failed to crawl and extract a page in the list") - + self.assertTrue( + result.success, "Failed to crawl and extract a page in the list" + ) + def test_run_fixed_length_word_chunking_and_no_extraction(self): result = self.crawler.run( - url='https://www.nbcnews.com/business', + url="https://www.nbcnews.com/business", word_count_threshold=5, chunking_strategy=FixedLengthWordChunking(chunk_size=100), - extraction_strategy=NoExtractionStrategy(), bypass_cache=True + extraction_strategy=NoExtractionStrategy(), + bypass_cache=True, + ) + self.assertTrue( + result.success, + "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy", ) - self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy") def test_run_sliding_window_and_no_extraction(self): result = self.crawler.run( - url='https://www.nbcnews.com/business', + url="https://www.nbcnews.com/business", word_count_threshold=5, chunking_strategy=SlidingWindowChunking(window_size=100, step=50), - extraction_strategy=NoExtractionStrategy(), bypass_cache=True + extraction_strategy=NoExtractionStrategy(), + bypass_cache=True, + ) + self.assertTrue( + result.success, + "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy", ) - self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()