Apply Ruff Corrections

This commit is contained in:
UncleCode
2025-01-13 19:19:58 +08:00
parent c3370ec5da
commit 8ec12d7d68
84 changed files with 6861 additions and 5076 deletions

8
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,8 @@
# .pre-commit-config.yaml
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@@ -2,14 +2,28 @@
from .async_webcrawler import AsyncWebCrawler, CacheMode from .async_webcrawler import AsyncWebCrawler, CacheMode
from .async_configs import BrowserConfig, CrawlerRunConfig from .async_configs import BrowserConfig, CrawlerRunConfig
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy from .content_scraping_strategy import (
from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy ContentScrapingStrategy,
WebScrapingStrategy,
LXMLWebScrapingStrategy,
)
from .extraction_strategy import (
ExtractionStrategy,
LLMExtractionStrategy,
CosineStrategy,
JsonCssExtractionStrategy,
)
from .chunking_strategy import ChunkingStrategy, RegexChunking from .chunking_strategy import ChunkingStrategy, RegexChunking
from .markdown_generation_strategy import DefaultMarkdownGenerator from .markdown_generation_strategy import DefaultMarkdownGenerator
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter from .content_filter_strategy import PruningContentFilter, BM25ContentFilter
from .models import CrawlResult, MarkdownGenerationResult from .models import CrawlResult, MarkdownGenerationResult
from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode from .async_dispatcher import (
from .__version__ import __version__ MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
)
__all__ = [ __all__ = [
"AsyncWebCrawler", "AsyncWebCrawler",
@@ -18,39 +32,44 @@ __all__ = [
"ContentScrapingStrategy", "ContentScrapingStrategy",
"WebScrapingStrategy", "WebScrapingStrategy",
"LXMLWebScrapingStrategy", "LXMLWebScrapingStrategy",
'BrowserConfig', "BrowserConfig",
'CrawlerRunConfig', "CrawlerRunConfig",
'ExtractionStrategy', "ExtractionStrategy",
'LLMExtractionStrategy', "LLMExtractionStrategy",
'CosineStrategy', "CosineStrategy",
'JsonCssExtractionStrategy', "JsonCssExtractionStrategy",
'ChunkingStrategy', "ChunkingStrategy",
'RegexChunking', "RegexChunking",
'DefaultMarkdownGenerator', "DefaultMarkdownGenerator",
'PruningContentFilter', "PruningContentFilter",
'BM25ContentFilter', "BM25ContentFilter",
'MemoryAdaptiveDispatcher', "MemoryAdaptiveDispatcher",
'SemaphoreDispatcher', "SemaphoreDispatcher",
'RateLimiter', "RateLimiter",
'CrawlerMonitor', "CrawlerMonitor",
'DisplayMode', "DisplayMode",
'MarkdownGenerationResult', "MarkdownGenerationResult",
] ]
def is_sync_version_installed(): def is_sync_version_installed():
try: try:
import selenium import selenium
return True return True
except ImportError: except ImportError:
return False return False
if is_sync_version_installed(): if is_sync_version_installed():
try: try:
from .web_crawler import WebCrawler from .web_crawler import WebCrawler
__all__.append("WebCrawler") __all__.append("WebCrawler")
except ImportError: except ImportError:
import warnings print(
print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.") "Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies."
)
else: else:
WebCrawler = None WebCrawler = None
# import warnings # import warnings

View File

@@ -5,7 +5,6 @@ from .config import (
PAGE_TIMEOUT, PAGE_TIMEOUT,
IMAGE_SCORE_THRESHOLD, IMAGE_SCORE_THRESHOLD,
SOCIAL_MEDIA_DOMAINS, SOCIAL_MEDIA_DOMAINS,
) )
from .user_agent_generator import UserAgentGenerator from .user_agent_generator import UserAgentGenerator
from .extraction_strategy import ExtractionStrategy from .extraction_strategy import ExtractionStrategy
@@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
from typing import Union, List from typing import Union, List
class BrowserConfig: class BrowserConfig:
""" """
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy. Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
@@ -335,10 +335,8 @@ class CrawlerRunConfig:
prettiify: bool = False, prettiify: bool = False,
parser_type: str = "lxml", parser_type: str = "lxml",
scraping_strategy: ContentScrapingStrategy = None, scraping_strategy: ContentScrapingStrategy = None,
# SSL Parameters # SSL Parameters
fetch_ssl_certificate: bool = False, fetch_ssl_certificate: bool = False,
# Caching Parameters # Caching Parameters
cache_mode=None, cache_mode=None,
session_id: str = None, session_id: str = None,
@@ -346,7 +344,6 @@ class CrawlerRunConfig:
disable_cache: bool = False, disable_cache: bool = False,
no_cache_read: bool = False, no_cache_read: bool = False,
no_cache_write: bool = False, no_cache_write: bool = False,
# Page Navigation and Timing Parameters # Page Navigation and Timing Parameters
wait_until: str = "domcontentloaded", wait_until: str = "domcontentloaded",
page_timeout: int = PAGE_TIMEOUT, page_timeout: int = PAGE_TIMEOUT,
@@ -356,7 +353,6 @@ class CrawlerRunConfig:
mean_delay: float = 0.1, mean_delay: float = 0.1,
max_range: float = 0.3, max_range: float = 0.3,
semaphore_count: int = 5, semaphore_count: int = 5,
# Page Interaction Parameters # Page Interaction Parameters
js_code: Union[str, List[str]] = None, js_code: Union[str, List[str]] = None,
js_only: bool = False, js_only: bool = False,
@@ -369,7 +365,6 @@ class CrawlerRunConfig:
override_navigator: bool = False, override_navigator: bool = False,
magic: bool = False, magic: bool = False,
adjust_viewport_to_content: bool = False, adjust_viewport_to_content: bool = False,
# Media Handling Parameters # Media Handling Parameters
screenshot: bool = False, screenshot: bool = False,
screenshot_wait_for: float = None, screenshot_wait_for: float = None,
@@ -378,17 +373,14 @@ class CrawlerRunConfig:
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
image_score_threshold: int = IMAGE_SCORE_THRESHOLD, image_score_threshold: int = IMAGE_SCORE_THRESHOLD,
exclude_external_images: bool = False, exclude_external_images: bool = False,
# Link and Domain Handling Parameters # Link and Domain Handling Parameters
exclude_social_media_domains: list = None, exclude_social_media_domains: list = None,
exclude_external_links: bool = False, exclude_external_links: bool = False,
exclude_social_media_links: bool = False, exclude_social_media_links: bool = False,
exclude_domains: list = None, exclude_domains: list = None,
# Debugging and Logging Parameters # Debugging and Logging Parameters
verbose: bool = True, verbose: bool = True,
log_console: bool = False, log_console: bool = False,
url: str = None, url: str = None,
): ):
self.url = url self.url = url
@@ -453,7 +445,9 @@ class CrawlerRunConfig:
self.exclude_external_images = exclude_external_images self.exclude_external_images = exclude_external_images
# Link and Domain Handling Parameters # Link and Domain Handling Parameters
self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS self.exclude_social_media_domains = (
exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
)
self.exclude_external_links = exclude_external_links self.exclude_external_links = exclude_external_links
self.exclude_social_media_links = exclude_social_media_links self.exclude_social_media_links = exclude_social_media_links
self.exclude_domains = exclude_domains or [] self.exclude_domains = exclude_domains or []
@@ -466,11 +460,15 @@ class CrawlerRunConfig:
if self.extraction_strategy is not None and not isinstance( if self.extraction_strategy is not None and not isinstance(
self.extraction_strategy, ExtractionStrategy self.extraction_strategy, ExtractionStrategy
): ):
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy") raise ValueError(
"extraction_strategy must be an instance of ExtractionStrategy"
)
if self.chunking_strategy is not None and not isinstance( if self.chunking_strategy is not None and not isinstance(
self.chunking_strategy, ChunkingStrategy self.chunking_strategy, ChunkingStrategy
): ):
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy") raise ValueError(
"chunking_strategy must be an instance of ChunkingStrategy"
)
# Set default chunking strategy if None # Set default chunking strategy if None
if self.chunking_strategy is None: if self.chunking_strategy is None:
@@ -494,10 +492,8 @@ class CrawlerRunConfig:
prettiify=kwargs.get("prettiify", False), prettiify=kwargs.get("prettiify", False),
parser_type=kwargs.get("parser_type", "lxml"), parser_type=kwargs.get("parser_type", "lxml"),
scraping_strategy=kwargs.get("scraping_strategy"), scraping_strategy=kwargs.get("scraping_strategy"),
# SSL Parameters # SSL Parameters
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False), fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
# Caching Parameters # Caching Parameters
cache_mode=kwargs.get("cache_mode"), cache_mode=kwargs.get("cache_mode"),
session_id=kwargs.get("session_id"), session_id=kwargs.get("session_id"),
@@ -505,7 +501,6 @@ class CrawlerRunConfig:
disable_cache=kwargs.get("disable_cache", False), disable_cache=kwargs.get("disable_cache", False),
no_cache_read=kwargs.get("no_cache_read", False), no_cache_read=kwargs.get("no_cache_read", False),
no_cache_write=kwargs.get("no_cache_write", False), no_cache_write=kwargs.get("no_cache_write", False),
# Page Navigation and Timing Parameters # Page Navigation and Timing Parameters
wait_until=kwargs.get("wait_until", "domcontentloaded"), wait_until=kwargs.get("wait_until", "domcontentloaded"),
page_timeout=kwargs.get("page_timeout", 60000), page_timeout=kwargs.get("page_timeout", 60000),
@@ -515,7 +510,6 @@ class CrawlerRunConfig:
mean_delay=kwargs.get("mean_delay", 0.1), mean_delay=kwargs.get("mean_delay", 0.1),
max_range=kwargs.get("max_range", 0.3), max_range=kwargs.get("max_range", 0.3),
semaphore_count=kwargs.get("semaphore_count", 5), semaphore_count=kwargs.get("semaphore_count", 5),
# Page Interaction Parameters # Page Interaction Parameters
js_code=kwargs.get("js_code"), js_code=kwargs.get("js_code"),
js_only=kwargs.get("js_only", False), js_only=kwargs.get("js_only", False),
@@ -528,26 +522,31 @@ class CrawlerRunConfig:
override_navigator=kwargs.get("override_navigator", False), override_navigator=kwargs.get("override_navigator", False),
magic=kwargs.get("magic", False), magic=kwargs.get("magic", False),
adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False), adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
# Media Handling Parameters # Media Handling Parameters
screenshot=kwargs.get("screenshot", False), screenshot=kwargs.get("screenshot", False),
screenshot_wait_for=kwargs.get("screenshot_wait_for"), screenshot_wait_for=kwargs.get("screenshot_wait_for"),
screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD), screenshot_height_threshold=kwargs.get(
"screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD
),
pdf=kwargs.get("pdf", False), pdf=kwargs.get("pdf", False),
image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD), image_description_min_word_threshold=kwargs.get(
image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD), "image_description_min_word_threshold",
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
),
image_score_threshold=kwargs.get(
"image_score_threshold", IMAGE_SCORE_THRESHOLD
),
exclude_external_images=kwargs.get("exclude_external_images", False), exclude_external_images=kwargs.get("exclude_external_images", False),
# Link and Domain Handling Parameters # Link and Domain Handling Parameters
exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS), exclude_social_media_domains=kwargs.get(
"exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS
),
exclude_external_links=kwargs.get("exclude_external_links", False), exclude_external_links=kwargs.get("exclude_external_links", False),
exclude_social_media_links=kwargs.get("exclude_social_media_links", False), exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
exclude_domains=kwargs.get("exclude_domains", []), exclude_domains=kwargs.get("exclude_domains", []),
# Debugging and Logging Parameters # Debugging and Logging Parameters
verbose=kwargs.get("verbose", True), verbose=kwargs.get("verbose", True),
log_console=kwargs.get("log_console", False), log_console=kwargs.get("log_console", False),
url=kwargs.get("url"), url=kwargs.get("url"),
) )

View File

@@ -2,27 +2,25 @@ import asyncio
import base64 import base64
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Dict, Any, List, Optional, Awaitable, Union from typing import Callable, Dict, Any, List, Optional, Union
import os, sys, shutil import os
import tempfile, subprocess import sys
from playwright.async_api import async_playwright, Page, Browser, Error, BrowserContext import shutil
import tempfile
import subprocess
from playwright.async_api import Page, Error, BrowserContext
from playwright.async_api import TimeoutError as PlaywrightTimeoutError from playwright.async_api import TimeoutError as PlaywrightTimeoutError
from io import BytesIO from io import BytesIO
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from pathlib import Path
from playwright.async_api import ProxySettings
from pydantic import BaseModel
import hashlib import hashlib
import json
import uuid import uuid
from .js_snippet import load_js_script from .js_snippet import load_js_script
from .models import AsyncCrawlResponse from .models import AsyncCrawlResponse
from .utils import get_error_context
from .user_agent_generator import UserAgentGenerator from .user_agent_generator import UserAgentGenerator
from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT
from .async_configs import BrowserConfig, CrawlerRunConfig from .async_configs import BrowserConfig, CrawlerRunConfig
from .async_logger import AsyncLogger from .async_logger import AsyncLogger
from playwright_stealth import StealthConfig, stealth_async from playwright_stealth import StealthConfig
from .ssl_certificate import SSLCertificate from .ssl_certificate import SSLCertificate
stealth_config = StealthConfig( stealth_config = StealthConfig(
@@ -94,6 +92,7 @@ class ManagedBrowser:
temp_dir: str temp_dir: str
debugging_port: int debugging_port: int
host: str host: str
def __init__( def __init__(
self, self,
browser_type: str = "chromium", browser_type: str = "chromium",
@@ -139,7 +138,7 @@ class ManagedBrowser:
self.user_data_dir = self.temp_dir self.user_data_dir = self.temp_dir
# Get browser path and args based on OS and browser type # Get browser path and args based on OS and browser type
browser_path = self._get_browser_path() # browser_path = self._get_browser_path()
args = self._get_browser_args() args = self._get_browser_args()
# Start browser process # Start browser process
@@ -300,6 +299,7 @@ class BrowserManager:
sessions (dict): Dictionary to store session information sessions (dict): Dictionary to store session information
session_ttl (int): Session timeout in seconds session_ttl (int): Session timeout in seconds
""" """
def __init__(self, browser_config: BrowserConfig, logger=None): def __init__(self, browser_config: BrowserConfig, logger=None):
""" """
Initialize the BrowserManager with a browser configuration. Initialize the BrowserManager with a browser configuration.
@@ -453,7 +453,12 @@ class BrowserManager:
return browser_args return browser_args
async def setup_context(self, context: BrowserContext, crawlerRunConfig: CrawlerRunConfig = None, is_default=False): async def setup_context(
self,
context: BrowserContext,
crawlerRunConfig: CrawlerRunConfig = None,
is_default=False,
):
""" """
Set up a browser context with the configured options. Set up a browser context with the configured options.
@@ -496,9 +501,9 @@ class BrowserManager:
context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT) context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT)
if self.config.downloads_path: if self.config.downloads_path:
context._impl_obj._options["accept_downloads"] = True context._impl_obj._options["accept_downloads"] = True
context._impl_obj._options["downloads_path"] = ( context._impl_obj._options[
self.config.downloads_path "downloads_path"
) ] = self.config.downloads_path
# Handle user agent and browser hints # Handle user agent and browser hints
if self.config.user_agent: if self.config.user_agent:
@@ -511,7 +516,15 @@ class BrowserManager:
# Add default cookie # Add default cookie
await context.add_cookies( await context.add_cookies(
[{"name": "cookiesEnabled", "value": "true", "url": crawlerRunConfig.url if crawlerRunConfig else "https://crawl4ai.com/"}] [
{
"name": "cookiesEnabled",
"value": "true",
"url": crawlerRunConfig.url
if crawlerRunConfig
else "https://crawl4ai.com/",
}
]
) )
# Handle navigator overrides # Handle navigator overrides
@@ -541,20 +554,57 @@ class BrowserManager:
blocked_extensions = [ blocked_extensions = [
# Images # Images
'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'ico', 'bmp', 'tiff', 'psd', "jpg",
"jpeg",
"png",
"gif",
"webp",
"svg",
"ico",
"bmp",
"tiff",
"psd",
# Fonts # Fonts
'woff', 'woff2', 'ttf', 'otf', 'eot', "woff",
"woff2",
"ttf",
"otf",
"eot",
# Styles # Styles
# 'css', 'less', 'scss', 'sass', # 'css', 'less', 'scss', 'sass',
# Media # Media
'mp4', 'webm', 'ogg', 'avi', 'mov', 'wmv', 'flv', 'm4v', "mp4",
'mp3', 'wav', 'aac', 'm4a', 'opus', 'flac', "webm",
"ogg",
"avi",
"mov",
"wmv",
"flv",
"m4v",
"mp3",
"wav",
"aac",
"m4a",
"opus",
"flac",
# Documents # Documents
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', "pdf",
"doc",
"docx",
"xls",
"xlsx",
"ppt",
"pptx",
# Archives # Archives
'zip', 'rar', '7z', 'tar', 'gz', "zip",
"rar",
"7z",
"tar",
"gz",
# Scripts and data # Scripts and data
'xml', 'swf', 'wasm' "xml",
"swf",
"wasm",
] ]
# Common context settings # Common context settings
@@ -672,12 +722,12 @@ class AsyncCrawlerStrategy(ABC):
Abstract base class for crawler strategies. Abstract base class for crawler strategies.
Subclasses must implement the crawl method. Subclasses must implement the crawl method.
""" """
@abstractmethod @abstractmethod
async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
pass # 4 + 3 pass # 4 + 3
class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
""" """
Crawler strategy using Playwright. Crawler strategy using Playwright.
@@ -706,6 +756,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Run the crawler for a single URL. Run the crawler for a single URL.
""" """
def __init__( def __init__(
self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs
): ):
@@ -917,7 +968,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
"or explicitly prefixed with 'js:' or 'css:'." "or explicitly prefixed with 'js:' or 'css:'."
) )
async def csp_compliant_wait( self, page: Page, user_wait_function: str, timeout: float = 30000 ): async def csp_compliant_wait(
self, page: Page, user_wait_function: str, timeout: float = 30000
):
""" """
Wait for a condition in a CSP-compliant way. Wait for a condition in a CSP-compliant way.
@@ -1045,7 +1098,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
page, context = await self.browser_manager.get_page(session_id, user_agent) page, context = await self.browser_manager.get_page(session_id, user_agent)
return session_id return session_id
async def crawl( self, url: str, config: CrawlerRunConfig, **kwargs ) -> AsyncCrawlResponse: async def crawl(
self, url: str, config: CrawlerRunConfig, **kwargs
) -> AsyncCrawlResponse:
""" """
Crawls a given URL or processes raw HTML/local file content based on the URL prefix. Crawls a given URL or processes raw HTML/local file content based on the URL prefix.
@@ -1104,7 +1159,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
"URL must start with 'http://', 'https://', 'file://', or 'raw:'" "URL must start with 'http://', 'https://', 'file://', or 'raw:'"
) )
async def _crawl_web( self, url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse: async def _crawl_web(
self, url: str, config: CrawlerRunConfig
) -> AsyncCrawlResponse:
""" """
Internal method to crawl web URLs with the specified configuration. Internal method to crawl web URLs with the specified configuration.
@@ -1190,9 +1247,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
nonce = hashlib.sha256(os.urandom(32)).hexdigest() nonce = hashlib.sha256(os.urandom(32)).hexdigest()
# Add CSP headers to the request # Add CSP headers to the request
await page.set_extra_http_headers({ await page.set_extra_http_headers(
'Content-Security-Policy': f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'" {
}) "Content-Security-Policy": f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'"
}
)
response = await page.goto( response = await page.goto(
url, wait_until=config.wait_until, timeout=config.page_timeout url, wait_until=config.wait_until, timeout=config.page_timeout
@@ -1200,7 +1259,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
except Error as e: except Error as e:
raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}") raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}")
await self.execute_hook("after_goto", page, context=context, url=url, response=response) await self.execute_hook(
"after_goto", page, context=context, url=url, response=response
)
if response is None: if response is None:
status_code = 200 status_code = 200
@@ -1229,14 +1290,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
style.opacity !== '0'; style.opacity !== '0';
return isVisible; return isVisible;
}""", }""",
timeout=30000 timeout=30000,
) )
if not is_visible and not config.ignore_body_visibility: if not is_visible and not config.ignore_body_visibility:
visibility_info = await self.check_visibility(page) visibility_info = await self.check_visibility(page)
raise Error(f"Body element is hidden: {visibility_info}") raise Error(f"Body element is hidden: {visibility_info}")
except Error as e: except Error:
visibility_info = await self.check_visibility(page) visibility_info = await self.check_visibility(page)
if self.config.verbose: if self.config.verbose:
@@ -1249,7 +1310,6 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if not config.ignore_body_visibility: if not config.ignore_body_visibility:
raise Error(f"Body element is hidden: {visibility_info}") raise Error(f"Body element is hidden: {visibility_info}")
# try: # try:
# await page.wait_for_selector("body", state="attached", timeout=30000) # await page.wait_for_selector("body", state="attached", timeout=30000)
@@ -1303,7 +1363,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
images_loaded = await self.csp_compliant_wait( images_loaded = await self.csp_compliant_wait(
page, page,
"() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)", "() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)",
timeout=1000 timeout=1000,
) )
if not images_loaded and self.logger: if not images_loaded and self.logger:
@@ -1316,8 +1376,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if not self.browser_config.text_mode and config.adjust_viewport_to_content: if not self.browser_config.text_mode and config.adjust_viewport_to_content:
try: try:
dimensions = await self.get_page_dimensions(page) dimensions = await self.get_page_dimensions(page)
page_height = dimensions['height'] page_height = dimensions["height"]
page_width = dimensions['width'] page_width = dimensions["width"]
# page_width = await page.evaluate( # page_width = await page.evaluate(
# "document.documentElement.scrollWidth" # "document.documentElement.scrollWidth"
# ) # )
@@ -1364,12 +1424,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if config.js_code: if config.js_code:
# execution_result = await self.execute_user_script(page, config.js_code) # execution_result = await self.execute_user_script(page, config.js_code)
execution_result = await self.robust_execute_user_script(page, config.js_code) execution_result = await self.robust_execute_user_script(
page, config.js_code
)
if not execution_result["success"]: if not execution_result["success"]:
self.logger.warning( self.logger.warning(
message="User script execution had issues: {error}", message="User script execution had issues: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": execution_result.get("error")} params={"error": execution_result.get("error")},
) )
await self.execute_hook("on_execution_started", page, context=context) await self.execute_hook("on_execution_started", page, context=context)
@@ -1425,7 +1487,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# Get final HTML content # Get final HTML content
html = await page.content() html = await page.content()
await self.execute_hook("before_return_html", page = page, html = html, context=context) await self.execute_hook(
"before_return_html", page=page, html=html, context=context
)
# Handle PDF and screenshot generation # Handle PDF and screenshot generation
start_export_time = time.perf_counter() start_export_time = time.perf_counter()
@@ -1511,7 +1575,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# total_height = await page.evaluate("document.documentElement.scrollHeight") # total_height = await page.evaluate("document.documentElement.scrollHeight")
dimensions = await self.get_page_dimensions(page) dimensions = await self.get_page_dimensions(page)
total_height = dimensions['height'] total_height = dimensions["height"]
while current_position < total_height: while current_position < total_height:
current_position = min(current_position + viewport_height, total_height) current_position = min(current_position + viewport_height, total_height)
@@ -1521,7 +1585,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# new_height = await page.evaluate("document.documentElement.scrollHeight") # new_height = await page.evaluate("document.documentElement.scrollHeight")
dimensions = await self.get_page_dimensions(page) dimensions = await self.get_page_dimensions(page)
new_height = dimensions['height'] new_height = dimensions["height"]
if new_height > total_height: if new_height > total_height:
total_height = new_height total_height = new_height
@@ -1598,7 +1662,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
remove_overlays_js = load_js_script("remove_overlay_elements") remove_overlays_js = load_js_script("remove_overlay_elements")
try: try:
await page.evaluate(f""" await page.evaluate(
f"""
(() => {{ (() => {{
try {{ try {{
{remove_overlays_js} {remove_overlays_js}
@@ -1611,7 +1676,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
}}; }};
}} }}
}})() }})()
""") """
)
await page.wait_for_timeout(500) # Wait for any animations to complete await page.wait_for_timeout(500) # Wait for any animations to complete
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
@@ -1707,8 +1773,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
try: try:
# Get page height # Get page height
dimensions = await self.get_page_dimensions(page) dimensions = await self.get_page_dimensions(page)
page_width = dimensions['width'] page_width = dimensions["width"]
page_height = dimensions['height'] page_height = dimensions["height"]
# page_height = await page.evaluate("document.documentElement.scrollHeight") # page_height = await page.evaluate("document.documentElement.scrollHeight")
# page_width = await page.evaluate("document.documentElement.scrollWidth") # page_width = await page.evaluate("document.documentElement.scrollWidth")
@@ -1826,7 +1892,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
tag="WARNING", tag="WARNING",
) )
async def robust_execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]: async def robust_execute_user_script(
self, page: Page, js_code: Union[str, List[str]]
) -> Dict[str, Any]:
""" """
Executes user-provided JavaScript code with proper error handling and context, Executes user-provided JavaScript code with proper error handling and context,
supporting both synchronous and async user code, plus navigations. supporting both synchronous and async user code, plus navigations.
@@ -1846,7 +1914,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Dict[str, Any]: The results of the execution Dict[str, Any]: The results of the execution
""" """
try: try:
await page.wait_for_load_state('domcontentloaded') await page.wait_for_load_state("domcontentloaded")
if isinstance(js_code, str): if isinstance(js_code, str):
scripts = [js_code] scripts = [js_code]
@@ -1861,7 +1929,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# then wait for the new page to load before continuing # then wait for the new page to load before continuing
result = None result = None
try: try:
result = await page.evaluate(f""" result = await page.evaluate(
f"""
(async () => {{ (async () => {{
try {{ try {{
{script} {script}
@@ -1870,51 +1939,60 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
return {{ success: false, error: err.toString(), stack: err.stack }}; return {{ success: false, error: err.toString(), stack: err.stack }};
}} }}
}})(); }})();
""") """
)
except Error as e: except Error as e:
# If it's due to navigation destroying the context, handle gracefully # If it's due to navigation destroying the context, handle gracefully
if "Execution context was destroyed" in str(e): if "Execution context was destroyed" in str(e):
self.logger.info("Navigation triggered by script, waiting for load state", tag="JS_EXEC") self.logger.info(
"Navigation triggered by script, waiting for load state",
tag="JS_EXEC",
)
try: try:
await page.wait_for_load_state('load', timeout=30000) await page.wait_for_load_state("load", timeout=30000)
except Error as nav_err: except Error as nav_err:
self.logger.warning( self.logger.warning(
message="Navigation wait failed: {error}", message="Navigation wait failed: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(nav_err)} params={"error": str(nav_err)},
) )
try: try:
await page.wait_for_load_state('networkidle', timeout=30000) await page.wait_for_load_state(
"networkidle", timeout=30000
)
except Error as nav_err: except Error as nav_err:
self.logger.warning( self.logger.warning(
message="Network idle wait failed: {error}", message="Network idle wait failed: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(nav_err)} params={"error": str(nav_err)},
) )
# Return partial success, or adapt as you see fit # Return partial success, or adapt as you see fit
result = { result = {
"success": True, "success": True,
"info": "Navigation triggered, ignoring context destroyed error" "info": "Navigation triggered, ignoring context destroyed error",
} }
else: else:
# It's some other error, log and continue # It's some other error, log and continue
self.logger.error( self.logger.error(
message="Playwright execution error: {error}", message="Playwright execution error: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
result = {"success": False, "error": str(e)} result = {"success": False, "error": str(e)}
# If we made it this far with no repeated error, do post-load waits # If we made it this far with no repeated error, do post-load waits
t1 = time.time() t1 = time.time()
try: try:
await page.wait_for_load_state('domcontentloaded', timeout=5000) await page.wait_for_load_state("domcontentloaded", timeout=5000)
print("DOM content loaded after script execution in", time.time() - t1) print(
"DOM content loaded after script execution in",
time.time() - t1,
)
except Error as e: except Error as e:
self.logger.warning( self.logger.warning(
message="DOM content load timeout: {error}", message="DOM content load timeout: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
# t1 = time.time() # t1 = time.time()
@@ -1935,7 +2013,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error( self.logger.error(
message="Script chunk failed: {error}", message="Script chunk failed: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
results.append({"success": False, "error": str(e)}) results.append({"success": False, "error": str(e)})
@@ -1945,11 +2023,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error( self.logger.error(
message="Script execution failed: {error}", message="Script execution failed: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
async def execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]: async def execute_user_script(
self, page: Page, js_code: Union[str, List[str]]
) -> Dict[str, Any]:
""" """
Executes user-provided JavaScript code with proper error handling and context. Executes user-provided JavaScript code with proper error handling and context.
@@ -1962,7 +2042,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
""" """
try: try:
# Ensure the page is ready for script execution # Ensure the page is ready for script execution
await page.wait_for_load_state('domcontentloaded') await page.wait_for_load_state("domcontentloaded")
# Handle single script or multiple scripts # Handle single script or multiple scripts
if isinstance(js_code, str): if isinstance(js_code, str):
@@ -1974,7 +2054,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
for script in scripts: for script in scripts:
try: try:
# Execute the script and wait for network idle # Execute the script and wait for network idle
result = await page.evaluate(f""" result = await page.evaluate(
f"""
(() => {{ (() => {{
return new Promise((resolve) => {{ return new Promise((resolve) => {{
try {{ try {{
@@ -2007,15 +2088,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
}} }}
}}); }});
}})() }})()
""") """
)
# Wait for network idle after script execution # Wait for network idle after script execution
t1 = time.time() t1 = time.time()
await page.wait_for_load_state('domcontentloaded', timeout=5000) await page.wait_for_load_state("domcontentloaded", timeout=5000)
print("DOM content loaded after script execution in", time.time() - t1) print(
"DOM content loaded after script execution in", time.time() - t1
)
t1 = time.time() t1 = time.time()
await page.wait_for_load_state('networkidle', timeout=5000) await page.wait_for_load_state("networkidle", timeout=5000)
print("Network idle after script execution in", time.time() - t1) print("Network idle after script execution in", time.time() - t1)
results.append(result if result else {"success": True}) results.append(result if result else {"success": True})
@@ -2025,7 +2109,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error( self.logger.error(
message="Playwright execution error: {error}", message="Playwright execution error: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
results.append({"success": False, "error": str(e)}) results.append({"success": False, "error": str(e)})
@@ -2035,7 +2119,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error( self.logger.error(
message="Script execution failed: {error}", message="Script execution failed: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -2043,7 +2127,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error( self.logger.error(
message="Script execution failed: {error}", message="Script execution failed: {error}",
tag="JS_EXEC", tag="JS_EXEC",
params={"error": str(e)} params={"error": str(e)},
) )
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -2057,7 +2141,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Returns: Returns:
Boolean indicating visibility Boolean indicating visibility
""" """
return await page.evaluate(""" return await page.evaluate(
"""
() => { () => {
const element = document.body; const element = document.body;
if (!element) return false; if (!element) return false;
@@ -2067,7 +2152,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
style.opacity !== '0'; style.opacity !== '0';
return isVisible; return isVisible;
} }
""") """
)
async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1): async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1):
""" """
@@ -2079,7 +2165,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
y: Vertical scroll position y: Vertical scroll position
""" """
result = await self.csp_scroll_to(page, x, y) result = await self.csp_scroll_to(page, x, y)
if result['success']: if result["success"]:
await page.wait_for_timeout(delay * 1000) await page.wait_for_timeout(delay * 1000)
return result return result
@@ -2126,11 +2212,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
}}""" }}"""
) )
if not result['success']: if not result["success"]:
self.logger.warning( self.logger.warning(
message="Scroll operation failed: {error}", message="Scroll operation failed: {error}",
tag="SCROLL", tag="SCROLL",
params={"error": result.get('error')} params={"error": result.get("error")},
) )
return result return result
@@ -2139,12 +2225,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error( self.logger.error(
message="Failed to execute scroll: {error}", message="Failed to execute scroll: {error}",
tag="SCROLL", tag="SCROLL",
params={"error": str(e)} params={"error": str(e)},
) )
return { return {"success": False, "error": str(e)}
"success": False,
"error": str(e)
}
async def get_page_dimensions(self, page: Page): async def get_page_dimensions(self, page: Page):
""" """
@@ -2156,12 +2239,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Returns: Returns:
Dict containing width and height of the page Dict containing width and height of the page
""" """
return await page.evaluate(""" return await page.evaluate(
"""
() => { () => {
const {scrollWidth, scrollHeight} = document.documentElement; const {scrollWidth, scrollHeight} = document.documentElement;
return {width: scrollWidth, height: scrollHeight}; return {width: scrollWidth, height: scrollHeight};
} }
""") """
)
async def page_need_scroll(self, page: Page) -> bool: async def page_need_scroll(self, page: Page) -> bool:
""" """
@@ -2174,18 +2259,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
bool: True if page needs scrolling bool: True if page needs scrolling
""" """
try: try:
need_scroll = await page.evaluate(""" need_scroll = await page.evaluate(
"""
() => { () => {
const scrollHeight = document.documentElement.scrollHeight; const scrollHeight = document.documentElement.scrollHeight;
const viewportHeight = window.innerHeight; const viewportHeight = window.innerHeight;
return scrollHeight > viewportHeight; return scrollHeight > viewportHeight;
} }
""") """
)
return need_scroll return need_scroll
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
message="Failed to check scroll need: {error}. Defaulting to True for safety.", message="Failed to check scroll need: {error}. Defaulting to True for safety.",
tag="SCROLL", tag="SCROLL",
params={"error": str(e)} params={"error": str(e)},
) )
return True # Default to scrolling if check fails return True # Default to scrolling if check fails

View File

@@ -1,27 +1,29 @@
import os, sys import os
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
import asyncio import asyncio
from typing import Optional, Tuple, Dict from typing import Optional, Dict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import logging import logging
import json # Added for serialization/deserialization import json # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash from .utils import ensure_content_dirs, generate_content_hash
from .models import CrawlResult, MarkdownGenerationResult from .models import CrawlResult, MarkdownGenerationResult
import xxhash
import aiofiles import aiofiles
from .config import NEED_MIGRATION
from .version_manager import VersionManager from .version_manager import VersionManager
from .async_logger import AsyncLogger from .async_logger import AsyncLogger
from .utils import get_error_context, create_box_message from .utils import get_error_context, create_box_message
# Set up logging # Set up logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") base_directory = DB_PATH = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(DB_PATH, exist_ok=True) os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(base_directory, "crawl4ai.db") DB_PATH = os.path.join(base_directory, "crawl4ai.db")
class AsyncDatabaseManager: class AsyncDatabaseManager:
def __init__(self, pool_size: int = 10, max_retries: int = 3): def __init__(self, pool_size: int = 10, max_retries: int = 3):
self.db_path = DB_PATH self.db_path = DB_PATH
@@ -37,10 +39,9 @@ class AsyncDatabaseManager:
self.logger = AsyncLogger( self.logger = AsyncLogger(
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
verbose=False, verbose=False,
tag_width=10 tag_width=10,
) )
async def initialize(self): async def initialize(self):
"""Initialize the database and connection pool""" """Initialize the database and connection pool"""
try: try:
@@ -67,28 +68,32 @@ class AsyncDatabaseManager:
if needs_update: if needs_update:
self.logger.info("New version detected, running updates", tag="INIT") self.logger.info("New version detected, running updates", tag="INIT")
await self.update_db_schema() await self.update_db_schema()
from .migrations import run_migration # Import here to avoid circular imports from .migrations import (
run_migration,
) # Import here to avoid circular imports
await run_migration() await run_migration()
self.version_manager.update_version() # Update stored version after successful migration self.version_manager.update_version() # Update stored version after successful migration
self.logger.success("Version update completed successfully", tag="COMPLETE") self.logger.success(
"Version update completed successfully", tag="COMPLETE"
)
else: else:
self.logger.success("Database initialization completed successfully", tag="COMPLETE") self.logger.success(
"Database initialization completed successfully", tag="COMPLETE"
)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
message="Database initialization error: {error}", message="Database initialization error: {error}",
tag="ERROR", tag="ERROR",
params={"error": str(e)} params={"error": str(e)},
) )
self.logger.info( self.logger.info(
message="Database will be initialized on first use", message="Database will be initialized on first use", tag="INIT"
tag="INIT"
) )
raise raise
async def cleanup(self): async def cleanup(self):
"""Cleanup connections when shutting down""" """Cleanup connections when shutting down"""
async with self.pool_lock: async with self.pool_lock:
@@ -107,6 +112,7 @@ class AsyncDatabaseManager:
self._initialized = True self._initialized = True
except Exception as e: except Exception as e:
import sys import sys
error_context = get_error_context(sys.exc_info()) error_context = get_error_context(sys.exc_info())
self.logger.error( self.logger.error(
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
@@ -115,8 +121,8 @@ class AsyncDatabaseManager:
params={ params={
"error": str(e), "error": str(e),
"context": error_context["code_context"], "context": error_context["code_context"],
"traceback": error_context["full_traceback"] "traceback": error_context["full_traceback"],
} },
) )
raise raise
@@ -127,29 +133,40 @@ class AsyncDatabaseManager:
async with self.pool_lock: async with self.pool_lock:
if task_id not in self.connection_pool: if task_id not in self.connection_pool:
try: try:
conn = await aiosqlite.connect( conn = await aiosqlite.connect(self.db_path, timeout=30.0)
self.db_path, await conn.execute("PRAGMA journal_mode = WAL")
timeout=30.0 await conn.execute("PRAGMA busy_timeout = 5000")
)
await conn.execute('PRAGMA journal_mode = WAL')
await conn.execute('PRAGMA busy_timeout = 5000')
# Verify database structure # Verify database structure
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: async with conn.execute(
"PRAGMA table_info(crawled_data)"
) as cursor:
columns = await cursor.fetchall() columns = await cursor.fetchall()
column_names = [col[1] for col in columns] column_names = [col[1] for col in columns]
expected_columns = { expected_columns = {
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', "url",
'success', 'media', 'links', 'metadata', 'screenshot', "html",
'response_headers', 'downloaded_files' "cleaned_html",
"markdown",
"extracted_content",
"success",
"media",
"links",
"metadata",
"screenshot",
"response_headers",
"downloaded_files",
} }
missing_columns = expected_columns - set(column_names) missing_columns = expected_columns - set(column_names)
if missing_columns: if missing_columns:
raise ValueError(f"Database missing columns: {missing_columns}") raise ValueError(
f"Database missing columns: {missing_columns}"
)
self.connection_pool[task_id] = conn self.connection_pool[task_id] = conn
except Exception as e: except Exception as e:
import sys import sys
error_context = get_error_context(sys.exc_info()) error_context = get_error_context(sys.exc_info())
error_message = ( error_message = (
f"Unexpected error in db get_connection at line {error_context['line_no']} " f"Unexpected error in db get_connection at line {error_context['line_no']} "
@@ -167,6 +184,7 @@ class AsyncDatabaseManager:
except Exception as e: except Exception as e:
import sys import sys
error_context = get_error_context(sys.exc_info()) error_context = get_error_context(sys.exc_info())
error_message = ( error_message = (
f"Unexpected error in db get_connection at line {error_context['line_no']} " f"Unexpected error in db get_connection at line {error_context['line_no']} "
@@ -185,7 +203,6 @@ class AsyncDatabaseManager:
del self.connection_pool[task_id] del self.connection_pool[task_id]
self.connection_semaphore.release() self.connection_semaphore.release()
async def execute_with_retry(self, operation, *args): async def execute_with_retry(self, operation, *args):
"""Execute database operations with retry logic""" """Execute database operations with retry logic"""
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
@@ -200,10 +217,7 @@ class AsyncDatabaseManager:
message="Operation failed after {retries} attempts: {error}", message="Operation failed after {retries} attempts: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={ params={"retries": self.max_retries, "error": str(e)},
"retries": self.max_retries,
"error": str(e)
}
) )
raise raise
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
@@ -211,7 +225,8 @@ class AsyncDatabaseManager:
async def ainit_db(self): async def ainit_db(self):
"""Initialize database schema""" """Initialize database schema"""
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
await db.execute(''' await db.execute(
"""
CREATE TABLE IF NOT EXISTS crawled_data ( CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY, url TEXT PRIMARY KEY,
html TEXT, html TEXT,
@@ -226,11 +241,10 @@ class AsyncDatabaseManager:
response_headers TEXT DEFAULT "{}", response_headers TEXT DEFAULT "{}",
downloaded_files TEXT DEFAULT "{}" -- New column added downloaded_files TEXT DEFAULT "{}" -- New column added
) )
''') """
)
await db.commit() await db.commit()
async def update_db_schema(self): async def update_db_schema(self):
"""Update database schema if needed""" """Update database schema if needed"""
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
@@ -239,7 +253,14 @@ class AsyncDatabaseManager:
column_names = [column[1] for column in columns] column_names = [column[1] for column in columns]
# List of new columns to add # List of new columns to add
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] new_columns = [
"media",
"links",
"metadata",
"screenshot",
"response_headers",
"downloaded_files",
]
for column in new_columns: for column in new_columns:
if column not in column_names: if column not in column_names:
@@ -248,22 +269,26 @@ class AsyncDatabaseManager:
async def aalter_db_add_column(self, new_column: str, db): async def aalter_db_add_column(self, new_column: str, db):
"""Add new column to the database""" """Add new column to the database"""
if new_column == 'response_headers': if new_column == "response_headers":
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') await db.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"'
)
else: else:
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') await db.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
)
self.logger.info( self.logger.info(
message="Added column '{column}' to the database", message="Added column '{column}' to the database",
tag="INIT", tag="INIT",
params={"column": new_column} params={"column": new_column},
) )
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
"""Retrieve cached URL data as CrawlResult""" """Retrieve cached URL data as CrawlResult"""
async def _get(db): async def _get(db):
async with db.execute( async with db.execute(
'SELECT * FROM crawled_data WHERE url = ?', (url,) "SELECT * FROM crawled_data WHERE url = ?", (url,)
) as cursor: ) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
if not row: if not row:
@@ -276,42 +301,54 @@ class AsyncDatabaseManager:
# Load content from files using stored hashes # Load content from files using stored hashes
content_fields = { content_fields = {
'html': row_dict['html'], "html": row_dict["html"],
'cleaned_html': row_dict['cleaned_html'], "cleaned_html": row_dict["cleaned_html"],
'markdown': row_dict['markdown'], "markdown": row_dict["markdown"],
'extracted_content': row_dict['extracted_content'], "extracted_content": row_dict["extracted_content"],
'screenshot': row_dict['screenshot'], "screenshot": row_dict["screenshot"],
'screenshots': row_dict['screenshot'], "screenshots": row_dict["screenshot"],
} }
for field, hash_value in content_fields.items(): for field, hash_value in content_fields.items():
if hash_value: if hash_value:
content = await self._load_content( content = await self._load_content(
hash_value, hash_value,
field.split('_')[0] # Get content type from field name field.split("_")[0], # Get content type from field name
) )
row_dict[field] = content or "" row_dict[field] = content or ""
else: else:
row_dict[field] = "" row_dict[field] = ""
# Parse JSON fields # Parse JSON fields
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] json_fields = [
"media",
"links",
"metadata",
"response_headers",
"markdown",
]
for field in json_fields: for field in json_fields:
try: try:
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} row_dict[field] = (
json.loads(row_dict[field]) if row_dict[field] else {}
)
except json.JSONDecodeError: except json.JSONDecodeError:
row_dict[field] = {} row_dict[field] = {}
if isinstance(row_dict['markdown'], Dict): if isinstance(row_dict["markdown"], Dict):
row_dict['markdown_v2'] = row_dict['markdown'] row_dict["markdown_v2"] = row_dict["markdown"]
if row_dict['markdown'].get('raw_markdown'): if row_dict["markdown"].get("raw_markdown"):
row_dict['markdown'] = row_dict['markdown']['raw_markdown'] row_dict["markdown"] = row_dict["markdown"]["raw_markdown"]
# Parse downloaded_files # Parse downloaded_files
try: try:
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] row_dict["downloaded_files"] = (
json.loads(row_dict["downloaded_files"])
if row_dict["downloaded_files"]
else []
)
except json.JSONDecodeError: except json.JSONDecodeError:
row_dict['downloaded_files'] = [] row_dict["downloaded_files"] = []
# Remove any fields not in CrawlResult model # Remove any fields not in CrawlResult model
valid_fields = CrawlResult.__annotations__.keys() valid_fields = CrawlResult.__annotations__.keys()
@@ -326,7 +363,7 @@ class AsyncDatabaseManager:
message="Error retrieving cached URL: {error}", message="Error retrieving cached URL: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
return None return None
@@ -334,37 +371,52 @@ class AsyncDatabaseManager:
"""Cache CrawlResult data""" """Cache CrawlResult data"""
# Store content files and get hashes # Store content files and get hashes
content_map = { content_map = {
'html': (result.html, 'html'), "html": (result.html, "html"),
'cleaned_html': (result.cleaned_html or "", 'cleaned'), "cleaned_html": (result.cleaned_html or "", "cleaned"),
'markdown': None, "markdown": None,
'extracted_content': (result.extracted_content or "", 'extracted'), "extracted_content": (result.extracted_content or "", "extracted"),
'screenshot': (result.screenshot or "", 'screenshots') "screenshot": (result.screenshot or "", "screenshots"),
} }
try: try:
if isinstance(result.markdown, MarkdownGenerationResult): if isinstance(result.markdown, MarkdownGenerationResult):
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') content_map["markdown"] = (
elif hasattr(result, 'markdown_v2'): result.markdown.model_dump_json(),
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') "markdown",
)
elif hasattr(result, "markdown_v2"):
content_map["markdown"] = (
result.markdown_v2.model_dump_json(),
"markdown",
)
elif isinstance(result.markdown, str): elif isinstance(result.markdown, str):
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') content_map["markdown"] = (
markdown_result.model_dump_json(),
"markdown",
)
else: else:
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') content_map["markdown"] = (
MarkdownGenerationResult().model_dump_json(),
"markdown",
)
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
message=f"Error processing markdown content: {str(e)}", message=f"Error processing markdown content: {str(e)}", tag="WARNING"
tag="WARNING"
) )
# Fallback to empty markdown result # Fallback to empty markdown result
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') content_map["markdown"] = (
MarkdownGenerationResult().model_dump_json(),
"markdown",
)
content_hashes = {} content_hashes = {}
for field, (content, content_type) in content_map.items(): for field, (content, content_type) in content_map.items():
content_hashes[field] = await self._store_content(content, content_type) content_hashes[field] = await self._store_content(content, content_type)
async def _cache(db): async def _cache(db):
await db.execute(''' await db.execute(
"""
INSERT INTO crawled_data ( INSERT INTO crawled_data (
url, html, cleaned_html, markdown, url, html, cleaned_html, markdown,
extracted_content, success, media, links, metadata, extracted_content, success, media, links, metadata,
@@ -383,20 +435,22 @@ class AsyncDatabaseManager:
screenshot = excluded.screenshot, screenshot = excluded.screenshot,
response_headers = excluded.response_headers, response_headers = excluded.response_headers,
downloaded_files = excluded.downloaded_files downloaded_files = excluded.downloaded_files
''', ( """,
(
result.url, result.url,
content_hashes['html'], content_hashes["html"],
content_hashes['cleaned_html'], content_hashes["cleaned_html"],
content_hashes['markdown'], content_hashes["markdown"],
content_hashes['extracted_content'], content_hashes["extracted_content"],
result.success, result.success,
json.dumps(result.media), json.dumps(result.media),
json.dumps(result.links), json.dumps(result.links),
json.dumps(result.metadata or {}), json.dumps(result.metadata or {}),
content_hashes['screenshot'], content_hashes["screenshot"],
json.dumps(result.response_headers or {}), json.dumps(result.response_headers or {}),
json.dumps(result.downloaded_files or []) json.dumps(result.downloaded_files or []),
)) ),
)
try: try:
await self.execute_with_retry(_cache) await self.execute_with_retry(_cache)
@@ -405,14 +459,14 @@ class AsyncDatabaseManager:
message="Error caching URL: {error}", message="Error caching URL: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
async def aget_total_count(self) -> int: async def aget_total_count(self) -> int:
"""Get total number of cached URLs""" """Get total number of cached URLs"""
async def _count(db): async def _count(db):
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor:
result = await cursor.fetchone() result = await cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
@@ -423,14 +477,15 @@ class AsyncDatabaseManager:
message="Error getting total count: {error}", message="Error getting total count: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
return 0 return 0
async def aclear_db(self): async def aclear_db(self):
"""Clear all data from the database""" """Clear all data from the database"""
async def _clear(db): async def _clear(db):
await db.execute('DELETE FROM crawled_data') await db.execute("DELETE FROM crawled_data")
try: try:
await self.execute_with_retry(_clear) await self.execute_with_retry(_clear)
@@ -439,13 +494,14 @@ class AsyncDatabaseManager:
message="Error clearing database: {error}", message="Error clearing database: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
async def aflush_db(self): async def aflush_db(self):
"""Drop the entire table""" """Drop the entire table"""
async def _flush(db): async def _flush(db):
await db.execute('DROP TABLE IF EXISTS crawled_data') await db.execute("DROP TABLE IF EXISTS crawled_data")
try: try:
await self.execute_with_retry(_flush) await self.execute_with_retry(_flush)
@@ -454,10 +510,9 @@ class AsyncDatabaseManager:
message="Error flushing database: {error}", message="Error flushing database: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
async def _store_content(self, content: str, content_type: str) -> str: async def _store_content(self, content: str, content_type: str) -> str:
"""Store content in filesystem and return hash""" """Store content in filesystem and return hash"""
if not content: if not content:
@@ -468,28 +523,31 @@ class AsyncDatabaseManager:
# Only write if file doesn't exist # Only write if file doesn't exist
if not os.path.exists(file_path): if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
await f.write(content) await f.write(content)
return content_hash return content_hash
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: async def _load_content(
self, content_hash: str, content_type: str
) -> Optional[str]:
"""Load content from filesystem by hash""" """Load content from filesystem by hash"""
if not content_hash: if not content_hash:
return None return None
file_path = os.path.join(self.content_paths[content_type], content_hash) file_path = os.path.join(self.content_paths[content_type], content_hash)
try: try:
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read() return await f.read()
except: except:
self.logger.error( self.logger.error(
message="Failed to load content: {file_path}", message="Failed to load content: {file_path}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"file_path": file_path} params={"file_path": file_path},
) )
return None return None
# Create a singleton instance # Create a singleton instance
async_db_manager = AsyncDatabaseManager() async_db_manager = AsyncDatabaseManager()

View File

@@ -1,14 +1,19 @@
from typing import Dict, Optional, List from typing import Dict, Optional, List, Tuple
from .async_configs import * from .async_configs import CrawlerRunConfig
from .models import * from .models import (
CrawlResult,
CrawlerTaskResult,
CrawlStatus,
DisplayMode,
CrawlStats,
DomainState,
)
from rich.live import Live from rich.live import Live
from rich.table import Table from rich.table import Table
from rich.console import Console from rich.console import Console
from rich.style import Style
from rich import box from rich import box
from datetime import datetime, timedelta from datetime import datetime, timedelta
from dataclasses import dataclass
import time import time
import psutil import psutil
@@ -26,7 +31,7 @@ class RateLimiter:
base_delay: Tuple[float, float] = (1.0, 3.0), base_delay: Tuple[float, float] = (1.0, 3.0),
max_delay: float = 60.0, max_delay: float = 60.0,
max_retries: int = 3, max_retries: int = 3,
rate_limit_codes: List[int] = None rate_limit_codes: List[int] = None,
): ):
self.base_delay = base_delay self.base_delay = base_delay
self.max_delay = max_delay self.max_delay = max_delay
@@ -68,21 +73,24 @@ class RateLimiter:
# Exponential backoff with random jitter # Exponential backoff with random jitter
state.current_delay = min( state.current_delay = min(
state.current_delay * 2 * random.uniform(0.75, 1.25), state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay
self.max_delay
) )
else: else:
# Gradually reduce delay on success # Gradually reduce delay on success
state.current_delay = max( state.current_delay = max(
random.uniform(*self.base_delay), random.uniform(*self.base_delay), state.current_delay * 0.75
state.current_delay * 0.75
) )
state.fail_count = 0 state.fail_count = 0
return True return True
class CrawlerMonitor: class CrawlerMonitor:
def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED): def __init__(
self,
max_visible_rows: int = 15,
display_mode: DisplayMode = DisplayMode.DETAILED,
):
self.console = Console() self.console = Console()
self.max_visible_rows = max_visible_rows self.max_visible_rows = max_visible_rows
self.display_mode = display_mode self.display_mode = display_mode
@@ -98,7 +106,9 @@ class CrawlerMonitor:
self.live.stop() self.live.stop()
def add_task(self, task_id: str, url: str): def add_task(self, task_id: str, url: str):
self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED) self.stats[task_id] = CrawlStats(
task_id=task_id, url=url, status=CrawlStatus.QUEUED
)
self.live.update(self._create_table()) self.live.update(self._create_table())
def update_task(self, task_id: str, **kwargs): def update_task(self, task_id: str, **kwargs):
@@ -114,20 +124,30 @@ class CrawlerMonitor:
title="Crawler Status Overview", title="Crawler Status Overview",
title_style="bold magenta", title_style="bold magenta",
header_style="bold blue", header_style="bold blue",
show_lines=True show_lines=True,
) )
# Calculate statistics # Calculate statistics
total_tasks = len(self.stats) total_tasks = len(self.stats)
queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED) queued = sum(
in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS) 1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED
completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED) )
failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED) in_progress = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
)
completed = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
)
failed = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
)
# Memory statistics # Memory statistics
current_memory = self.process.memory_info().rss / (1024 * 1024) current_memory = self.process.memory_info().rss / (1024 * 1024)
total_task_memory = sum(stat.memory_usage for stat in self.stats.values()) total_task_memory = sum(stat.memory_usage for stat in self.stats.values())
peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0) peak_memory = max(
(stat.peak_memory for stat in self.stats.values()), default=0.0
)
# Duration # Duration
duration = datetime.now() - self.start_time duration = datetime.now() - self.start_time
@@ -137,53 +157,43 @@ class CrawlerMonitor:
table.add_column("Count", justify="right") table.add_column("Count", justify="right")
table.add_column("Percentage", justify="right") table.add_column("Percentage", justify="right")
table.add_row( table.add_row("Total Tasks", str(total_tasks), "100%")
"Total Tasks",
str(total_tasks),
"100%"
)
table.add_row( table.add_row(
"[yellow]In Queue[/yellow]", "[yellow]In Queue[/yellow]",
str(queued), str(queued),
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
table.add_row( table.add_row(
"[blue]In Progress[/blue]", "[blue]In Progress[/blue]",
str(in_progress), str(in_progress),
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
table.add_row( table.add_row(
"[green]Completed[/green]", "[green]Completed[/green]",
str(completed), str(completed),
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
table.add_row( table.add_row(
"[red]Failed[/red]", "[red]Failed[/red]",
str(failed), str(failed),
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
# Add memory information # Add memory information
table.add_section() table.add_section()
table.add_row( table.add_row(
"[magenta]Current Memory[/magenta]", "[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", ""
f"{current_memory:.1f} MB",
""
) )
table.add_row( table.add_row(
"[magenta]Total Task Memory[/magenta]", "[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", ""
f"{total_task_memory:.1f} MB",
""
) )
table.add_row( table.add_row(
"[magenta]Peak Task Memory[/magenta]", "[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", ""
f"{peak_memory:.1f} MB",
""
) )
table.add_row( table.add_row(
"[yellow]Runtime[/yellow]", "[yellow]Runtime[/yellow]",
str(timedelta(seconds=int(duration.total_seconds()))), str(timedelta(seconds=int(duration.total_seconds()))),
"" "",
) )
return table return table
@@ -193,7 +203,7 @@ class CrawlerMonitor:
box=box.ROUNDED, box=box.ROUNDED,
title="Crawler Performance Monitor", title="Crawler Performance Monitor",
title_style="bold magenta", title_style="bold magenta",
header_style="bold blue" header_style="bold blue",
) )
# Add columns # Add columns
@@ -207,12 +217,15 @@ class CrawlerMonitor:
# Add summary row # Add summary row
total_memory = sum(stat.memory_usage for stat in self.stats.values()) total_memory = sum(stat.memory_usage for stat in self.stats.values())
active_count = sum(1 for stat in self.stats.values() active_count = sum(
if stat.status == CrawlStatus.IN_PROGRESS) 1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
completed_count = sum(1 for stat in self.stats.values() )
if stat.status == CrawlStatus.COMPLETED) completed_count = sum(
failed_count = sum(1 for stat in self.stats.values() 1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
if stat.status == CrawlStatus.FAILED) )
failed_count = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
)
table.add_row( table.add_row(
"[bold yellow]SUMMARY", "[bold yellow]SUMMARY",
@@ -220,9 +233,13 @@ class CrawlerMonitor:
f"Active: {active_count}", f"Active: {active_count}",
f"{total_memory:.1f}", f"{total_memory:.1f}",
f"{self.process.memory_info().rss / (1024 * 1024):.1f}", f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))), str(
timedelta(
seconds=int((datetime.now() - self.start_time).total_seconds())
)
),
f"{completed_count}{failed_count}", f"{completed_count}{failed_count}",
style="bold" style="bold",
) )
table.add_section() table.add_section()
@@ -233,8 +250,8 @@ class CrawlerMonitor:
key=lambda x: ( key=lambda x: (
x.status != CrawlStatus.IN_PROGRESS, x.status != CrawlStatus.IN_PROGRESS,
x.status != CrawlStatus.QUEUED, x.status != CrawlStatus.QUEUED,
x.end_time or datetime.max x.end_time or datetime.max,
) ),
)[: self.max_visible_rows] )[: self.max_visible_rows]
for stat in visible_stats: for stat in visible_stats:
@@ -242,7 +259,7 @@ class CrawlerMonitor:
CrawlStatus.QUEUED: "white", CrawlStatus.QUEUED: "white",
CrawlStatus.IN_PROGRESS: "yellow", CrawlStatus.IN_PROGRESS: "yellow",
CrawlStatus.COMPLETED: "green", CrawlStatus.COMPLETED: "green",
CrawlStatus.FAILED: "red" CrawlStatus.FAILED: "red",
}[stat.status] }[stat.status]
table.add_row( table.add_row(
@@ -252,7 +269,7 @@ class CrawlerMonitor:
f"{stat.memory_usage:.1f}", f"{stat.memory_usage:.1f}",
f"{stat.peak_memory:.1f}", f"{stat.peak_memory:.1f}",
stat.duration, stat.duration,
stat.error_message[:40] if stat.error_message else "" stat.error_message[:40] if stat.error_message else "",
) )
return table return table
@@ -268,7 +285,7 @@ class BaseDispatcher(ABC):
def __init__( def __init__(
self, self,
rate_limiter: Optional[RateLimiter] = None, rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
): ):
self.crawler = None self.crawler = None
self._domain_last_hit: Dict[str, float] = {} self._domain_last_hit: Dict[str, float] = {}
@@ -282,7 +299,7 @@ class BaseDispatcher(ABC):
url: str, url: str,
config: CrawlerRunConfig, config: CrawlerRunConfig,
task_id: str, task_id: str,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
) -> CrawlerTaskResult: ) -> CrawlerTaskResult:
pass pass
@@ -290,12 +307,13 @@ class BaseDispatcher(ABC):
async def run_urls( async def run_urls(
self, self,
urls: List[str], urls: List[str],
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler", # noqa: F821
config: CrawlerRunConfig, config: CrawlerRunConfig,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
) -> List[CrawlerTaskResult]: ) -> List[CrawlerTaskResult]:
pass pass
class MemoryAdaptiveDispatcher(BaseDispatcher): class MemoryAdaptiveDispatcher(BaseDispatcher):
def __init__( def __init__(
self, self,
@@ -304,7 +322,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
max_session_permit: int = 20, max_session_permit: int = 20,
memory_wait_timeout: float = 300.0, # 5 minutes default timeout memory_wait_timeout: float = 300.0, # 5 minutes default timeout
rate_limiter: Optional[RateLimiter] = None, rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
): ):
super().__init__(rate_limiter, monitor) super().__init__(rate_limiter, monitor)
self.memory_threshold_percent = memory_threshold_percent self.memory_threshold_percent = memory_threshold_percent
@@ -324,7 +342,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
try: try:
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) self.monitor.update_task(
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
)
self.concurrent_sessions += 1 self.concurrent_sessions += 1
if self.rate_limiter: if self.rate_limiter:
@@ -350,7 +370,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=datetime.now(), end_time=datetime.now(),
error_message=error_message error_message=error_message,
) )
if not result.success: if not result.success:
@@ -364,7 +384,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
error_message = str(e) error_message = str(e)
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED) self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) result = CrawlResult(
url=url, html="", metadata={}, success=False, error_message=str(e)
)
finally: finally:
end_time = datetime.now() end_time = datetime.now()
@@ -374,7 +396,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
end_time=end_time, end_time=end_time,
memory_usage=memory_usage, memory_usage=memory_usage,
peak_memory=peak_memory, peak_memory=peak_memory,
error_message=error_message error_message=error_message,
) )
self.concurrent_sessions -= 1 self.concurrent_sessions -= 1
@@ -386,13 +408,13 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
error_message=error_message error_message=error_message,
) )
async def run_urls( async def run_urls(
self, self,
urls: List[str], urls: List[str],
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler", # noqa: F821
config: CrawlerRunConfig, config: CrawlerRunConfig,
) -> List[CrawlerTaskResult]: ) -> List[CrawlerTaskResult]:
self.crawler = crawler self.crawler = crawler
@@ -417,7 +439,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
if psutil.virtual_memory().percent >= self.memory_threshold_percent: if psutil.virtual_memory().percent >= self.memory_threshold_percent:
# Check if we've exceeded the timeout # Check if we've exceeded the timeout
if time.time() - wait_start_time > self.memory_wait_timeout: if time.time() - wait_start_time > self.memory_wait_timeout:
raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds") raise MemoryError(
f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds"
)
await asyncio.sleep(self.check_interval) await asyncio.sleep(self.check_interval)
continue continue
@@ -430,8 +454,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
continue continue
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
active_tasks, active_tasks, return_when=asyncio.FIRST_COMPLETED
return_when=asyncio.FIRST_COMPLETED
) )
pending_tasks.extend(done) pending_tasks.extend(done)
@@ -442,13 +465,14 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
if self.monitor: if self.monitor:
self.monitor.stop() self.monitor.stop()
class SemaphoreDispatcher(BaseDispatcher): class SemaphoreDispatcher(BaseDispatcher):
def __init__( def __init__(
self, self,
semaphore_count: int = 5, semaphore_count: int = 5,
max_session_permit: int = 20, max_session_permit: int = 20,
rate_limiter: Optional[RateLimiter] = None, rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
): ):
super().__init__(rate_limiter, monitor) super().__init__(rate_limiter, monitor)
self.semaphore_count = semaphore_count self.semaphore_count = semaphore_count
@@ -459,7 +483,7 @@ class SemaphoreDispatcher(BaseDispatcher):
url: str, url: str,
config: CrawlerRunConfig, config: CrawlerRunConfig,
task_id: str, task_id: str,
semaphore: asyncio.Semaphore = None semaphore: asyncio.Semaphore = None,
) -> CrawlerTaskResult: ) -> CrawlerTaskResult:
start_time = datetime.now() start_time = datetime.now()
error_message = "" error_message = ""
@@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher):
try: try:
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) self.monitor.update_task(
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
)
if self.rate_limiter: if self.rate_limiter:
await self.rate_limiter.wait_if_needed(url) await self.rate_limiter.wait_if_needed(url)
@@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=datetime.now(), end_time=datetime.now(),
error_message=error_message error_message=error_message,
) )
if not result.success: if not result.success:
@@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher):
error_message = str(e) error_message = str(e)
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED) self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) result = CrawlResult(
url=url, html="", metadata={}, success=False, error_message=str(e)
)
finally: finally:
end_time = datetime.now() end_time = datetime.now()
@@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher):
end_time=end_time, end_time=end_time,
memory_usage=memory_usage, memory_usage=memory_usage,
peak_memory=peak_memory, peak_memory=peak_memory,
error_message=error_message error_message=error_message,
) )
return CrawlerTaskResult( return CrawlerTaskResult(
@@ -528,12 +556,12 @@ class SemaphoreDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
error_message=error_message error_message=error_message,
) )
async def run_urls( async def run_urls(
self, self,
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler", # noqa: F821
urls: List[str], urls: List[str],
config: CrawlerRunConfig, config: CrawlerRunConfig,
) -> List[CrawlerTaskResult]: ) -> List[CrawlerTaskResult]:

View File

@@ -1,10 +1,10 @@
from enum import Enum from enum import Enum
from typing import Optional, Dict, Any, Union from typing import Optional, Dict, Any
from colorama import Fore, Back, Style, init from colorama import Fore, Style, init
import time
import os import os
from datetime import datetime from datetime import datetime
class LogLevel(Enum): class LogLevel(Enum):
DEBUG = 1 DEBUG = 1
INFO = 2 INFO = 2
@@ -12,6 +12,7 @@ class LogLevel(Enum):
WARNING = 4 WARNING = 4
ERROR = 5 ERROR = 5
class AsyncLogger: class AsyncLogger:
""" """
Asynchronous logger with support for colored console output and file logging. Asynchronous logger with support for colored console output and file logging.
@@ -19,16 +20,16 @@ class AsyncLogger:
""" """
DEFAULT_ICONS = { DEFAULT_ICONS = {
'INIT': '', "INIT": "",
'READY': '', "READY": "",
'FETCH': '', "FETCH": "",
'SCRAPE': '', "SCRAPE": "",
'EXTRACT': '', "EXTRACT": "",
'COMPLETE': '', "COMPLETE": "",
'ERROR': '×', "ERROR": "×",
'DEBUG': '', "DEBUG": "",
'INFO': '', "INFO": "",
'WARNING': '', "WARNING": "",
} }
DEFAULT_COLORS = { DEFAULT_COLORS = {
@@ -46,7 +47,7 @@ class AsyncLogger:
tag_width: int = 10, tag_width: int = 10,
icons: Optional[Dict[str, str]] = None, icons: Optional[Dict[str, str]] = None,
colors: Optional[Dict[LogLevel, str]] = None, colors: Optional[Dict[LogLevel, str]] = None,
verbose: bool = True verbose: bool = True,
): ):
""" """
Initialize the logger. Initialize the logger.
@@ -77,18 +78,20 @@ class AsyncLogger:
def _get_icon(self, tag: str) -> str: def _get_icon(self, tag: str) -> str:
"""Get the icon for a tag, defaulting to info icon if not found.""" """Get the icon for a tag, defaulting to info icon if not found."""
return self.icons.get(tag, self.icons['INFO']) return self.icons.get(tag, self.icons["INFO"])
def _write_to_file(self, message: str): def _write_to_file(self, message: str):
"""Write a message to the log file if configured.""" """Write a message to the log file if configured."""
if self.log_file: if self.log_file:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
with open(self.log_file, 'a', encoding='utf-8') as f: with open(self.log_file, "a", encoding="utf-8") as f:
# Strip ANSI color codes for file output # Strip ANSI color codes for file output
clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '') clean_message = message.replace(Fore.RESET, "").replace(
Style.RESET_ALL, ""
)
for color in vars(Fore).values(): for color in vars(Fore).values():
if isinstance(color, str): if isinstance(color, str):
clean_message = clean_message.replace(color, '') clean_message = clean_message.replace(color, "")
f.write(f"[{timestamp}] {clean_message}\n") f.write(f"[{timestamp}] {clean_message}\n")
def _log( def _log(
@@ -99,7 +102,7 @@ class AsyncLogger:
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
colors: Optional[Dict[str, str]] = None, colors: Optional[Dict[str, str]] = None,
base_color: Optional[str] = None, base_color: Optional[str] = None,
**kwargs **kwargs,
): ):
""" """
Core logging method that handles message formatting and output. Core logging method that handles message formatting and output.
@@ -128,12 +131,13 @@ class AsyncLogger:
if key in params: if key in params:
value_str = str(params[key]) value_str = str(params[key])
formatted_message = formatted_message.replace( formatted_message = formatted_message.replace(
value_str, value_str, f"{color}{value_str}{Style.RESET_ALL}"
f"{color}{value_str}{Style.RESET_ALL}"
) )
except KeyError as e: except KeyError as e:
formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template" formatted_message = (
f"LOGGING ERROR: Missing parameter {e} in message template"
)
level = LogLevel.ERROR level = LogLevel.ERROR
else: else:
formatted_message = message formatted_message = message
@@ -175,7 +179,7 @@ class AsyncLogger:
success: bool, success: bool,
timing: float, timing: float,
tag: str = "FETCH", tag: str = "FETCH",
url_length: int = 50 url_length: int = 50,
): ):
""" """
Convenience method for logging URL fetch status. Convenience method for logging URL fetch status.
@@ -195,20 +199,16 @@ class AsyncLogger:
"url": url, "url": url,
"url_length": url_length, "url_length": url_length,
"status": success, "status": success,
"timing": timing "timing": timing,
}, },
colors={ colors={
"status": Fore.GREEN if success else Fore.RED, "status": Fore.GREEN if success else Fore.RED,
"timing": Fore.YELLOW "timing": Fore.YELLOW,
} },
) )
def error_status( def error_status(
self, self, url: str, error: str, tag: str = "ERROR", url_length: int = 50
url: str,
error: str,
tag: str = "ERROR",
url_length: int = 50
): ):
""" """
Convenience method for logging error status. Convenience method for logging error status.
@@ -223,9 +223,5 @@ class AsyncLogger:
level=LogLevel.ERROR, level=LogLevel.ERROR,
message="{url:.{url_length}}... | Error: {error}", message="{url:.{url_length}}... | Error: {error}",
tag=tag, tag=tag,
params={ params={"url": url, "url_length": url_length, "error": error},
"url": url,
"url_length": url_length,
"error": error
}
) )

View File

@@ -1,42 +1,47 @@
import os, sys import os
import sys
import time import time
import warnings import warnings
from enum import Enum from colorama import Fore
from colorama import init, Fore, Back, Style
from pathlib import Path from pathlib import Path
from typing import Optional, List, Union from typing import Optional, List
import json import json
import asyncio import asyncio
# from contextlib import nullcontext, asynccontextmanager # from contextlib import nullcontext, asynccontextmanager
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult, RateLimiter
from .async_database import async_db_manager from .async_database import async_db_manager
from .chunking_strategy import * from .chunking_strategy import * # noqa: F403
from .content_filter_strategy import * from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking
from .extraction_strategy import * from .content_filter_strategy import * # noqa: F403
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse from .content_filter_strategy import RelevantContentFilter
from .extraction_strategy import * # noqa: F403
from .extraction_strategy import NoExtractionStrategy, ExtractionStrategy
from .async_crawler_strategy import (
AsyncCrawlerStrategy,
AsyncPlaywrightCrawlerStrategy,
AsyncCrawlResponse,
)
from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode
from .markdown_generation_strategy import DefaultMarkdownGenerator, MarkdownGenerationStrategy from .markdown_generation_strategy import (
from .content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy DefaultMarkdownGenerator,
MarkdownGenerationStrategy,
)
from .async_logger import AsyncLogger from .async_logger import AsyncLogger
from .async_configs import BrowserConfig, CrawlerRunConfig from .async_configs import BrowserConfig, CrawlerRunConfig
from .async_dispatcher import * from .async_dispatcher import * # noqa: F403
from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher
from .config import ( from .config import MIN_WORD_THRESHOLD
MIN_WORD_THRESHOLD,
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
URL_LOG_SHORTEN_LENGTH
)
from .utils import ( from .utils import (
sanitize_input_encode, sanitize_input_encode,
InvalidCSSSelectorError, InvalidCSSSelectorError,
format_html,
fast_format_html, fast_format_html,
create_box_message create_box_message,
get_error_context,
) )
from urllib.parse import urlparse
import random
from .__version__ import __version__ as crawl4ai_version from .__version__ import __version__ as crawl4ai_version
@@ -104,6 +109,7 @@ class AsyncWebCrawler:
result = await crawler.arun(url="https://example.com", config=crawler_config) result = await crawler.arun(url="https://example.com", config=crawler_config)
print(result.markdown) print(result.markdown)
""" """
_domain_last_hit = {} _domain_last_hit = {}
def __init__( def __init__(
@@ -131,10 +137,18 @@ class AsyncWebCrawler:
# Handle browser configuration # Handle browser configuration
browser_config = config browser_config = config
if browser_config is not None: if browser_config is not None:
if any(k in kwargs for k in ["browser_type", "headless", "viewport_width", "viewport_height"]): if any(
k in kwargs
for k in [
"browser_type",
"headless",
"viewport_width",
"viewport_height",
]
):
self.logger.warning( self.logger.warning(
message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.", message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.",
tag="WARNING" tag="WARNING",
) )
else: else:
# Create browser config from kwargs for backwards compatibility # Create browser config from kwargs for backwards compatibility
@@ -146,18 +160,15 @@ class AsyncWebCrawler:
self.logger = AsyncLogger( self.logger = AsyncLogger(
log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"), log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"),
verbose=self.browser_config.verbose, verbose=self.browser_config.verbose,
tag_width=10 tag_width=10,
) )
# Initialize crawler strategy # Initialize crawler strategy
params = { params = {k: v for k, v in kwargs.items() if k in ["browser_congig", "logger"]}
k:v for k, v in kwargs.items() if k in ['browser_congig', 'logger']
}
self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy( self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy(
browser_config=browser_config, browser_config=browser_config,
logger=self.logger, logger=self.logger,
**params # Pass remaining kwargs for backwards compatibility **params, # Pass remaining kwargs for backwards compatibility
) )
# If craweler strategy doesnt have logger, use crawler logger # If craweler strategy doesnt have logger, use crawler logger
@@ -172,7 +183,7 @@ class AsyncWebCrawler:
"Use 'always_bypass_cache' instead. " "Use 'always_bypass_cache' instead. "
"Pass warning=False to suppress this warning.", "Pass warning=False to suppress this warning.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
self.always_bypass_cache = always_by_pass_cache self.always_bypass_cache = always_by_pass_cache
else: else:
@@ -323,7 +334,7 @@ class AsyncWebCrawler:
"screenshot": screenshot, "screenshot": screenshot,
"pdf": pdf, "pdf": pdf,
"verbose": verbose, "verbose": verbose,
**kwargs **kwargs,
} }
config = CrawlerRunConfig.from_kwargs(config_kwargs) config = CrawlerRunConfig.from_kwargs(config_kwargs)
@@ -334,7 +345,7 @@ class AsyncWebCrawler:
"Cache control boolean flags are deprecated and will be removed in version 0.5.0. " "Cache control boolean flags are deprecated and will be removed in version 0.5.0. "
"Use 'cache_mode' parameter instead.", "Use 'cache_mode' parameter instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
# Convert legacy parameters if cache_mode not provided # Convert legacy parameters if cache_mode not provided
@@ -343,7 +354,7 @@ class AsyncWebCrawler:
disable_cache=disable_cache, disable_cache=disable_cache,
bypass_cache=bypass_cache, bypass_cache=bypass_cache,
no_cache_read=no_cache_read, no_cache_read=no_cache_read,
no_cache_write=no_cache_write no_cache_write=no_cache_write,
) )
# Default to ENABLED if no cache mode specified # Default to ENABLED if no cache mode specified
@@ -351,7 +362,9 @@ class AsyncWebCrawler:
config.cache_mode = CacheMode.ENABLED config.cache_mode = CacheMode.ENABLED
# Create cache context # Create cache context
cache_context = CacheContext(url, config.cache_mode, self.always_bypass_cache) cache_context = CacheContext(
url, config.cache_mode, self.always_bypass_cache
)
# Initialize processing variables # Initialize processing variables
async_response: AsyncCrawlResponse = None async_response: AsyncCrawlResponse = None
@@ -367,8 +380,14 @@ class AsyncWebCrawler:
if cached_result: if cached_result:
html = sanitize_input_encode(cached_result.html) html = sanitize_input_encode(cached_result.html)
extracted_content = sanitize_input_encode(cached_result.extracted_content or "") extracted_content = sanitize_input_encode(
extracted_content = None if not extracted_content or extracted_content == "[]" else extracted_content cached_result.extracted_content or ""
)
extracted_content = (
None
if not extracted_content or extracted_content == "[]"
else extracted_content
)
# If screenshot is requested but its not in cache, then set cache_result to None # If screenshot is requested but its not in cache, then set cache_result to None
screenshot_data = cached_result.screenshot screenshot_data = cached_result.screenshot
pdf_data = cached_result.pdf pdf_data = cached_result.pdf
@@ -379,7 +398,7 @@ class AsyncWebCrawler:
url=cache_context.display_url, url=cache_context.display_url,
success=bool(html), success=bool(html),
timing=time.perf_counter() - start_time, timing=time.perf_counter() - start_time,
tag="FETCH" tag="FETCH",
) )
# Fetch fresh content if needed # Fetch fresh content if needed
@@ -392,7 +411,7 @@ class AsyncWebCrawler:
# Pass config to crawl method # Pass config to crawl method
async_response = await self.crawler_strategy.crawl( async_response = await self.crawler_strategy.crawl(
url, url,
config=config # Pass the entire config object config=config, # Pass the entire config object
) )
html = sanitize_input_encode(async_response.html) html = sanitize_input_encode(async_response.html)
@@ -404,7 +423,7 @@ class AsyncWebCrawler:
url=cache_context.display_url, url=cache_context.display_url,
success=bool(html), success=bool(html),
timing=t2 - t1, timing=t2 - t1,
tag="FETCH" tag="FETCH",
) )
# Process the HTML content # Process the HTML content
@@ -417,13 +436,15 @@ class AsyncWebCrawler:
pdf_data=pdf_data, pdf_data=pdf_data,
verbose=config.verbose, verbose=config.verbose,
is_raw_html=True if url.startswith("raw:") else False, is_raw_html=True if url.startswith("raw:") else False,
**kwargs **kwargs,
) )
crawl_result.status_code = async_response.status_code crawl_result.status_code = async_response.status_code
crawl_result.response_headers = async_response.response_headers crawl_result.response_headers = async_response.response_headers
crawl_result.downloaded_files = async_response.downloaded_files crawl_result.downloaded_files = async_response.downloaded_files
crawl_result.ssl_certificate = async_response.ssl_certificate # Add SSL certificate crawl_result.ssl_certificate = (
async_response.ssl_certificate
) # Add SSL certificate
# # Check and set values from async_response to crawl_result # # Check and set values from async_response to crawl_result
# try: # try:
@@ -446,7 +467,7 @@ class AsyncWebCrawler:
# ) # )
crawl_result.success = bool(html) crawl_result.success = bool(html)
crawl_result.session_id = getattr(config, 'session_id', None) crawl_result.session_id = getattr(config, "session_id", None)
self.logger.success( self.logger.success(
message="{url:.50}... | Status: {status} | Total: {timing}", message="{url:.50}... | Status: {status} | Total: {timing}",
@@ -454,12 +475,12 @@ class AsyncWebCrawler:
params={ params={
"url": cache_context.display_url, "url": cache_context.display_url,
"status": crawl_result.success, "status": crawl_result.success,
"timing": f"{time.perf_counter() - start_time:.2f}s" "timing": f"{time.perf_counter() - start_time:.2f}s",
}, },
colors={ colors={
"status": Fore.GREEN if crawl_result.success else Fore.RED, "status": Fore.GREEN if crawl_result.success else Fore.RED,
"timing": Fore.YELLOW "timing": Fore.YELLOW,
} },
) )
# Update cache if appropriate # Update cache if appropriate
@@ -475,16 +496,13 @@ class AsyncWebCrawler:
params={ params={
"url": cache_context.display_url, "url": cache_context.display_url,
"status": True, "status": True,
"timing": f"{time.perf_counter() - start_time:.2f}s" "timing": f"{time.perf_counter() - start_time:.2f}s",
}, },
colors={ colors={"status": Fore.GREEN, "timing": Fore.YELLOW},
"status": Fore.GREEN,
"timing": Fore.YELLOW
}
) )
cached_result.success = bool(html) cached_result.success = bool(html)
cached_result.session_id = getattr(config, 'session_id', None) cached_result.session_id = getattr(config, "session_id", None)
return cached_result return cached_result
except Exception as e: except Exception as e:
@@ -502,14 +520,11 @@ class AsyncWebCrawler:
self.logger.error_status( self.logger.error_status(
url=url, url=url,
error=create_box_message(error_message, type="error"), error=create_box_message(error_message, type="error"),
tag="ERROR" tag="ERROR",
) )
return CrawlResult( return CrawlResult(
url=url, url=url, html="", success=False, error_message=error_message
html="",
success=False,
error_message=error_message
) )
async def aprocess_html( async def aprocess_html(
@@ -553,21 +568,19 @@ class AsyncWebCrawler:
# add keys from kwargs to params that doesn't exist in params # add keys from kwargs to params that doesn't exist in params
params.update({k: v for k, v in kwargs.items() if k not in params.keys()}) params.update({k: v for k, v in kwargs.items() if k not in params.keys()})
result = scraping_strategy.scrap( result = scraping_strategy.scrap(url, html, **params)
url,
html,
**params
)
if result is None: if result is None:
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}") raise ValueError(
f"Process HTML, Failed to extract content from the website: {url}"
)
except InvalidCSSSelectorError as e: except InvalidCSSSelectorError as e:
raise ValueError(str(e)) raise ValueError(str(e))
except Exception as e: except Exception as e:
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") raise ValueError(
f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}"
)
# Extract results - handle both dict and ScrapingResult # Extract results - handle both dict and ScrapingResult
if isinstance(result, dict): if isinstance(result, dict):
@@ -582,17 +595,21 @@ class AsyncWebCrawler:
metadata = result.metadata metadata = result.metadata
# Markdown Generation # Markdown Generation
markdown_generator: Optional[MarkdownGenerationStrategy] = config.markdown_generator or DefaultMarkdownGenerator() markdown_generator: Optional[MarkdownGenerationStrategy] = (
config.markdown_generator or DefaultMarkdownGenerator()
)
# Uncomment if by default we want to use PruningContentFilter # Uncomment if by default we want to use PruningContentFilter
# if not config.content_filter and not markdown_generator.content_filter: # if not config.content_filter and not markdown_generator.content_filter:
# markdown_generator.content_filter = PruningContentFilter() # markdown_generator.content_filter = PruningContentFilter()
markdown_result: MarkdownGenerationResult = markdown_generator.generate_markdown( markdown_result: MarkdownGenerationResult = (
markdown_generator.generate_markdown(
cleaned_html=cleaned_html, cleaned_html=cleaned_html,
base_url=url, base_url=url,
# html2text_options=kwargs.get('html2text', {}) # html2text_options=kwargs.get('html2text', {})
) )
)
markdown_v2 = markdown_result markdown_v2 = markdown_result
markdown = sanitize_input_encode(markdown_result.raw_markdown) markdown = sanitize_input_encode(markdown_result.raw_markdown)
@@ -600,15 +617,15 @@ class AsyncWebCrawler:
self.logger.info( self.logger.info(
message="Processed {url:.50}... | Time: {timing}ms", message="Processed {url:.50}... | Time: {timing}ms",
tag="SCRAPE", tag="SCRAPE",
params={ params={"url": _url, "timing": int((time.perf_counter() - t1) * 1000)},
"url": _url,
"timing": int((time.perf_counter() - t1) * 1000)
}
) )
# Handle content extraction if needed # Handle content extraction if needed
if (not bool(extracted_content) and config.extraction_strategy and not isinstance(config.extraction_strategy, NoExtractionStrategy)): if (
not bool(extracted_content)
and config.extraction_strategy
and not isinstance(config.extraction_strategy, NoExtractionStrategy)
):
t1 = time.perf_counter() t1 = time.perf_counter()
# Choose content based on input_format # Choose content based on input_format
@@ -617,30 +634,33 @@ class AsyncWebCrawler:
self.logger.warning( self.logger.warning(
message="Fit markdown requested but not available. Falling back to raw markdown.", message="Fit markdown requested but not available. Falling back to raw markdown.",
tag="EXTRACT", tag="EXTRACT",
params={"url": _url} params={"url": _url},
) )
content_format = "markdown" content_format = "markdown"
content = { content = {
"markdown": markdown, "markdown": markdown,
"html": html, "html": html,
"fit_markdown": markdown_result.raw_markdown "fit_markdown": markdown_result.raw_markdown,
}.get(content_format, markdown) }.get(content_format, markdown)
# Use IdentityChunking for HTML input, otherwise use provided chunking strategy # Use IdentityChunking for HTML input, otherwise use provided chunking strategy
chunking = IdentityChunking() if content_format == "html" else config.chunking_strategy chunking = (
IdentityChunking()
if content_format == "html"
else config.chunking_strategy
)
sections = chunking.chunk(content) sections = chunking.chunk(content)
extracted_content = config.extraction_strategy.run(url, sections) extracted_content = config.extraction_strategy.run(url, sections)
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) extracted_content = json.dumps(
extracted_content, indent=4, default=str, ensure_ascii=False
)
# Log extraction completion # Log extraction completion
self.logger.info( self.logger.info(
message="Completed for {url:.50}... | Time: {timing}s", message="Completed for {url:.50}... | Time: {timing}s",
tag="EXTRACT", tag="EXTRACT",
params={ params={"url": _url, "timing": time.perf_counter() - t1},
"url": _url,
"timing": time.perf_counter() - t1
}
) )
# Handle screenshot and PDF data # Handle screenshot and PDF data
@@ -739,7 +759,7 @@ class AsyncWebCrawler:
screenshot=screenshot, screenshot=screenshot,
pdf=pdf, pdf=pdf,
verbose=verbose, verbose=verbose,
**kwargs **kwargs,
) )
# # Initialize the dispatcher with the selected strategy # # Initialize the dispatcher with the selected strategy
@@ -754,17 +774,13 @@ class AsyncWebCrawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
self, self,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(1.0, 3.0), base_delay=(1.0, 3.0), max_delay=60.0, max_retries=3
max_delay=60.0, ),
max_retries=3
)
) )
# Run the URLs through the dispatcher # Run the URLs through the dispatcher
_results: List[CrawlerTaskResult] = await dispatcher.run_urls( _results: List[CrawlerTaskResult] = await dispatcher.run_urls(
crawler=self, crawler=self, urls=urls, config=config
urls=urls,
config=config
) )
results: CrawlResult = [] results: CrawlResult = []
@@ -776,7 +792,7 @@ class AsyncWebCrawler:
peak_memory=res.peak_memory, peak_memory=res.peak_memory,
start_time=res.start_time, start_time=res.start_time,
end_time=res.end_time, end_time=res.end_time,
error_message=res.error_message error_message=res.error_message,
) )
_res.dispatch_result = dispatch_result _res.dispatch_result = dispatch_result
results.append(_res) results.append(_res)

View File

@@ -12,6 +12,7 @@ class CacheMode(Enum):
- WRITE_ONLY: Only write to cache, don't read - WRITE_ONLY: Only write to cache, don't read
- BYPASS: Bypass cache for this operation - BYPASS: Bypass cache for this operation
""" """
ENABLED = "enabled" ENABLED = "enabled"
DISABLED = "disabled" DISABLED = "disabled"
READ_ONLY = "read_only" READ_ONLY = "read_only"
@@ -36,6 +37,7 @@ class CacheContext:
is_raw_html (bool): True if the URL is raw HTML, False otherwise. is_raw_html (bool): True if the URL is raw HTML, False otherwise.
_url_display (str): The display name for the URL (web, local file, or raw HTML). _url_display (str): The display name for the URL (web, local file, or raw HTML).
""" """
def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False): def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False):
""" """
Initializes the CacheContext with the provided URL and cache mode. Initializes the CacheContext with the provided URL and cache mode.
@@ -48,8 +50,8 @@ class CacheContext:
self.url = url self.url = url
self.cache_mode = cache_mode self.cache_mode = cache_mode
self.always_bypass = always_bypass self.always_bypass = always_bypass
self.is_cacheable = url.startswith(('http://', 'https://', 'file://')) self.is_cacheable = url.startswith(("http://", "https://", "file://"))
self.is_web_url = url.startswith(('http://', 'https://')) self.is_web_url = url.startswith(("http://", "https://"))
self.is_local_file = url.startswith("file://") self.is_local_file = url.startswith("file://")
self.is_raw_html = url.startswith("raw:") self.is_raw_html = url.startswith("raw:")
self._url_display = url if not self.is_raw_html else "Raw HTML" self._url_display = url if not self.is_raw_html else "Raw HTML"
@@ -94,7 +96,7 @@ def _legacy_to_cache_mode(
disable_cache: bool = False, disable_cache: bool = False,
bypass_cache: bool = False, bypass_cache: bool = False,
no_cache_read: bool = False, no_cache_read: bool = False,
no_cache_write: bool = False no_cache_write: bool = False,
) -> CacheMode: ) -> CacheMode:
""" """
Converts legacy cache parameters to the new CacheMode enum. Converts legacy cache parameters to the new CacheMode enum.

View File

@@ -3,7 +3,7 @@ import re
from collections import Counter from collections import Counter
import string import string
from .model_loader import load_nltk_punkt from .model_loader import load_nltk_punkt
from .utils import *
# Define the abstract base class for chunking strategies # Define the abstract base class for chunking strategies
class ChunkingStrategy(ABC): class ChunkingStrategy(ABC):
@@ -24,19 +24,23 @@ class ChunkingStrategy(ABC):
""" """
pass pass
# Create an identity chunking strategy f(x) = [x] # Create an identity chunking strategy f(x) = [x]
class IdentityChunking(ChunkingStrategy): class IdentityChunking(ChunkingStrategy):
""" """
Chunking strategy that returns the input text as a single chunk. Chunking strategy that returns the input text as a single chunk.
""" """
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
return [text] return [text]
# Regex-based chunking # Regex-based chunking
class RegexChunking(ChunkingStrategy): class RegexChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text based on regular expression patterns. Chunking strategy that splits text based on regular expression patterns.
""" """
def __init__(self, patterns=None, **kwargs): def __init__(self, patterns=None, **kwargs):
""" """
Initialize the RegexChunking object. Initialize the RegexChunking object.
@@ -45,7 +49,7 @@ class RegexChunking(ChunkingStrategy):
patterns (list): A list of regular expression patterns to split text. patterns (list): A list of regular expression patterns to split text.
""" """
if patterns is None: if patterns is None:
patterns = [r'\n\n'] # Default split pattern patterns = [r"\n\n"] # Default split pattern
self.patterns = patterns self.patterns = patterns
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
@@ -57,18 +61,19 @@ class RegexChunking(ChunkingStrategy):
paragraphs = new_paragraphs paragraphs = new_paragraphs
return paragraphs return paragraphs
# NLP-based sentence chunking # NLP-based sentence chunking
class NlpSentenceChunking(ChunkingStrategy): class NlpSentenceChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text into sentences using NLTK's sentence tokenizer. Chunking strategy that splits text into sentences using NLTK's sentence tokenizer.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
Initialize the NlpSentenceChunking object. Initialize the NlpSentenceChunking object.
""" """
load_nltk_punkt() load_nltk_punkt()
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
# Improved regex for sentence splitting # Improved regex for sentence splitting
# sentence_endings = re.compile( # sentence_endings = re.compile(
@@ -77,11 +82,13 @@ class NlpSentenceChunking(ChunkingStrategy):
# sentences = sentence_endings.split(text) # sentences = sentence_endings.split(text)
# sens = [sent.strip() for sent in sentences if sent] # sens = [sent.strip() for sent in sentences if sent]
from nltk.tokenize import sent_tokenize from nltk.tokenize import sent_tokenize
sentences = sent_tokenize(text) sentences = sent_tokenize(text)
sens = [sent.strip() for sent in sentences] sens = [sent.strip() for sent in sentences]
return list(set(sens)) return list(set(sens))
# Topic-based segmentation using TextTiling # Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy): class TopicSegmentationChunking(ChunkingStrategy):
""" """
@@ -100,6 +107,7 @@ class TopicSegmentationChunking(ChunkingStrategy):
num_keywords (int): The number of keywords to extract for each topic segment. num_keywords (int): The number of keywords to extract for each topic segment.
""" """
import nltk as nl import nltk as nl
self.tokenizer = nl.tokenize.TextTilingTokenizer() self.tokenizer = nl.tokenize.TextTilingTokenizer()
self.num_keywords = num_keywords self.num_keywords = num_keywords
@@ -111,8 +119,14 @@ class TopicSegmentationChunking(ChunkingStrategy):
def extract_keywords(self, text: str) -> list: def extract_keywords(self, text: str) -> list:
# Tokenize and remove stopwords and punctuation # Tokenize and remove stopwords and punctuation
import nltk as nl import nltk as nl
tokens = nl.toknize.word_tokenize(text) tokens = nl.toknize.word_tokenize(text)
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation] tokens = [
token.lower()
for token in tokens
if token not in nl.corpus.stopwords.words("english")
and token not in string.punctuation
]
# Calculate frequency distribution # Calculate frequency distribution
freq_dist = Counter(tokens) freq_dist = Counter(tokens)
@@ -123,9 +137,12 @@ class TopicSegmentationChunking(ChunkingStrategy):
# Segment the text into topics # Segment the text into topics
segments = self.chunk(text) segments = self.chunk(text)
# Extract keywords for each topic segment # Extract keywords for each topic segment
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments] segments_with_topics = [
(segment, self.extract_keywords(segment)) for segment in segments
]
return segments_with_topics return segments_with_topics
# Fixed-length word chunks # Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy): class FixedLengthWordChunking(ChunkingStrategy):
""" """
@@ -136,6 +153,7 @@ class FixedLengthWordChunking(ChunkingStrategy):
2. Create chunks of fixed length 2. Create chunks of fixed length
3. Return the list of chunks 3. Return the list of chunks
""" """
def __init__(self, chunk_size=100, **kwargs): def __init__(self, chunk_size=100, **kwargs):
""" """
Initialize the fixed-length word chunking strategy with the given chunk size. Initialize the fixed-length word chunking strategy with the given chunk size.
@@ -147,7 +165,11 @@ class FixedLengthWordChunking(ChunkingStrategy):
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
words = text.split() words = text.split()
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)] return [
" ".join(words[i : i + self.chunk_size])
for i in range(0, len(words), self.chunk_size)
]
# Sliding window chunking # Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy): class SlidingWindowChunking(ChunkingStrategy):
@@ -159,6 +181,7 @@ class SlidingWindowChunking(ChunkingStrategy):
2. Create chunks of fixed length 2. Create chunks of fixed length
3. Return the list of chunks 3. Return the list of chunks
""" """
def __init__(self, window_size=100, step=50, **kwargs): def __init__(self, window_size=100, step=50, **kwargs):
""" """
Initialize the sliding window chunking strategy with the given window size and Initialize the sliding window chunking strategy with the given window size and
@@ -179,15 +202,16 @@ class SlidingWindowChunking(ChunkingStrategy):
return [text] return [text]
for i in range(0, len(words) - self.window_size + 1, self.step): for i in range(0, len(words) - self.window_size + 1, self.step):
chunk = ' '.join(words[i:i + self.window_size]) chunk = " ".join(words[i : i + self.window_size])
chunks.append(chunk) chunks.append(chunk)
# Handle the last chunk if it doesn't align perfectly # Handle the last chunk if it doesn't align perfectly
if i + self.window_size < len(words): if i + self.window_size < len(words):
chunks.append(' '.join(words[-self.window_size:])) chunks.append(" ".join(words[-self.window_size :]))
return chunks return chunks
class OverlappingWindowChunking(ChunkingStrategy): class OverlappingWindowChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text into overlapping word chunks. Chunking strategy that splits text into overlapping word chunks.
@@ -198,6 +222,7 @@ class OverlappingWindowChunking(ChunkingStrategy):
3. Slide the window by the overlap size 3. Slide the window by the overlap size
4. Return the list of chunks 4. Return the list of chunks
""" """
def __init__(self, window_size=1000, overlap=100, **kwargs): def __init__(self, window_size=1000, overlap=100, **kwargs):
""" """
Initialize the overlapping window chunking strategy with the given window size and Initialize the overlapping window chunking strategy with the given window size and
@@ -220,7 +245,7 @@ class OverlappingWindowChunking(ChunkingStrategy):
start = 0 start = 0
while start < len(words): while start < len(words):
end = start + self.window_size end = start + self.window_size
chunk = ' '.join(words[start:end]) chunk = " ".join(words[start:end])
chunks.append(chunk) chunks.append(chunk)
if end >= len(words): if end >= len(words):

View File

@@ -8,14 +8,21 @@ from .async_logger import AsyncLogger
logger = AsyncLogger(verbose=True) logger = AsyncLogger(verbose=True)
docs_manager = DocsManager(logger) docs_manager = DocsManager(logger)
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2): def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
"""Print formatted table with headers and rows""" """Print formatted table with headers and rows"""
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)] widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+' border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+"
def format_row(row): def format_row(row):
return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}" return (
for cell, w in zip(row, widths)) + '|' "|"
+ "|".join(
f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
for cell, w in zip(row, widths)
)
+ "|"
)
click.echo(border) click.echo(border)
click.echo(format_row(headers)) click.echo(format_row(headers))
@@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
click.echo(format_row(row)) click.echo(format_row(row))
click.echo(border) click.echo(border)
@click.group() @click.group()
def cli(): def cli():
"""Crawl4AI Command Line Interface""" """Crawl4AI Command Line Interface"""
pass pass
@cli.group() @cli.group()
def docs(): def docs():
"""Documentation operations""" """Documentation operations"""
pass pass
@docs.command() @docs.command()
@click.argument('sections', nargs=-1) @click.argument("sections", nargs=-1)
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended') @click.option(
"--mode", type=click.Choice(["extended", "condensed"]), default="extended"
)
def combine(sections: tuple, mode: str): def combine(sections: tuple, mode: str):
"""Combine documentation sections""" """Combine documentation sections"""
try: try:
@@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str):
logger.error(str(e), tag="ERROR") logger.error(str(e), tag="ERROR")
sys.exit(1) sys.exit(1)
@docs.command() @docs.command()
@click.argument('query') @click.argument("query")
@click.option('--top-k', '-k', default=5) @click.option("--top-k", "-k", default=5)
@click.option('--build-index', is_flag=True, help='Build index if missing') @click.option("--build-index", is_flag=True, help="Build index if missing")
def search(query: str, top_k: int, build_index: bool): def search(query: str, top_k: int, build_index: bool):
"""Search documentation""" """Search documentation"""
try: try:
result = docs_manager.search(query, top_k) result = docs_manager.search(query, top_k)
if result == "No search index available. Call build_search_index() first.": if result == "No search index available. Call build_search_index() first.":
if build_index or click.confirm('No search index found. Build it now?'): if build_index or click.confirm("No search index found. Build it now?"):
asyncio.run(docs_manager.llm_text.generate_index_files()) asyncio.run(docs_manager.llm_text.generate_index_files())
result = docs_manager.search(query, top_k) result = docs_manager.search(query, top_k)
click.echo(result) click.echo(result)
@@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool):
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
@docs.command() @docs.command()
def update(): def update():
"""Update docs from GitHub""" """Update docs from GitHub"""
@@ -73,22 +87,25 @@ def update():
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
@docs.command() @docs.command()
@click.option('--force-facts', is_flag=True, help='Force regenerate fact files') @click.option("--force-facts", is_flag=True, help="Force regenerate fact files")
@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache') @click.option("--clear-cache", is_flag=True, help="Clear BM25 cache")
def index(force_facts: bool, clear_cache: bool): def index(force_facts: bool, clear_cache: bool):
"""Build or rebuild search indexes""" """Build or rebuild search indexes"""
try: try:
asyncio.run(docs_manager.ensure_docs_exist()) asyncio.run(docs_manager.ensure_docs_exist())
asyncio.run(docs_manager.llm_text.generate_index_files( asyncio.run(
force_generate_facts=force_facts, docs_manager.llm_text.generate_index_files(
clear_bm25_cache=clear_cache force_generate_facts=force_facts, clear_bm25_cache=clear_cache
)) )
)
click.echo("Search indexes built successfully") click.echo("Search indexes built successfully")
except Exception as e: except Exception as e:
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
# Add docs list command # Add docs list command
@docs.command() @docs.command()
def list(): def list():
@@ -101,5 +118,6 @@ def list():
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
if __name__ == '__main__':
if __name__ == "__main__":
cli() cli()

View File

@@ -30,18 +30,40 @@ WORD_TOKEN_RATE = 1.3
MIN_WORD_THRESHOLD = 1 MIN_WORD_THRESHOLD = 1
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1 IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1
IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height'] IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"]
ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark'] ONLY_TEXT_ELIGIBLE_TAGS = [
"b",
"i",
"u",
"span",
"del",
"ins",
"sub",
"sup",
"strong",
"em",
"code",
"kbd",
"var",
"s",
"q",
"abbr",
"cite",
"dfn",
"time",
"small",
"mark",
]
SOCIAL_MEDIA_DOMAINS = [ SOCIAL_MEDIA_DOMAINS = [
'facebook.com', "facebook.com",
'twitter.com', "twitter.com",
'x.com', "x.com",
'linkedin.com', "linkedin.com",
'instagram.com', "instagram.com",
'pinterest.com', "pinterest.com",
'tiktok.com', "tiktok.com",
'snapchat.com', "snapchat.com",
'reddit.com', "reddit.com",
] ]
# Threshold for the Image extraction - Range is 1 to 6 # Threshold for the Image extraction - Range is 1 to 6

View File

@@ -1,44 +1,85 @@
import re import re
from bs4 import BeautifulSoup, Tag from bs4 import BeautifulSoup, Tag
from typing import List, Tuple, Dict from typing import List, Tuple
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
from time import perf_counter
from collections import deque from collections import deque
from bs4 import BeautifulSoup, NavigableString, Tag, Comment from bs4 import NavigableString, Comment
from .utils import clean_tokens from .utils import clean_tokens
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import math import math
from snowballstemmer import stemmer from snowballstemmer import stemmer
class RelevantContentFilter(ABC): class RelevantContentFilter(ABC):
"""Abstract base class for content filtering strategies""" """Abstract base class for content filtering strategies"""
def __init__(self, user_query: str = None): def __init__(self, user_query: str = None):
self.user_query = user_query self.user_query = user_query
self.included_tags = { self.included_tags = {
# Primary structure # Primary structure
'article', 'main', 'section', 'div', "article",
"main",
"section",
"div",
# List structures # List structures
'ul', 'ol', 'li', 'dl', 'dt', 'dd', "ul",
"ol",
"li",
"dl",
"dt",
"dd",
# Text content # Text content
'p', 'span', 'blockquote', 'pre', 'code', "p",
"span",
"blockquote",
"pre",
"code",
# Headers # Headers
'h1', 'h2', 'h3', 'h4', 'h5', 'h6', "h1",
"h2",
"h3",
"h4",
"h5",
"h6",
# Tables # Tables
'table', 'thead', 'tbody', 'tr', 'td', 'th', "table",
"thead",
"tbody",
"tr",
"td",
"th",
# Other semantic elements # Other semantic elements
'figure', 'figcaption', 'details', 'summary', "figure",
"figcaption",
"details",
"summary",
# Text formatting # Text formatting
'em', 'strong', 'b', 'i', 'mark', 'small', "em",
"strong",
"b",
"i",
"mark",
"small",
# Rich content # Rich content
'time', 'address', 'cite', 'q' "time",
"address",
"cite",
"q",
} }
self.excluded_tags = { self.excluded_tags = {
'nav', 'footer', 'header', 'aside', 'script', "nav",
'style', 'form', 'iframe', 'noscript' "footer",
"header",
"aside",
"script",
"style",
"form",
"iframe",
"noscript",
} }
self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'} self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"}
self.negative_patterns = re.compile( self.negative_patterns = re.compile(
r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share', r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I
re.I
) )
self.min_word_count = 2 self.min_word_count = 2
@@ -62,28 +103,30 @@ class RelevantContentFilter(ABC):
except Exception: except Exception:
pass pass
if soup.find('h1'): if soup.find("h1"):
query_parts.append(soup.find('h1').get_text()) query_parts.append(soup.find("h1").get_text())
# Meta tags # Meta tags
temp = "" temp = ""
for meta_name in ['keywords', 'description']: for meta_name in ["keywords", "description"]:
meta = soup.find('meta', attrs={'name': meta_name}) meta = soup.find("meta", attrs={"name": meta_name})
if meta and meta.get('content'): if meta and meta.get("content"):
query_parts.append(meta['content']) query_parts.append(meta["content"])
temp += meta['content'] temp += meta["content"]
# If still empty, grab first significant paragraph # If still empty, grab first significant paragraph
if not temp: if not temp:
# Find the first tag P thatits text contains more than 50 characters # Find the first tag P thatits text contains more than 50 characters
for p in body.find_all('p'): for p in body.find_all("p"):
if len(p.get_text()) > 150: if len(p.get_text()) > 150:
query_parts.append(p.get_text()[:150]) query_parts.append(p.get_text()[:150])
break break
return ' '.join(filter(None, query_parts)) return " ".join(filter(None, query_parts))
def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]: def extract_text_chunks(
self, body: Tag, min_word_threshold: int = None
) -> List[Tuple[str, str]]:
""" """
Extracts text chunks from a BeautifulSoup body element while preserving order. Extracts text chunks from a BeautifulSoup body element while preserving order.
Returns list of tuples (text, tag_name) for classification. Returns list of tuples (text, tag_name) for classification.
@@ -96,14 +139,42 @@ class RelevantContentFilter(ABC):
""" """
# Tags to ignore - inline elements that shouldn't break text flow # Tags to ignore - inline elements that shouldn't break text flow
INLINE_TAGS = { INLINE_TAGS = {
'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code', "a",
'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q', "abbr",
'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup', "acronym",
'textarea', 'time', 'tt', 'var' "b",
"bdo",
"big",
"br",
"button",
"cite",
"code",
"dfn",
"em",
"i",
"img",
"input",
"kbd",
"label",
"map",
"object",
"q",
"samp",
"script",
"select",
"small",
"span",
"strong",
"sub",
"sup",
"textarea",
"time",
"tt",
"var",
} }
# Tags that typically contain meaningful headers # Tags that typically contain meaningful headers
HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'} HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"}
chunks = [] chunks = []
current_text = [] current_text = []
@@ -111,9 +182,8 @@ class RelevantContentFilter(ABC):
def should_break_chunk(tag: Tag) -> bool: def should_break_chunk(tag: Tag) -> bool:
"""Determine if a tag should cause a break in the current text chunk""" """Determine if a tag should cause a break in the current text chunk"""
return ( return tag.name not in INLINE_TAGS and not (
tag.name not in INLINE_TAGS tag.name == "p" and len(current_text) == 0
and not (tag.name == 'p' and len(current_text) == 0)
) )
# Use deque for efficient push/pop operations # Use deque for efficient push/pop operations
@@ -125,9 +195,11 @@ class RelevantContentFilter(ABC):
if visited: if visited:
# End of block element - flush accumulated text # End of block element - flush accumulated text
if current_text and should_break_chunk(element): if current_text and should_break_chunk(element):
text = ' '.join(''.join(current_text).split()) text = " ".join("".join(current_text).split())
if text: if text:
tag_type = 'header' if element.name in HEADER_TAGS else 'content' tag_type = (
"header" if element.name in HEADER_TAGS else "content"
)
chunks.append((chunk_index, text, tag_type, element)) chunks.append((chunk_index, text, tag_type, element))
chunk_index += 1 chunk_index += 1
current_text = [] current_text = []
@@ -153,18 +225,23 @@ class RelevantContentFilter(ABC):
# Handle any remaining text # Handle any remaining text
if current_text: if current_text:
text = ' '.join(''.join(current_text).split()) text = " ".join("".join(current_text).split())
if text: if text:
chunks.append((chunk_index, text, 'content', body)) chunks.append((chunk_index, text, "content", body))
if min_word_threshold: if min_word_threshold:
chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold] chunks = [
chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold
]
return chunks return chunks
def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]: def _deprecated_extract_text_chunks(
self, soup: BeautifulSoup
) -> List[Tuple[int, str, Tag]]:
"""Common method for extracting text chunks""" """Common method for extracting text chunks"""
_text_cache = {} _text_cache = {}
def fast_text(element: Tag) -> str: def fast_text(element: Tag) -> str:
elem_id = id(element) elem_id = id(element)
if elem_id in _text_cache: if elem_id in _text_cache:
@@ -175,7 +252,7 @@ class RelevantContentFilter(ABC):
text = content.strip() text = content.strip()
if text: if text:
texts.append(text) texts.append(text)
result = ' '.join(texts) result = " ".join(texts)
_text_cache[elem_id] = result _text_cache[elem_id] = result
return result return result
@@ -210,10 +287,9 @@ class RelevantContentFilter(ABC):
"""Common method for exclusion logic""" """Common method for exclusion logic"""
if tag.name in self.excluded_tags: if tag.name in self.excluded_tags:
return True return True
class_id = ' '.join(filter(None, [ class_id = " ".join(
' '.join(tag.get('class', [])), filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")])
tag.get('id', '') )
]))
return bool(self.negative_patterns.search(class_id)) return bool(self.negative_patterns.search(class_id))
def clean_element(self, tag: Tag) -> str: def clean_element(self, tag: Tag) -> str:
@@ -221,8 +297,16 @@ class RelevantContentFilter(ABC):
if not tag or not isinstance(tag, Tag): if not tag or not isinstance(tag, Tag):
return "" return ""
unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'} unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"}
unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'} unwanted_attrs = {
"style",
"onclick",
"onmouseover",
"align",
"bgcolor",
"class",
"id",
}
# Use string builder pattern for better performance # Use string builder pattern for better performance
builder = [] builder = []
@@ -237,28 +321,29 @@ class RelevantContentFilter(ABC):
return return
# Start tag # Start tag
builder.append(f'<{elem.name}') builder.append(f"<{elem.name}")
# Add cleaned attributes # Add cleaned attributes
attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs} attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs}
for key, value in attrs.items(): for key, value in attrs.items():
builder.append(f' {key}="{value}"') builder.append(f' {key}="{value}"')
builder.append('>') builder.append(">")
# Process children # Process children
for child in elem.children: for child in elem.children:
render_tag(child) render_tag(child)
# Close tag # Close tag
builder.append(f'</{elem.name}>') builder.append(f"</{elem.name}>")
try: try:
render_tag(tag) render_tag(tag)
return ''.join(builder) return "".join(builder)
except Exception: except Exception:
return str(tag) # Fallback to original if anything fails return str(tag) # Fallback to original if anything fails
class BM25ContentFilter(RelevantContentFilter): class BM25ContentFilter(RelevantContentFilter):
""" """
Content filtering using BM25 algorithm with priority tag handling. Content filtering using BM25 algorithm with priority tag handling.
@@ -280,7 +365,13 @@ class BM25ContentFilter(RelevantContentFilter):
Methods: Methods:
filter_content(self, html: str, min_word_threshold: int = None) filter_content(self, html: str, min_word_threshold: int = None)
""" """
def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'):
def __init__(
self,
user_query: str = None,
bm25_threshold: float = 1.0,
language: str = "english",
):
""" """
Initializes the BM25ContentFilter class, if not provided, falls back to page metadata. Initializes the BM25ContentFilter class, if not provided, falls back to page metadata.
@@ -295,17 +386,17 @@ class BM25ContentFilter(RelevantContentFilter):
super().__init__(user_query=user_query) super().__init__(user_query=user_query)
self.bm25_threshold = bm25_threshold self.bm25_threshold = bm25_threshold
self.priority_tags = { self.priority_tags = {
'h1': 5.0, "h1": 5.0,
'h2': 4.0, "h2": 4.0,
'h3': 3.0, "h3": 3.0,
'title': 4.0, "title": 4.0,
'strong': 2.0, "strong": 2.0,
'b': 1.5, "b": 1.5,
'em': 1.5, "em": 1.5,
'blockquote': 2.0, "blockquote": 2.0,
'code': 2.0, "code": 2.0,
'pre': 1.5, "pre": 1.5,
'th': 1.5, # Table headers "th": 1.5, # Table headers
} }
self.stemmer = stemmer(language) self.stemmer = stemmer(language)
@@ -327,13 +418,13 @@ class BM25ContentFilter(RelevantContentFilter):
if not html or not isinstance(html, str): if not html or not isinstance(html, str):
return [] return []
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
# Check if body is present # Check if body is present
if not soup.body: if not soup.body:
# Wrap in body tag if missing # Wrap in body tag if missing
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml') soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
body = soup.find('body') body = soup.find("body")
query = self.extract_page_query(soup, body) query = self.extract_page_query(soup, body)
@@ -354,9 +445,13 @@ class BM25ContentFilter(RelevantContentFilter):
# for _, chunk, _, _ in candidates] # for _, chunk, _, _ in candidates]
# tokenized_query = [ps.stem(word) for word in query.lower().split()] # tokenized_query = [ps.stem(word) for word in query.lower().split()]
tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()] tokenized_corpus = [
for _, chunk, _, _ in candidates] [self.stemmer.stemWord(word) for word in chunk.lower().split()]
tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()] for _, chunk, _, _ in candidates
]
tokenized_query = [
self.stemmer.stemWord(word) for word in query.lower().split()
]
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] # tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
# for _, chunk, _, _ in candidates] # for _, chunk, _, _ in candidates]
@@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter):
# Filter candidates by threshold # Filter candidates by threshold
selected_candidates = [ selected_candidates = [
(index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates (index, chunk, tag)
for adjusted_score, index, chunk, tag in adjusted_candidates
if adjusted_score >= self.bm25_threshold if adjusted_score >= self.bm25_threshold
] ]
@@ -390,6 +486,7 @@ class BM25ContentFilter(RelevantContentFilter):
return [self.clean_element(tag) for _, _, tag in selected_candidates] return [self.clean_element(tag) for _, _, tag in selected_candidates]
class PruningContentFilter(RelevantContentFilter): class PruningContentFilter(RelevantContentFilter):
""" """
Content filtering using pruning algorithm with dynamic threshold. Content filtering using pruning algorithm with dynamic threshold.
@@ -411,8 +508,14 @@ class PruningContentFilter(RelevantContentFilter):
Methods: Methods:
filter_content(self, html: str, min_word_threshold: int = None): filter_content(self, html: str, min_word_threshold: int = None):
""" """
def __init__(self, user_query: str = None, min_word_threshold: int = None,
threshold_type: str = 'fixed', threshold: float = 0.48): def __init__(
self,
user_query: str = None,
min_word_threshold: int = None,
threshold_type: str = "fixed",
threshold: float = 0.48,
):
""" """
Initializes the PruningContentFilter class, if not provided, falls back to page metadata. Initializes the PruningContentFilter class, if not provided, falls back to page metadata.
@@ -432,49 +535,49 @@ class PruningContentFilter(RelevantContentFilter):
# Add tag importance for dynamic threshold # Add tag importance for dynamic threshold
self.tag_importance = { self.tag_importance = {
'article': 1.5, "article": 1.5,
'main': 1.4, "main": 1.4,
'section': 1.3, "section": 1.3,
'p': 1.2, "p": 1.2,
'h1': 1.4, "h1": 1.4,
'h2': 1.3, "h2": 1.3,
'h3': 1.2, "h3": 1.2,
'div': 0.7, "div": 0.7,
'span': 0.6 "span": 0.6,
} }
# Metric configuration # Metric configuration
self.metric_config = { self.metric_config = {
'text_density': True, "text_density": True,
'link_density': True, "link_density": True,
'tag_weight': True, "tag_weight": True,
'class_id_weight': True, "class_id_weight": True,
'text_length': True, "text_length": True,
} }
self.metric_weights = { self.metric_weights = {
'text_density': 0.4, "text_density": 0.4,
'link_density': 0.2, "link_density": 0.2,
'tag_weight': 0.2, "tag_weight": 0.2,
'class_id_weight': 0.1, "class_id_weight": 0.1,
'text_length': 0.1, "text_length": 0.1,
} }
self.tag_weights = { self.tag_weights = {
'div': 0.5, "div": 0.5,
'p': 1.0, "p": 1.0,
'article': 1.5, "article": 1.5,
'section': 1.0, "section": 1.0,
'span': 0.3, "span": 0.3,
'li': 0.5, "li": 0.5,
'ul': 0.5, "ul": 0.5,
'ol': 0.5, "ol": 0.5,
'h1': 1.2, "h1": 1.2,
'h2': 1.1, "h2": 1.1,
'h3': 1.0, "h3": 1.0,
'h4': 0.9, "h4": 0.9,
'h5': 0.8, "h5": 0.8,
'h6': 0.7, "h6": 0.7,
} }
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
@@ -495,22 +598,22 @@ class PruningContentFilter(RelevantContentFilter):
if not html or not isinstance(html, str): if not html or not isinstance(html, str):
return [] return []
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
if not soup.body: if not soup.body:
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml') soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
# Remove comments and unwanted tags # Remove comments and unwanted tags
self._remove_comments(soup) self._remove_comments(soup)
self._remove_unwanted_tags(soup) self._remove_unwanted_tags(soup)
# Prune tree starting from body # Prune tree starting from body
body = soup.find('body') body = soup.find("body")
self._prune_tree(body) self._prune_tree(body)
# Extract remaining content as list of HTML strings # Extract remaining content as list of HTML strings
content_blocks = [] content_blocks = []
for element in body.children: for element in body.children:
if isinstance(element, str) or not hasattr(element, 'name'): if isinstance(element, str) or not hasattr(element, "name"):
continue continue
if len(element.get_text(strip=True)) > 0: if len(element.get_text(strip=True)) > 0:
content_blocks.append(str(element)) content_blocks.append(str(element))
@@ -535,24 +638,28 @@ class PruningContentFilter(RelevantContentFilter):
Args: Args:
node (Tag): The node from which the pruning starts. node (Tag): The node from which the pruning starts.
""" """
if not node or not hasattr(node, 'name') or node.name is None: if not node or not hasattr(node, "name") or node.name is None:
return return
text_len = len(node.get_text(strip=True)) text_len = len(node.get_text(strip=True))
tag_len = len(node.encode_contents().decode('utf-8')) tag_len = len(node.encode_contents().decode("utf-8"))
link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s) link_text_len = sum(
len(s.strip())
for s in (a.string for a in node.find_all("a", recursive=False))
if s
)
metrics = { metrics = {
'node': node, "node": node,
'tag_name': node.name, "tag_name": node.name,
'text_len': text_len, "text_len": text_len,
'tag_len': tag_len, "tag_len": tag_len,
'link_text_len': link_text_len "link_text_len": link_text_len,
} }
score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len) score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len)
if self.threshold_type == 'fixed': if self.threshold_type == "fixed":
should_remove = score < self.threshold should_remove = score < self.threshold
else: # dynamic else: # dynamic
tag_importance = self.tag_importance.get(node.name, 0.7) tag_importance = self.tag_importance.get(node.name, 0.7)
@@ -572,7 +679,7 @@ class PruningContentFilter(RelevantContentFilter):
if should_remove: if should_remove:
node.decompose() node.decompose()
else: else:
children = [child for child in node.children if hasattr(child, 'name')] children = [child for child in node.children if hasattr(child, "name")]
for child in children: for child in children:
self._prune_tree(child) self._prune_tree(child)
@@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter):
"""Computes the composite score""" """Computes the composite score"""
if self.min_word_threshold: if self.min_word_threshold:
# Get raw text from metrics node - avoid extra processing # Get raw text from metrics node - avoid extra processing
text = metrics['node'].get_text(strip=True) text = metrics["node"].get_text(strip=True)
word_count = text.count(' ') + 1 word_count = text.count(" ") + 1
if word_count < self.min_word_threshold: if word_count < self.min_word_threshold:
return -1.0 # Guaranteed removal return -1.0 # Guaranteed removal
score = 0.0 score = 0.0
total_weight = 0.0 total_weight = 0.0
if self.metric_config['text_density']: if self.metric_config["text_density"]:
density = text_len / tag_len if tag_len > 0 else 0 density = text_len / tag_len if tag_len > 0 else 0
score += self.metric_weights['text_density'] * density score += self.metric_weights["text_density"] * density
total_weight += self.metric_weights['text_density'] total_weight += self.metric_weights["text_density"]
if self.metric_config['link_density']: if self.metric_config["link_density"]:
density = 1 - (link_text_len / text_len if text_len > 0 else 0) density = 1 - (link_text_len / text_len if text_len > 0 else 0)
score += self.metric_weights['link_density'] * density score += self.metric_weights["link_density"] * density
total_weight += self.metric_weights['link_density'] total_weight += self.metric_weights["link_density"]
if self.metric_config['tag_weight']: if self.metric_config["tag_weight"]:
tag_score = self.tag_weights.get(metrics['tag_name'], 0.5) tag_score = self.tag_weights.get(metrics["tag_name"], 0.5)
score += self.metric_weights['tag_weight'] * tag_score score += self.metric_weights["tag_weight"] * tag_score
total_weight += self.metric_weights['tag_weight'] total_weight += self.metric_weights["tag_weight"]
if self.metric_config['class_id_weight']: if self.metric_config["class_id_weight"]:
class_score = self._compute_class_id_weight(metrics['node']) class_score = self._compute_class_id_weight(metrics["node"])
score += self.metric_weights['class_id_weight'] * max(0, class_score) score += self.metric_weights["class_id_weight"] * max(0, class_score)
total_weight += self.metric_weights['class_id_weight'] total_weight += self.metric_weights["class_id_weight"]
if self.metric_config['text_length']: if self.metric_config["text_length"]:
score += self.metric_weights['text_length'] * math.log(text_len + 1) score += self.metric_weights["text_length"] * math.log(text_len + 1)
total_weight += self.metric_weights['text_length'] total_weight += self.metric_weights["text_length"]
return score / total_weight if total_weight > 0 else 0 return score / total_weight if total_weight > 0 else 0
def _compute_class_id_weight(self, node): def _compute_class_id_weight(self, node):
"""Computes the class ID weight""" """Computes the class ID weight"""
class_id_score = 0 class_id_score = 0
if 'class' in node.attrs: if "class" in node.attrs:
classes = ' '.join(node['class']) classes = " ".join(node["class"])
if self.negative_patterns.match(classes): if self.negative_patterns.match(classes):
class_id_score -= 0.5 class_id_score -= 0.5
if 'id' in node.attrs: if "id" in node.attrs:
element_id = node['id'] element_id = node["id"]
if self.negative_patterns.match(element_id): if self.negative_patterns.match(element_id):
class_id_score -= 0.5 class_id_score -= 0.5
return class_id_score return class_id_score

File diff suppressed because it is too large Load Diff

View File

@@ -15,32 +15,30 @@ import logging, time
import base64 import base64
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from io import BytesIO from io import BytesIO
from typing import List, Callable from typing import Callable
import requests import requests
import os import os
from pathlib import Path from pathlib import Path
from .utils import * from .utils import *
logger = logging.getLogger('selenium.webdriver.remote.remote_connection') logger = logging.getLogger("selenium.webdriver.remote.remote_connection")
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
logger_driver = logging.getLogger('selenium.webdriver.common.service') logger_driver = logging.getLogger("selenium.webdriver.common.service")
logger_driver.setLevel(logging.WARNING) logger_driver.setLevel(logging.WARNING)
urllib3_logger = logging.getLogger('urllib3.connectionpool') urllib3_logger = logging.getLogger("urllib3.connectionpool")
urllib3_logger.setLevel(logging.WARNING) urllib3_logger.setLevel(logging.WARNING)
# Disable http.client logging # Disable http.client logging
http_client_logger = logging.getLogger('http.client') http_client_logger = logging.getLogger("http.client")
http_client_logger.setLevel(logging.WARNING) http_client_logger.setLevel(logging.WARNING)
# Disable driver_finder and service logging # Disable driver_finder and service logging
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder') driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder")
driver_finder_logger.setLevel(logging.WARNING) driver_finder_logger.setLevel(logging.WARNING)
class CrawlerStrategy(ABC): class CrawlerStrategy(ABC):
@abstractmethod @abstractmethod
def crawl(self, url: str, **kwargs) -> str: def crawl(self, url: str, **kwargs) -> str:
@@ -58,6 +56,7 @@ class CrawlerStrategy(ABC):
def set_hook(self, hook_type: str, hook: Callable): def set_hook(self, hook_type: str, hook: Callable):
pass pass
class CloudCrawlerStrategy(CrawlerStrategy): class CloudCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False): def __init__(self, use_cached_html=False):
super().__init__() super().__init__()
@@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
html = response["results"][0]["html"] html = response["results"][0]["html"]
return sanitize_input_encode(html) return sanitize_input_encode(html)
class LocalSeleniumCrawlerStrategy(CrawlerStrategy): class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None, **kwargs): def __init__(self, use_cached_html=False, js_code=None, **kwargs):
super().__init__() super().__init__()
@@ -87,9 +87,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if kwargs.get("user_agent"): if kwargs.get("user_agent"):
self.options.add_argument("--user-agent=" + kwargs.get("user_agent")) self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
else: else:
user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") user_agent = kwargs.get(
"user_agent",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
)
self.options.add_argument(f"--user-agent={user_agent}") self.options.add_argument(f"--user-agent={user_agent}")
self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") self.options.add_argument(
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
)
self.options.headless = kwargs.get("headless", True) self.options.headless = kwargs.get("headless", True)
if self.options.headless: if self.options.headless:
@@ -123,11 +128,11 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# Hooks # Hooks
self.hooks = { self.hooks = {
'on_driver_created': None, "on_driver_created": None,
'on_user_agent_updated': None, "on_user_agent_updated": None,
'before_get_url': None, "before_get_url": None,
'after_get_url': None, "after_get_url": None,
'before_return_html': None "before_return_html": None,
} }
# chromedriver_autoinstaller.install() # chromedriver_autoinstaller.install()
@@ -138,7 +143,6 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver() # chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver()
# self.service = Service(chromedriver_autoinstaller.install()) # self.service = Service(chromedriver_autoinstaller.install())
# chromedriver_path = ChromeDriverManager().install() # chromedriver_path = ChromeDriverManager().install()
# self.service = Service(chromedriver_path) # self.service = Service(chromedriver_path)
# self.service.log_path = "NUL" # self.service.log_path = "NUL"
@@ -148,14 +152,12 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.service = Service() self.service = Service()
self.driver = webdriver.Chrome(options=self.options) self.driver = webdriver.Chrome(options=self.options)
self.driver = self.execute_hook('on_driver_created', self.driver) self.driver = self.execute_hook("on_driver_created", self.driver)
if kwargs.get("cookies"): if kwargs.get("cookies"):
for cookie in kwargs.get("cookies"): for cookie in kwargs.get("cookies"):
self.driver.add_cookie(cookie) self.driver.add_cookie(cookie)
def set_hook(self, hook_type: str, hook: Callable): def set_hook(self, hook_type: str, hook: Callable):
if hook_type in self.hooks: if hook_type in self.hooks:
self.hooks[hook_type] = hook self.hooks[hook_type] = hook
@@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if isinstance(result, webdriver.Chrome): if isinstance(result, webdriver.Chrome):
return result return result
else: else:
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.") raise TypeError(
f"Hook {hook_type} must return an instance of webdriver.Chrome or None."
)
# If the hook returns None or there is no hook, return self.driver # If the hook returns None or there is no hook, return self.driver
return self.driver return self.driver
@@ -178,13 +182,13 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.add_argument(f"user-agent={user_agent}") self.options.add_argument(f"user-agent={user_agent}")
self.driver.quit() self.driver.quit()
self.driver = webdriver.Chrome(service=self.service, options=self.options) self.driver = webdriver.Chrome(service=self.service, options=self.options)
self.driver = self.execute_hook('on_user_agent_updated', self.driver) self.driver = self.execute_hook("on_user_agent_updated", self.driver)
def set_custom_headers(self, headers: dict): def set_custom_headers(self, headers: dict):
# Enable Network domain for sending headers # Enable Network domain for sending headers
self.driver.execute_cdp_cmd('Network.enable', {}) self.driver.execute_cdp_cmd("Network.enable", {})
# Set extra HTTP headers # Set extra HTTP headers
self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers}) self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers})
def _ensure_page_load(self, max_checks=6, check_interval=0.01): def _ensure_page_load(self, max_checks=6, check_interval=0.01):
initial_length = len(self.driver.page_source) initial_length = len(self.driver.page_source)
@@ -202,36 +206,53 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def crawl(self, url: str, **kwargs) -> str: def crawl(self, url: str, **kwargs) -> str:
# Create md5 hash of the URL # Create md5 hash of the URL
import hashlib import hashlib
url_hash = hashlib.md5(url.encode()).hexdigest() url_hash = hashlib.md5(url.encode()).hexdigest()
if self.use_cached_html: if self.use_cached_html:
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) cache_file_path = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
".crawl4ai",
"cache",
url_hash,
)
if os.path.exists(cache_file_path): if os.path.exists(cache_file_path):
with open(cache_file_path, "r") as f: with open(cache_file_path, "r") as f:
return sanitize_input_encode(f.read()) return sanitize_input_encode(f.read())
try: try:
self.driver = self.execute_hook('before_get_url', self.driver) self.driver = self.execute_hook("before_get_url", self.driver)
if self.verbose: if self.verbose:
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...") print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
self.driver.get(url) # <html><head></head><body></body></html> self.driver.get(url) # <html><head></head><body></body></html>
WebDriverWait(self.driver, 20).until( WebDriverWait(self.driver, 20).until(
lambda d: d.execute_script('return document.readyState') == 'complete' lambda d: d.execute_script("return document.readyState") == "complete"
) )
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "body")) EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
) )
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);") self.driver.execute_script(
"window.scrollTo(0, document.body.scrollHeight);"
)
self.driver = self.execute_hook('after_get_url', self.driver) self.driver = self.execute_hook("after_get_url", self.driver)
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source html = sanitize_input_encode(
can_not_be_done_headless = False # Look at my creativity for naming variables self._ensure_page_load()
) # self.driver.page_source
can_not_be_done_headless = (
False # Look at my creativity for naming variables
)
# TODO: Very ugly approach, but promise to change it! # TODO: Very ugly approach, but promise to change it!
if kwargs.get('bypass_headless', False) or html == "<html><head></head><body></body></html>": if (
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...") kwargs.get("bypass_headless", False)
or html == "<html><head></head><body></body></html>"
):
print(
"[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..."
)
can_not_be_done_headless = True can_not_be_done_headless = True
options = Options() options = Options()
options.headless = False options.headless = False
@@ -239,7 +260,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
options.add_argument("--window-size=5,5") options.add_argument("--window-size=5,5")
driver = webdriver.Chrome(service=self.service, options=options) driver = webdriver.Chrome(service=self.service, options=options)
driver.get(url) driver.get(url)
self.driver = self.execute_hook('after_get_url', driver) self.driver = self.execute_hook("after_get_url", driver)
html = sanitize_input_encode(driver.page_source) html = sanitize_input_encode(driver.page_source)
driver.quit() driver.quit()
@@ -249,17 +270,21 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.driver.execute_script(self.js_code) self.driver.execute_script(self.js_code)
# Optionally, wait for some condition after executing the JS code # Optionally, wait for some condition after executing the JS code
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete" lambda driver: driver.execute_script("return document.readyState")
== "complete"
) )
elif self.js_code and type(self.js_code) == list: elif self.js_code and type(self.js_code) == list:
for js in self.js_code: for js in self.js_code:
self.driver.execute_script(js) self.driver.execute_script(js)
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete" lambda driver: driver.execute_script(
"return document.readyState"
)
== "complete"
) )
# Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky) # Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky)
wait_for = kwargs.get('wait_for', False) wait_for = kwargs.get("wait_for", False)
if wait_for: if wait_for:
if callable(wait_for): if callable(wait_for):
print("[LOG] 🔄 Waiting for condition...") print("[LOG] 🔄 Waiting for condition...")
@@ -272,10 +297,15 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if not can_not_be_done_headless: if not can_not_be_done_headless:
html = sanitize_input_encode(self.driver.page_source) html = sanitize_input_encode(self.driver.page_source)
self.driver = self.execute_hook('before_return_html', self.driver, html) self.driver = self.execute_hook("before_return_html", self.driver, html)
# Store in cache # Store in cache
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) cache_file_path = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
".crawl4ai",
"cache",
url_hash,
)
with open(cache_file_path, "w", encoding="utf-8") as f: with open(cache_file_path, "w", encoding="utf-8") as f:
f.write(html) f.write(html)
@@ -284,16 +314,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return html return html
except InvalidArgumentException as e: except InvalidArgumentException as e:
if not hasattr(e, 'msg'): if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e)) e.msg = sanitize_input_encode(str(e))
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}") raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
except WebDriverException as e: except WebDriverException as e:
# If e does nlt have msg attribute create it and set it to str(e) # If e does nlt have msg attribute create it and set it to str(e)
if not hasattr(e, 'msg'): if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e)) e.msg = sanitize_input_encode(str(e))
raise WebDriverException(f"Failed to crawl {url}: {e.msg}") raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
except Exception as e: except Exception as e:
if not hasattr(e, 'msg'): if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e)) e.msg = sanitize_input_encode(str(e))
raise Exception(f"Failed to crawl {url}: {e.msg}") raise Exception(f"Failed to crawl {url}: {e.msg}")
@@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
try: try:
# Get the dimensions of the page # Get the dimensions of the page
total_width = self.driver.execute_script("return document.body.scrollWidth") total_width = self.driver.execute_script("return document.body.scrollWidth")
total_height = self.driver.execute_script("return document.body.scrollHeight") total_height = self.driver.execute_script(
"return document.body.scrollHeight"
)
# Set the window size to the dimensions of the page # Set the window size to the dimensions of the page
self.driver.set_window_size(total_width, total_height) self.driver.set_window_size(total_width, total_height)
@@ -313,23 +345,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
image = Image.open(BytesIO(screenshot)) image = Image.open(BytesIO(screenshot))
# Convert image to RGB mode (this will handle both RGB and RGBA images) # Convert image to RGB mode (this will handle both RGB and RGBA images)
rgb_image = image.convert('RGB') rgb_image = image.convert("RGB")
# Convert to JPEG and compress # Convert to JPEG and compress
buffered = BytesIO() buffered = BytesIO()
rgb_image.save(buffered, format="JPEG", quality=85) rgb_image.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
if self.verbose: if self.verbose:
print(f"[LOG] 📸 Screenshot taken and converted to base64") print("[LOG] 📸 Screenshot taken and converted to base64")
return img_base64 return img_base64
except Exception as e: except Exception as e:
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}") error_message = sanitize_input_encode(
f"Failed to take screenshot: {str(e)}"
)
print(error_message) print(error_message)
# Generate an image with black background # Generate an image with black background
img = Image.new('RGB', (800, 600), color='black') img = Image.new("RGB", (800, 600), color="black")
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
# Load a font # Load a font
@@ -352,7 +386,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# Convert to base64 # Convert to base64
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format="JPEG") img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_base64 return img_base64

View File

@@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra
os.makedirs(DB_PATH, exist_ok=True) os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
def init_db(): def init_db():
global DB_PATH global DB_PATH
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS crawled_data ( CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY, url TEXT PRIMARY KEY,
html TEXT, html TEXT,
@@ -24,31 +26,42 @@ def init_db():
metadata TEXT DEFAULT "{}", metadata TEXT DEFAULT "{}",
screenshot TEXT DEFAULT "" screenshot TEXT DEFAULT ""
) )
''') """
)
conn.commit() conn.commit()
conn.close() conn.close()
def alter_db_add_screenshot(new_column: str = "media"): def alter_db_add_screenshot(new_column: str = "media"):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') cursor.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error altering database to add screenshot column: {e}") print(f"Error altering database to add screenshot column: {e}")
def check_db_path(): def check_db_path():
if not DB_PATH: if not DB_PATH:
raise ValueError("Database path is not set or is empty.") raise ValueError("Database path is not set or is empty.")
def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
def get_cached_url(
url: str,
) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,)) cursor.execute(
"SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?",
(url,),
)
result = cursor.fetchone() result = cursor.fetchone()
conn.close() conn.close()
return result return result
@@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str
print(f"Error retrieving cached URL: {e}") print(f"Error retrieving cached URL: {e}")
return None return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""):
def cache_url(
url: str,
html: str,
cleaned_html: str,
markdown: str,
extracted_content: str,
success: bool,
media: str = "{}",
links: str = "{}",
metadata: str = "{}",
screenshot: str = "",
):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot) INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET ON CONFLICT(url) DO UPDATE SET
@@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
links = excluded.links, links = excluded.links,
metadata = excluded.metadata, metadata = excluded.metadata,
screenshot = excluded.screenshot screenshot = excluded.screenshot
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)) """,
(
url,
html,
cleaned_html,
markdown,
extracted_content,
success,
media,
links,
metadata,
screenshot,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error caching URL: {e}") print(f"Error caching URL: {e}")
def get_total_count() -> int: def get_total_count() -> int:
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM crawled_data') cursor.execute("SELECT COUNT(*) FROM crawled_data")
result = cursor.fetchone() result = cursor.fetchone()
conn.close() conn.close()
return result[0] return result[0]
@@ -93,43 +133,48 @@ def get_total_count() -> int:
print(f"Error getting total count: {e}") print(f"Error getting total count: {e}")
return 0 return 0
def clear_db(): def clear_db():
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM crawled_data') cursor.execute("DELETE FROM crawled_data")
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error clearing database: {e}") print(f"Error clearing database: {e}")
def flush_db(): def flush_db():
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DROP TABLE crawled_data') cursor.execute("DROP TABLE crawled_data")
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error flushing database: {e}") print(f"Error flushing database: {e}")
def update_existing_records(new_column: str = "media", default_value: str = "{}"): def update_existing_records(new_column: str = "media", default_value: str = "{}"):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL') cursor.execute(
f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL'
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error updating existing records: {e}") print(f"Error updating existing records: {e}")
if __name__ == "__main__": if __name__ == "__main__":
# Delete the existing database file # Delete the existing database file
if os.path.exists(DB_PATH): if os.path.exists(DB_PATH):
os.remove(DB_PATH) os.remove(DB_PATH)
init_db() init_db()
# alter_db_add_screenshot("COL_NAME") # alter_db_add_screenshot("COL_NAME")

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from crawl4ai.async_logger import AsyncLogger from crawl4ai.async_logger import AsyncLogger
from crawl4ai.llmtxt import AsyncLLMTextManager from crawl4ai.llmtxt import AsyncLLMTextManager
class DocsManager: class DocsManager:
def __init__(self, logger=None): def __init__(self, logger=None):
self.docs_dir = Path.home() / ".crawl4ai" / "docs" self.docs_dir = Path.home() / ".crawl4ai" / "docs"
@@ -21,7 +22,10 @@ class DocsManager:
"""Copy from local docs or download from GitHub""" """Copy from local docs or download from GitHub"""
try: try:
# Try local first # Try local first
if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))): if self.local_docs.exists() and (
any(self.local_docs.glob("*.md"))
or any(self.local_docs.glob("*.tokens"))
):
# Empty the local docs directory # Empty the local docs directory
for file_path in self.docs_dir.glob("*.md"): for file_path in self.docs_dir.glob("*.md"):
file_path.unlink() file_path.unlink()
@@ -36,14 +40,14 @@ class DocsManager:
# Fallback to GitHub # Fallback to GitHub
response = requests.get( response = requests.get(
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt", "https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
headers={'Accept': 'application/vnd.github.v3+json'} headers={"Accept": "application/vnd.github.v3+json"},
) )
response.raise_for_status() response.raise_for_status()
for item in response.json(): for item in response.json():
if item['type'] == 'file' and item['name'].endswith('.md'): if item["type"] == "file" and item["name"].endswith(".md"):
content = requests.get(item['download_url']).text content = requests.get(item["download_url"]).text
with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f: with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f:
f.write(content) f.write(content)
return True return True
@@ -57,7 +61,11 @@ class DocsManager:
# Remove [0-9]+_ prefix # Remove [0-9]+_ prefix
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names] names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
# Exclude those end with .xs.md and .q.md # Exclude those end with .xs.md and .q.md
names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")] names = [
name
for name in names
if not name.endswith(".xs") and not name.endswith(".q")
]
return names return names
def generate(self, sections, mode="extended"): def generate(self, sections, mode="extended"):

View File

@@ -1,20 +1,48 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Dict, Optional, Union from typing import Any, List, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import json, time import json
# from optimum.intel import IPEXModel import time
from .prompts import * import os
from .config import *
from .utils import * from .prompts import PROMPT_EXTRACT_BLOCKS
from .models import * from .config import (
DEFAULT_PROVIDER, PROVIDER_MODELS,
CHUNK_TOKEN_THRESHOLD,
OVERLAP_RATE,
WORD_TOKEN_RATE,
PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION,
PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
)
from .utils import * # noqa: F403
from .utils import (
sanitize_html,
calculate_batch_size,
escape_json_string,
perform_completion_with_backoff,
extract_xml_data,
split_and_parse_json_objects,
sanitize_input_encode,
)
from .models import * # noqa: F403
from .models import TokenUsage
from .model_loader import * # noqa: F403
from .model_loader import (
get_device,
load_HF_embedding_model,
load_text_multilabel_classifier,
)
from functools import partial from functools import partial
from .model_loader import *
import math import math
import numpy as np import numpy as np
import re import re
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from lxml import html, etree from lxml import html, etree
from dataclasses import dataclass
class ExtractionStrategy(ABC): class ExtractionStrategy(ABC):
""" """
@@ -56,15 +84,20 @@ class ExtractionStrategy(ABC):
""" """
extracted_content = [] extracted_content = []
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections] futures = [
executor.submit(self.extract, url, section, **kwargs)
for section in sections
]
for future in as_completed(futures): for future in as_completed(futures):
extracted_content.extend(future.result()) extracted_content.extend(future.result())
return extracted_content return extracted_content
class NoExtractionStrategy(ExtractionStrategy): class NoExtractionStrategy(ExtractionStrategy):
""" """
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block. A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
""" """
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
""" """
Extract meaningful blocks or chunks from the given HTML. Extract meaningful blocks or chunks from the given HTML.
@@ -72,13 +105,17 @@ class NoExtractionStrategy(ExtractionStrategy):
return [{"index": 0, "content": html}] return [{"index": 0, "content": html}]
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)] return [
{"index": i, "tags": [], "content": section}
for i, section in enumerate(sections)
]
####################################################### #######################################################
# Strategies using clustering for text data extraction # # Strategies using clustering for text data extraction #
####################################################### #######################################################
class CosineStrategy(ExtractionStrategy): class CosineStrategy(ExtractionStrategy):
""" """
Extract meaningful blocks or chunks from the given HTML using cosine similarity. Extract meaningful blocks or chunks from the given HTML using cosine similarity.
@@ -99,7 +136,18 @@ class CosineStrategy(ExtractionStrategy):
model_name (str): The name of the sentence-transformers model. model_name (str): The name of the sentence-transformers model.
sim_threshold (float): The similarity threshold for clustering. sim_threshold (float): The similarity threshold for clustering.
""" """
def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'sentence-transformers/all-MiniLM-L6-v2', sim_threshold = 0.3, **kwargs):
def __init__(
self,
semantic_filter=None,
word_count_threshold=10,
max_dist=0.2,
linkage_method="ward",
top_k=3,
model_name="sentence-transformers/all-MiniLM-L6-v2",
sim_threshold=0.3,
**kwargs,
):
""" """
Initialize the strategy with clustering parameters. Initialize the strategy with clustering parameters.
@@ -162,7 +210,6 @@ class CosineStrategy(ExtractionStrategy):
# self.tokenizer = self.model.tokenizer # self.tokenizer = self.model.tokenizer
# self.get_embedding_method = "direct" # self.get_embedding_method = "direct"
if self.verbose: if self.verbose:
print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.") print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.")
@@ -170,9 +217,15 @@ class CosineStrategy(ExtractionStrategy):
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64 # self.default_batch_size = 16 if self.device.type == 'cpu' else 64
if self.verbose: if self.verbose:
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") print(
f"[LOG] Model loaded {model_name}, models/reuters, took "
+ str(time.time() - self.timer)
+ " seconds"
)
def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, at_least_k: int = 20) -> List[str]: def filter_documents_embeddings(
self, documents: List[str], semantic_filter: str, at_least_k: int = 20
) -> List[str]:
""" """
Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding. Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
@@ -200,14 +253,24 @@ class CosineStrategy(ExtractionStrategy):
document_embeddings = self.get_embeddings(documents) document_embeddings = self.get_embeddings(documents)
# Calculate cosine similarity between the query embedding and document embeddings # Calculate cosine similarity between the query embedding and document embeddings
similarities = cosine_similarity([query_embedding], document_embeddings).flatten() similarities = cosine_similarity(
[query_embedding], document_embeddings
).flatten()
# Filter documents based on the similarity threshold # Filter documents based on the similarity threshold
filtered_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim >= self.sim_threshold] filtered_docs = [
(doc, sim)
for doc, sim in zip(documents, similarities)
if sim >= self.sim_threshold
]
# If the number of filtered documents is less than at_least_k, sort remaining documents by similarity # If the number of filtered documents is less than at_least_k, sort remaining documents by similarity
if len(filtered_docs) < at_least_k: if len(filtered_docs) < at_least_k:
remaining_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim < self.sim_threshold] remaining_docs = [
(doc, sim)
for doc, sim in zip(documents, similarities)
if sim < self.sim_threshold
]
remaining_docs.sort(key=lambda x: x[1], reverse=True) remaining_docs.sort(key=lambda x: x[1], reverse=True)
filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)]) filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)])
@@ -216,7 +279,9 @@ class CosineStrategy(ExtractionStrategy):
return filtered_docs[:at_least_k] return filtered_docs[:at_least_k]
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=False): def get_embeddings(
self, sentences: List[str], batch_size=None, bypass_buffer=False
):
""" """
Get BERT embeddings for a list of sentences. Get BERT embeddings for a list of sentences.
@@ -231,6 +296,7 @@ class CosineStrategy(ExtractionStrategy):
if self.device.type in ["cpu", "gpu", "cuda", "mps"]: if self.device.type in ["cpu", "gpu", "cuda", "mps"]:
import torch import torch
# Tokenize sentences and convert to tensor # Tokenize sentences and convert to tensor
if batch_size is None: if batch_size is None:
batch_size = self.default_batch_size batch_size = self.default_batch_size
@@ -238,8 +304,12 @@ class CosineStrategy(ExtractionStrategy):
all_embeddings = [] all_embeddings = []
for i in range(0, len(sentences), batch_size): for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i : i + batch_size] batch_sentences = sentences[i : i + batch_size]
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt') encoded_input = self.tokenizer(
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()} batch_sentences, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {
key: tensor.to(self.device) for key, tensor in encoded_input.items()
}
# Ensure no gradients are calculated # Ensure no gradients are calculated
with torch.no_grad(): with torch.no_grad():
@@ -277,18 +347,21 @@ class CosineStrategy(ExtractionStrategy):
# Get embeddings # Get embeddings
from scipy.cluster.hierarchy import linkage, fcluster from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist from scipy.spatial.distance import pdist
self.timer = time.time() self.timer = time.time()
embeddings = self.get_embeddings(sentences, bypass_buffer=True) embeddings = self.get_embeddings(sentences, bypass_buffer=True)
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds") # print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# Compute pairwise cosine distances # Compute pairwise cosine distances
distance_matrix = pdist(embeddings, 'cosine') distance_matrix = pdist(embeddings, "cosine")
# Perform agglomerative clustering respecting order # Perform agglomerative clustering respecting order
linked = linkage(distance_matrix, method=self.linkage_method) linked = linkage(distance_matrix, method=self.linkage_method)
# Form flat clusters # Form flat clusters
labels = fcluster(linked, self.max_dist, criterion='distance') labels = fcluster(linked, self.max_dist, criterion="distance")
return labels return labels
def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]) -> Dict[int, List[str]]: def filter_clusters_by_word_count(
self, clusters: Dict[int, List[str]]
) -> Dict[int, List[str]]:
""" """
Filter clusters to remove those with a word count below the threshold. Filter clusters to remove those with a word count below the threshold.
@@ -327,7 +400,9 @@ class CosineStrategy(ExtractionStrategy):
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
# Pre-filter documents using embeddings and semantic_filter # Pre-filter documents using embeddings and semantic_filter
text_chunks = self.filter_documents_embeddings(text_chunks, self.semantic_filter) text_chunks = self.filter_documents_embeddings(
text_chunks, self.semantic_filter
)
if not text_chunks: if not text_chunks:
return [] return []
@@ -346,16 +421,19 @@ class CosineStrategy(ExtractionStrategy):
filtered_clusters = self.filter_clusters_by_word_count(clusters) filtered_clusters = self.filter_clusters_by_word_count(clusters)
# Convert filtered clusters to a sorted list of dictionaries # Convert filtered clusters to a sorted list of dictionaries
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)] cluster_list = [
{"index": int(idx), "tags": [], "content": " ".join(filtered_clusters[idx])}
for idx in sorted(filtered_clusters)
]
if self.verbose: if self.verbose:
print(f"[LOG] 🚀 Assign tags using {self.device}") print(f"[LOG] 🚀 Assign tags using {self.device}")
if self.device.type in ["gpu", "cuda", "mps", "cpu"]: if self.device.type in ["gpu", "cuda", "mps", "cpu"]:
labels = self.nlp([cluster['content'] for cluster in cluster_list]) labels = self.nlp([cluster["content"] for cluster in cluster_list])
for cluster, label in zip(cluster_list, labels): for cluster, label in zip(cluster_list, labels):
cluster['tags'] = label cluster["tags"] = label
# elif self.device.type == "cpu": # elif self.device.type == "cpu":
# # Process the text with the loaded model # # Process the text with the loaded model
# texts = [cluster['content'] for cluster in cluster_list] # texts = [cluster['content'] for cluster in cluster_list]
@@ -393,7 +471,6 @@ class CosineStrategy(ExtractionStrategy):
return self.extract(url, self.DEL.join(sections), **kwargs) return self.extract(url, self.DEL.join(sections), **kwargs)
####################################################### #######################################################
# Strategies using LLM-based extraction for text data # # Strategies using LLM-based extraction for text data #
####################################################### #######################################################
@@ -419,9 +496,15 @@ class LLMExtractionStrategy(ExtractionStrategy):
total_usage: Accumulated token usage. total_usage: Accumulated token usage.
""" """
def __init__(self, def __init__(
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, self,
instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs): provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
instruction: str = None,
schema: Dict = None,
extraction_type="block",
**kwargs,
):
""" """
Initialize the strategy with clustering parameters. Initialize the strategy with clustering parameters.
@@ -445,14 +528,20 @@ class LLMExtractionStrategy(ExtractionStrategy):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self.provider = provider self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, "no-token") or os.getenv("OPENAI_API_KEY") self.api_token = (
api_token
or PROVIDER_MODELS.get(provider, "no-token")
or os.getenv("OPENAI_API_KEY")
)
self.instruction = instruction self.instruction = instruction
self.extract_type = extraction_type self.extract_type = extraction_type
self.schema = schema self.schema = schema
if schema: if schema:
self.extract_type = "schema" self.extract_type = "schema"
self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD) self.chunk_token_threshold = kwargs.get(
"chunk_token_threshold", CHUNK_TOKEN_THRESHOLD
)
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE) self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE) self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
self.apply_chunking = kwargs.get("apply_chunking", True) self.apply_chunking = kwargs.get("apply_chunking", True)
@@ -467,8 +556,9 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.total_usage = TokenUsage() # Accumulated usage self.total_usage = TokenUsage() # Accumulated usage
if not self.api_token: if not self.api_token:
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.") raise ValueError(
"API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable."
)
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]: def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
""" """
@@ -515,15 +605,19 @@ class LLMExtractionStrategy(ExtractionStrategy):
prompt_with_variables, prompt_with_variables,
self.api_token, self.api_token,
base_url=self.api_base or self.base_url, base_url=self.api_base or self.base_url,
extra_args = self.extra_args extra_args=self.extra_args,
) # , json_response=self.extract_type == "schema") ) # , json_response=self.extract_type == "schema")
# Track usage # Track usage
usage = TokenUsage( usage = TokenUsage(
completion_tokens=response.usage.completion_tokens, completion_tokens=response.usage.completion_tokens,
prompt_tokens=response.usage.prompt_tokens, prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens, total_tokens=response.usage.total_tokens,
completion_tokens_details=response.usage.completion_tokens_details.__dict__ if response.usage.completion_tokens_details else {}, completion_tokens_details=response.usage.completion_tokens_details.__dict__
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__ if response.usage.prompt_tokens_details else {} if response.usage.completion_tokens_details
else {},
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__
if response.usage.prompt_tokens_details
else {},
) )
self.usages.append(usage) self.usages.append(usage)
@@ -533,36 +627,44 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.total_usage.total_tokens += usage.total_tokens self.total_usage.total_tokens += usage.total_tokens
try: try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[
"blocks"
]
blocks = json.loads(blocks) blocks = json.loads(blocks)
for block in blocks: for block in blocks:
block['error'] = False block["error"] = False
except Exception as e: except Exception:
parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content) parsed, unparsed = split_and_parse_json_objects(
response.choices[0].message.content
)
blocks = parsed blocks = parsed
if unparsed: if unparsed:
blocks.append({ blocks.append(
"index": 0, {"index": 0, "error": True, "tags": ["error"], "content": unparsed}
"error": True, )
"tags": ["error"],
"content": unparsed
})
if self.verbose: if self.verbose:
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix) print(
"[LOG] Extracted",
len(blocks),
"blocks from URL:",
url,
"block index:",
ix,
)
return blocks return blocks
def _merge(self, documents, chunk_token_threshold, overlap): def _merge(self, documents, chunk_token_threshold, overlap):
""" """
Merge documents into sections based on chunk_token_threshold and overlap. Merge documents into sections based on chunk_token_threshold and overlap.
""" """
chunks = [] # chunks = []
sections = [] sections = []
total_tokens = 0 total_tokens = 0
# Calculate the total tokens across all documents # Calculate the total tokens across all documents
for document in documents: for document in documents:
total_tokens += len(document.split(' ')) * self.word_token_rate total_tokens += len(document.split(" ")) * self.word_token_rate
# Calculate the number of sections needed # Calculate the number of sections needed
num_sections = math.floor(total_tokens / chunk_token_threshold) num_sections = math.floor(total_tokens / chunk_token_threshold)
@@ -574,7 +676,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
current_chunk = [] current_chunk = []
for document in documents: for document in documents:
tokens = document.split(' ') tokens = document.split(" ")
token_count = len(tokens) * self.word_token_rate token_count = len(tokens) * self.word_token_rate
if total_token_so_far + token_count <= adjusted_chunk_threshold: if total_token_so_far + token_count <= adjusted_chunk_threshold:
@@ -591,17 +693,16 @@ class LLMExtractionStrategy(ExtractionStrategy):
overlap_tokens = current_chunk[-overlap:] overlap_tokens = current_chunk[-overlap:]
current_chunk.extend(overlap_tokens) current_chunk.extend(overlap_tokens)
sections.append(' '.join(current_chunk)) sections.append(" ".join(current_chunk))
current_chunk = tokens current_chunk = tokens
total_token_so_far = token_count total_token_so_far = token_count
# Add the last chunk # Add the last chunk
if current_chunk: if current_chunk:
sections.append(' '.join(current_chunk)) sections.append(" ".join(current_chunk))
return sections return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]: def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
""" """
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy. Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
@@ -615,15 +716,18 @@ class LLMExtractionStrategy(ExtractionStrategy):
""" """
merged_sections = self._merge( merged_sections = self._merge(
sections, self.chunk_token_threshold, sections,
overlap= int(self.chunk_token_threshold * self.overlap_rate) self.chunk_token_threshold,
overlap=int(self.chunk_token_threshold * self.overlap_rate),
) )
extracted_content = [] extracted_content = []
if self.provider.startswith("groq/"): if self.provider.startswith("groq/"):
# Sequential processing with a delay # Sequential processing with a delay
for ix, section in enumerate(merged_sections): for ix, section in enumerate(merged_sections):
extract_func = partial(self.extract, url) extract_func = partial(self.extract, url)
extracted_content.extend(extract_func(ix, sanitize_input_encode(section))) extracted_content.extend(
extract_func(ix, sanitize_input_encode(section))
)
time.sleep(0.5) # 500 ms delay between each processing time.sleep(0.5) # 500 ms delay between each processing
else: else:
# Parallel processing using ThreadPoolExecutor # Parallel processing using ThreadPoolExecutor
@@ -633,7 +737,10 @@ class LLMExtractionStrategy(ExtractionStrategy):
with ThreadPoolExecutor(max_workers=4) as executor: with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url) extract_func = partial(self.extract, url)
futures = [executor.submit(extract_func, ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)] futures = [
executor.submit(extract_func, ix, sanitize_input_encode(section))
for ix, section in enumerate(merged_sections)
]
for future in as_completed(futures): for future in as_completed(futures):
try: try:
@@ -642,17 +749,17 @@ class LLMExtractionStrategy(ExtractionStrategy):
if self.verbose: if self.verbose:
print(f"Error in thread execution: {e}") print(f"Error in thread execution: {e}")
# Add error information to extracted_content # Add error information to extracted_content
extracted_content.append({ extracted_content.append(
{
"index": 0, "index": 0,
"error": True, "error": True,
"tags": ["error"], "tags": ["error"],
"content": str(e) "content": str(e),
}) }
)
return extracted_content return extracted_content
def show_usage(self) -> None: def show_usage(self) -> None:
"""Print a detailed token usage report showing total and per-request usage.""" """Print a detailed token usage report showing total and per-request usage."""
print("\n=== Token Usage Summary ===") print("\n=== Token Usage Summary ===")
@@ -666,14 +773,16 @@ class LLMExtractionStrategy(ExtractionStrategy):
print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}") print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}")
print("-" * 48) print("-" * 48)
for i, usage in enumerate(self.usages, 1): for i, usage in enumerate(self.usages, 1):
print(f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}") print(
f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
)
####################################################### #######################################################
# New extraction strategies for JSON-based extraction # # New extraction strategies for JSON-based extraction #
####################################################### #######################################################
class JsonElementExtractionStrategy(ExtractionStrategy): class JsonElementExtractionStrategy(ExtractionStrategy):
""" """
Abstract base class for extracting structured JSON from HTML content. Abstract base class for extracting structured JSON from HTML content.
@@ -706,8 +815,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
_get_element_attribute(element, attribute): Extracts an attribute's value from an element. _get_element_attribute(element, attribute): Extracts an attribute's value from an element.
""" """
DEL = "\n"
DEL = '\n'
def __init__(self, schema: Dict[str, Any], **kwargs): def __init__(self, schema: Dict[str, Any], **kwargs):
""" """
@@ -718,9 +826,11 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self.schema = schema self.schema = schema
self.verbose = kwargs.get('verbose', False) self.verbose = kwargs.get("verbose", False)
def extract(self, url: str, html_content: str, *q, **kwargs) -> List[Dict[str, Any]]: def extract(
self, url: str, html_content: str, *q, **kwargs
) -> List[Dict[str, Any]]:
""" """
Extract structured data from HTML content. Extract structured data from HTML content.
@@ -740,20 +850,22 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
""" """
parsed_html = self._parse_html(html_content) parsed_html = self._parse_html(html_content)
base_elements = self._get_base_elements(parsed_html, self.schema['baseSelector']) base_elements = self._get_base_elements(
parsed_html, self.schema["baseSelector"]
)
results = [] results = []
for element in base_elements: for element in base_elements:
# Extract base element attributes # Extract base element attributes
item = {} item = {}
if 'baseFields' in self.schema: if "baseFields" in self.schema:
for field in self.schema['baseFields']: for field in self.schema["baseFields"]:
value = self._extract_single_field(element, field) value = self._extract_single_field(element, field)
if value is not None: if value is not None:
item[field['name']] = value item[field["name"]] = value
# Extract child fields # Extract child fields
field_data = self._extract_item(element, self.schema['fields']) field_data = self._extract_item(element, self.schema["fields"])
item.update(field_data) item.update(field_data)
if item: if item:
@@ -778,24 +890,28 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
def _extract_field(self, element, field): def _extract_field(self, element, field):
try: try:
if field['type'] == 'nested': if field["type"] == "nested":
nested_elements = self._get_elements(element, field['selector']) nested_elements = self._get_elements(element, field["selector"])
nested_element = nested_elements[0] if nested_elements else None nested_element = nested_elements[0] if nested_elements else None
return self._extract_item(nested_element, field['fields']) if nested_element else {} return (
self._extract_item(nested_element, field["fields"])
if nested_element
else {}
)
if field['type'] == 'list': if field["type"] == "list":
elements = self._get_elements(element, field['selector']) elements = self._get_elements(element, field["selector"])
return [self._extract_list_item(el, field['fields']) for el in elements] return [self._extract_list_item(el, field["fields"]) for el in elements]
if field['type'] == 'nested_list': if field["type"] == "nested_list":
elements = self._get_elements(element, field['selector']) elements = self._get_elements(element, field["selector"])
return [self._extract_item(el, field['fields']) for el in elements] return [self._extract_item(el, field["fields"]) for el in elements]
return self._extract_single_field(element, field) return self._extract_single_field(element, field)
except Exception as e: except Exception as e:
if self.verbose: if self.verbose:
print(f"Error extracting field {field['name']}: {str(e)}") print(f"Error extracting field {field['name']}: {str(e)}")
return field.get('default') return field.get("default")
def _extract_single_field(self, element, field): def _extract_single_field(self, element, field):
""" """
@@ -814,37 +930,37 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
Any: The extracted field value. Any: The extracted field value.
""" """
if 'selector' in field: if "selector" in field:
selected = self._get_elements(element, field['selector']) selected = self._get_elements(element, field["selector"])
if not selected: if not selected:
return field.get('default') return field.get("default")
selected = selected[0] selected = selected[0]
else: else:
selected = element selected = element
value = None value = None
if field['type'] == 'text': if field["type"] == "text":
value = self._get_element_text(selected) value = self._get_element_text(selected)
elif field['type'] == 'attribute': elif field["type"] == "attribute":
value = self._get_element_attribute(selected, field['attribute']) value = self._get_element_attribute(selected, field["attribute"])
elif field['type'] == 'html': elif field["type"] == "html":
value = self._get_element_html(selected) value = self._get_element_html(selected)
elif field['type'] == 'regex': elif field["type"] == "regex":
text = self._get_element_text(selected) text = self._get_element_text(selected)
match = re.search(field['pattern'], text) match = re.search(field["pattern"], text)
value = match.group(1) if match else None value = match.group(1) if match else None
if 'transform' in field: if "transform" in field:
value = self._apply_transform(value, field['transform']) value = self._apply_transform(value, field["transform"])
return value if value is not None else field.get('default') return value if value is not None else field.get("default")
def _extract_list_item(self, element, fields): def _extract_list_item(self, element, fields):
item = {} item = {}
for field in fields: for field in fields:
value = self._extract_single_field(element, field) value = self._extract_single_field(element, field)
if value is not None: if value is not None:
item[field['name']] = value item[field["name"]] = value
return item return item
def _extract_item(self, element, fields): def _extract_item(self, element, fields):
@@ -866,12 +982,12 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
item = {} item = {}
for field in fields: for field in fields:
if field['type'] == 'computed': if field["type"] == "computed":
value = self._compute_field(item, field) value = self._compute_field(item, field)
else: else:
value = self._extract_field(element, field) value = self._extract_field(element, field)
if value is not None: if value is not None:
item[field['name']] = value item[field["name"]] = value
return item return item
def _apply_transform(self, value, transform): def _apply_transform(self, value, transform):
@@ -891,24 +1007,24 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
str: The transformed value. str: The transformed value.
""" """
if transform == 'lowercase': if transform == "lowercase":
return value.lower() return value.lower()
elif transform == 'uppercase': elif transform == "uppercase":
return value.upper() return value.upper()
elif transform == 'strip': elif transform == "strip":
return value.strip() return value.strip()
return value return value
def _compute_field(self, item, field): def _compute_field(self, item, field):
try: try:
if 'expression' in field: if "expression" in field:
return eval(field['expression'], {}, item) return eval(field["expression"], {}, item)
elif 'function' in field: elif "function" in field:
return field['function'](item) return field["function"](item)
except Exception as e: except Exception as e:
if self.verbose: if self.verbose:
print(f"Error computing field {field['name']}: {str(e)}") print(f"Error computing field {field['name']}: {str(e)}")
return field.get('default') return field.get("default")
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
""" """
@@ -946,6 +1062,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""Get attribute value from element""" """Get attribute value from element"""
pass pass
class JsonCssExtractionStrategy(JsonElementExtractionStrategy): class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
""" """
Concrete implementation of `JsonElementExtractionStrategy` using CSS selectors. Concrete implementation of `JsonElementExtractionStrategy` using CSS selectors.
@@ -969,11 +1086,11 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
""" """
def __init__(self, schema: Dict[str, Any], **kwargs): def __init__(self, schema: Dict[str, Any], **kwargs):
kwargs['input_format'] = 'html' # Force HTML input kwargs["input_format"] = "html" # Force HTML input
super().__init__(schema, **kwargs) super().__init__(schema, **kwargs)
def _parse_html(self, html_content: str): def _parse_html(self, html_content: str):
return BeautifulSoup(html_content, 'html.parser') return BeautifulSoup(html_content, "html.parser")
def _get_base_elements(self, parsed_html, selector: str): def _get_base_elements(self, parsed_html, selector: str):
return parsed_html.select(selector) return parsed_html.select(selector)
@@ -992,6 +1109,7 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
def _get_element_attribute(self, element, attribute: str): def _get_element_attribute(self, element, attribute: str):
return element.get(attribute) return element.get(attribute)
class JsonXPathExtractionStrategy(JsonElementExtractionStrategy): class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
""" """
Concrete implementation of `JsonElementExtractionStrategy` using XPath selectors. Concrete implementation of `JsonElementExtractionStrategy` using XPath selectors.
@@ -1016,7 +1134,7 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
""" """
def __init__(self, schema: Dict[str, Any], **kwargs): def __init__(self, schema: Dict[str, Any], **kwargs):
kwargs['input_format'] = 'html' # Force HTML input kwargs["input_format"] = "html" # Force HTML input
super().__init__(schema, **kwargs) super().__init__(schema, **kwargs)
def _parse_html(self, html_content: str): def _parse_html(self, html_content: str):
@@ -1027,31 +1145,31 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
def _css_to_xpath(self, css_selector: str) -> str: def _css_to_xpath(self, css_selector: str) -> str:
"""Convert CSS selector to XPath if needed""" """Convert CSS selector to XPath if needed"""
if '/' in css_selector: # Already an XPath if "/" in css_selector: # Already an XPath
return css_selector return css_selector
return self._basic_css_to_xpath(css_selector) return self._basic_css_to_xpath(css_selector)
def _basic_css_to_xpath(self, css_selector: str) -> str: def _basic_css_to_xpath(self, css_selector: str) -> str:
"""Basic CSS to XPath conversion for common cases""" """Basic CSS to XPath conversion for common cases"""
if ' > ' in css_selector: if " > " in css_selector:
parts = css_selector.split(' > ') parts = css_selector.split(" > ")
return '//' + '/'.join(parts) return "//" + "/".join(parts)
if ' ' in css_selector: if " " in css_selector:
parts = css_selector.split(' ') parts = css_selector.split(" ")
return '//' + '//'.join(parts) return "//" + "//".join(parts)
return '//' + css_selector return "//" + css_selector
def _get_elements(self, element, selector: str): def _get_elements(self, element, selector: str):
xpath = self._css_to_xpath(selector) xpath = self._css_to_xpath(selector)
if not xpath.startswith('.'): if not xpath.startswith("."):
xpath = '.' + xpath xpath = "." + xpath
return element.xpath(xpath) return element.xpath(xpath)
def _get_element_text(self, element) -> str: def _get_element_text(self, element) -> str:
return ''.join(element.xpath('.//text()')).strip() return "".join(element.xpath(".//text()")).strip()
def _get_element_html(self, element) -> str: def _get_element_html(self, element) -> str:
return etree.tostring(element, encoding='unicode') return etree.tostring(element, encoding="unicode")
def _get_element_attribute(self, element, attribute: str): def _get_element_attribute(self, element, attribute: str):
return element.get(attribute) return element.get(attribute)

View File

@@ -903,7 +903,13 @@ class HTML2Text(html.parser.HTMLParser):
self.empty_link = False self.empty_link = False
if not self.code and not self.pre and not entity_char: if not self.code and not self.pre and not entity_char:
data = escape_md_section(data, snob=self.escape_snob, escape_dot=self.escape_dot, escape_plus=self.escape_plus, escape_dash=self.escape_dash) data = escape_md_section(
data,
snob=self.escape_snob,
escape_dot=self.escape_dot,
escape_plus=self.escape_plus,
escape_dash=self.escape_dash,
)
self.preceding_data = data self.preceding_data = data
self.o(data, puredata=True) self.o(data, puredata=True)
@@ -1006,6 +1012,7 @@ class HTML2Text(html.parser.HTMLParser):
newlines += 1 newlines += 1
return result return result
def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str: def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str:
if bodywidth is None: if bodywidth is None:
bodywidth = config.BODY_WIDTH bodywidth = config.BODY_WIDTH
@@ -1013,6 +1020,7 @@ def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) ->
return h.handle(html) return h.handle(html)
class CustomHTML2Text(HTML2Text): class CustomHTML2Text(HTML2Text):
def __init__(self, *args, handle_code_in_pre=False, **kwargs): def __init__(self, *args, handle_code_in_pre=False, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -1041,9 +1049,9 @@ class CustomHTML2Text(HTML2Text):
def update_params(self, **kwargs): def update_params(self, **kwargs):
"""Update parameters and set preserved tags.""" """Update parameters and set preserved tags."""
for key, value in kwargs.items(): for key, value in kwargs.items():
if key == 'preserve_tags': if key == "preserve_tags":
self.preserve_tags = set(value) self.preserve_tags = set(value)
elif key == 'handle_code_in_pre': elif key == "handle_code_in_pre":
self.handle_code_in_pre = value self.handle_code_in_pre = value
else: else:
setattr(self, key, value) setattr(self, key, value)
@@ -1056,17 +1064,19 @@ class CustomHTML2Text(HTML2Text):
self.current_preserved_tag = tag self.current_preserved_tag = tag
self.preserved_content = [] self.preserved_content = []
# Format opening tag with attributes # Format opening tag with attributes
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) attr_str = "".join(
self.preserved_content.append(f'<{tag}{attr_str}>') f' {k}="{v}"' for k, v in attrs.items() if v is not None
)
self.preserved_content.append(f"<{tag}{attr_str}>")
self.preserve_depth += 1 self.preserve_depth += 1
return return
else: else:
self.preserve_depth -= 1 self.preserve_depth -= 1
if self.preserve_depth == 0: if self.preserve_depth == 0:
self.preserved_content.append(f'</{tag}>') self.preserved_content.append(f"</{tag}>")
# Output the preserved HTML block with proper spacing # Output the preserved HTML block with proper spacing
preserved_html = ''.join(self.preserved_content) preserved_html = "".join(self.preserved_content)
self.o('\n' + preserved_html + '\n') self.o("\n" + preserved_html + "\n")
self.current_preserved_tag = None self.current_preserved_tag = None
return return
@@ -1074,29 +1084,31 @@ class CustomHTML2Text(HTML2Text):
if self.preserve_depth > 0: if self.preserve_depth > 0:
if start: if start:
# Format nested tags with attributes # Format nested tags with attributes
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) attr_str = "".join(
self.preserved_content.append(f'<{tag}{attr_str}>') f' {k}="{v}"' for k, v in attrs.items() if v is not None
)
self.preserved_content.append(f"<{tag}{attr_str}>")
else: else:
self.preserved_content.append(f'</{tag}>') self.preserved_content.append(f"</{tag}>")
return return
# Handle pre tags # Handle pre tags
if tag == 'pre': if tag == "pre":
if start: if start:
self.o('```\n') # Markdown code block start self.o("```\n") # Markdown code block start
self.inside_pre = True self.inside_pre = True
else: else:
self.o('\n```\n') # Markdown code block end self.o("\n```\n") # Markdown code block end
self.inside_pre = False self.inside_pre = False
elif tag == 'code': elif tag == "code":
if self.inside_pre and not self.handle_code_in_pre: if self.inside_pre and not self.handle_code_in_pre:
# Ignore code tags inside pre blocks if handle_code_in_pre is False # Ignore code tags inside pre blocks if handle_code_in_pre is False
return return
if start: if start:
self.o('`') # Markdown inline code start self.o("`") # Markdown inline code start
self.inside_code = True self.inside_code = True
else: else:
self.o('`') # Markdown inline code end self.o("`") # Markdown inline code end
self.inside_code = False self.inside_code = False
else: else:
super().handle_tag(tag, attrs, start) super().handle_tag(tag, attrs, start)
@@ -1113,13 +1125,12 @@ class CustomHTML2Text(HTML2Text):
return return
if self.inside_code: if self.inside_code:
# Inline code: no newlines allowed # Inline code: no newlines allowed
self.o(data.replace('\n', ' ')) self.o(data.replace("\n", " "))
return return
# Default behavior for other tags # Default behavior for other tags
super().handle_data(data, entity_char) super().handle_data(data, entity_char)
# # Handle pre tags # # Handle pre tags
# if tag == 'pre': # if tag == 'pre':
# if start: # if start:

View File

@@ -1,2 +1,3 @@
class OutCallback: class OutCallback:
def __call__(self, s: str) -> None: ... def __call__(self, s: str) -> None:
...

View File

@@ -210,7 +210,7 @@ def escape_md_section(
snob: bool = False, snob: bool = False,
escape_dot: bool = True, escape_dot: bool = True,
escape_plus: bool = True, escape_plus: bool = True,
escape_dash: bool = True escape_dash: bool = True,
) -> str: ) -> str:
""" """
Escapes markdown-sensitive characters across whole document sections. Escapes markdown-sensitive characters across whole document sections.
@@ -233,6 +233,7 @@ def escape_md_section(
return text return text
def reformat_table(lines: List[str], right_margin: int) -> List[str]: def reformat_table(lines: List[str], right_margin: int) -> List[str]:
""" """
Given the lines of a table Given the lines of a table

View File

@@ -6,6 +6,7 @@ from .async_logger import AsyncLogger, LogLevel
# Initialize logger # Initialize logger
logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True) logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
def post_install(): def post_install():
"""Run all post-installation tasks""" """Run all post-installation tasks"""
logger.info("Running post-installation setup...", tag="INIT") logger.info("Running post-installation setup...", tag="INIT")
@@ -13,18 +14,36 @@ def post_install():
run_migration() run_migration()
logger.success("Post-installation setup completed!", tag="COMPLETE") logger.success("Post-installation setup completed!", tag="COMPLETE")
def install_playwright(): def install_playwright():
logger.info("Installing Playwright browsers...", tag="INIT") logger.info("Installing Playwright browsers...", tag="INIT")
try: try:
# subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"]) # subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"])
subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chromium"]) subprocess.check_call(
logger.success("Playwright installation completed successfully.", tag="COMPLETE") [
except subprocess.CalledProcessError as e: sys.executable,
"-m",
"playwright",
"install",
"--with-deps",
"--force",
"chromium",
]
)
logger.success(
"Playwright installation completed successfully.", tag="COMPLETE"
)
except subprocess.CalledProcessError:
# logger.error(f"Error during Playwright installation: {e}", tag="ERROR") # logger.error(f"Error during Playwright installation: {e}", tag="ERROR")
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.") logger.warning(
except Exception as e: f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
)
except Exception:
# logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR") # logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR")
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.") logger.warning(
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
)
def run_migration(): def run_migration():
"""Initialize database during installation""" """Initialize database during installation"""
@@ -33,18 +52,26 @@ def run_migration():
from crawl4ai.async_database import async_db_manager from crawl4ai.async_database import async_db_manager
asyncio.run(async_db_manager.initialize()) asyncio.run(async_db_manager.initialize())
logger.success("Database initialization completed successfully.", tag="COMPLETE") logger.success(
"Database initialization completed successfully.", tag="COMPLETE"
)
except ImportError: except ImportError:
logger.warning("Database module not found. Will initialize on first use.") logger.warning("Database module not found. Will initialize on first use.")
except Exception as e: except Exception as e:
logger.warning(f"Database initialization failed: {e}") logger.warning(f"Database initialization failed: {e}")
logger.warning("Database will be initialized on first use") logger.warning("Database will be initialized on first use")
async def run_doctor(): async def run_doctor():
"""Test if Crawl4AI is working properly""" """Test if Crawl4AI is working properly"""
logger.info("Running Crawl4AI health check...", tag="INIT") logger.info("Running Crawl4AI health check...", tag="INIT")
try: try:
from .async_webcrawler import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from .async_webcrawler import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
CacheMode,
)
browser_config = BrowserConfig( browser_config = BrowserConfig(
headless=True, headless=True,
@@ -52,7 +79,7 @@ async def run_doctor():
ignore_https_errors=True, ignore_https_errors=True,
light_mode=True, light_mode=True,
viewport_width=1280, viewport_width=1280,
viewport_height=720 viewport_height=720,
) )
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
@@ -62,10 +89,7 @@ async def run_doctor():
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
logger.info("Testing crawling capabilities...", tag="TEST") logger.info("Testing crawling capabilities...", tag="TEST")
result = await crawler.arun( result = await crawler.arun(url="https://crawl4ai.com", config=run_config)
url="https://crawl4ai.com",
config=run_config
)
if result and result.markdown: if result and result.markdown:
logger.success("✅ Crawling test passed!", tag="COMPLETE") logger.success("✅ Crawling test passed!", tag="COMPLETE")
@@ -77,7 +101,9 @@ async def run_doctor():
logger.error(f"❌ Test failed: {e}", tag="ERROR") logger.error(f"❌ Test failed: {e}", tag="ERROR")
return False return False
def doctor(): def doctor():
"""Entry point for the doctor command""" """Entry point for the doctor command"""
import asyncio import asyncio
return asyncio.run(run_doctor()) return asyncio.run(run_doctor())

View File

@@ -1,15 +1,18 @@
import os, sys import os
# Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free # Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free
def load_js_script(script_name): def load_js_script(script_name):
# Get the path of the current script # Get the path of the current script
current_script_path = os.path.dirname(os.path.realpath(__file__)) current_script_path = os.path.dirname(os.path.realpath(__file__))
# Get the path of the script to load # Get the path of the script to load
script_path = os.path.join(current_script_path, script_name + '.js') script_path = os.path.join(current_script_path, script_name + ".js")
# Check if the script exists # Check if the script exists
if not os.path.exists(script_path): if not os.path.exists(script_path):
raise ValueError(f"Script {script_name} not found in the folder {current_script_path}") raise ValueError(
f"Script {script_name} not found in the folder {current_script_path}"
)
# Load the content of the script # Load the content of the script
with open(script_path, 'r') as f: with open(script_path, "r") as f:
script_content = f.read() script_content = f.read()
return script_content return script_content

View File

@@ -11,16 +11,16 @@ from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer from nltk.stem import WordNetLemmatizer
from litellm import completion, batch_completion from litellm import batch_completion
from .async_logger import AsyncLogger from .async_logger import AsyncLogger
import litellm import litellm
import pickle import pickle
import hashlib # <--- ADDED for file-hash import hashlib # <--- ADDED for file-hash
from fnmatch import fnmatch
import glob import glob
litellm.set_verbose = False litellm.set_verbose = False
def _compute_file_hash(file_path: Path) -> str: def _compute_file_hash(file_path: Path) -> str:
"""Compute MD5 hash for the file's entire content.""" """Compute MD5 hash for the file's entire content."""
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
@@ -29,13 +29,14 @@ def _compute_file_hash(file_path: Path) -> str:
hash_md5.update(chunk) hash_md5.update(chunk)
return hash_md5.hexdigest() return hash_md5.hexdigest()
class AsyncLLMTextManager: class AsyncLLMTextManager:
def __init__( def __init__(
self, self,
docs_dir: Path, docs_dir: Path,
logger: Optional[AsyncLogger] = None, logger: Optional[AsyncLogger] = None,
max_concurrent_calls: int = 5, max_concurrent_calls: int = 5,
batch_size: int = 3 batch_size: int = 3,
) -> None: ) -> None:
self.docs_dir = docs_dir self.docs_dir = docs_dir
self.logger = logger self.logger = logger
@@ -51,7 +52,7 @@ class AsyncLLMTextManager:
contents = [] contents = []
for file_path in doc_batch: for file_path in doc_batch:
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
contents.append(f.read()) contents.append(f.read())
except Exception as e: except Exception as e:
self.logger.error(f"Error reading {file_path}: {str(e)}") self.logger.error(f"Error reading {file_path}: {str(e)}")
@@ -77,43 +78,53 @@ Wrap your response in <index>...</index> tags.
# Prepare messages for batch processing # Prepare messages for batch processing
messages_list = [ messages_list = [
[ [
{"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"} {
"role": "user",
"content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}",
}
] ]
for content in contents if content for content in contents
if content
] ]
try: try:
responses = batch_completion( responses = batch_completion(
model="anthropic/claude-3-5-sonnet-latest", model="anthropic/claude-3-5-sonnet-latest",
messages=messages_list, messages=messages_list,
logger_fn=None logger_fn=None,
) )
# Process responses and save index files # Process responses and save index files
for response, file_path in zip(responses, doc_batch): for response, file_path in zip(responses, doc_batch):
try: try:
index_content_match = re.search( index_content_match = re.search(
r'<index>(.*?)</index>', r"<index>(.*?)</index>",
response.choices[0].message.content, response.choices[0].message.content,
re.DOTALL re.DOTALL,
) )
if not index_content_match: if not index_content_match:
self.logger.warning(f"No <index>...</index> content found for {file_path}") self.logger.warning(
f"No <index>...</index> content found for {file_path}"
)
continue continue
index_content = re.sub( index_content = re.sub(
r"\n\s*\n", "\n", index_content_match.group(1) r"\n\s*\n", "\n", index_content_match.group(1)
).strip() ).strip()
if index_content: if index_content:
index_file = file_path.with_suffix('.q.md') index_file = file_path.with_suffix(".q.md")
with open(index_file, 'w', encoding='utf-8') as f: with open(index_file, "w", encoding="utf-8") as f:
f.write(index_content) f.write(index_content)
self.logger.info(f"Created index file: {index_file}") self.logger.info(f"Created index file: {index_file}")
else: else:
self.logger.warning(f"No index content found in response for {file_path}") self.logger.warning(
f"No index content found in response for {file_path}"
)
except Exception as e: except Exception as e:
self.logger.error(f"Error processing response for {file_path}: {str(e)}") self.logger.error(
f"Error processing response for {file_path}: {str(e)}"
)
except Exception as e: except Exception as e:
self.logger.error(f"Error in batch completion: {str(e)}") self.logger.error(f"Error in batch completion: {str(e)}")
@@ -171,7 +182,12 @@ Wrap your response in <index>...</index> tags.
lemmatizer = WordNetLemmatizer() lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words("english")) - { stop_words = set(stopwords.words("english")) - {
"how", "what", "when", "where", "why", "which", "how",
"what",
"when",
"where",
"why",
"which",
} }
tokens = [] tokens = []
@@ -222,7 +238,9 @@ Wrap your response in <index>...</index> tags.
self.logger.info("Checking which .q.md files need (re)indexing...") self.logger.info("Checking which .q.md files need (re)indexing...")
# Gather all .q.md files # Gather all .q.md files
q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")] q_files = [
self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")
]
# We'll store known (unchanged) facts in these lists # We'll store known (unchanged) facts in these lists
existing_facts: List[str] = [] existing_facts: List[str] = []
@@ -243,7 +261,9 @@ Wrap your response in <index>...</index> tags.
# Otherwise, load the existing cache and compare hash # Otherwise, load the existing cache and compare hash
cache = self._load_or_create_token_cache(qf) cache = self._load_or_create_token_cache(qf)
# If the .q.tokens was out of date (i.e. changed hash), we reindex # If the .q.tokens was out of date (i.e. changed hash), we reindex
if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf): if len(cache["facts"]) == 0 or cache.get(
"content_hash"
) != _compute_file_hash(qf):
needSet.append(qf) needSet.append(qf)
else: else:
# File is unchanged → retrieve cached token data # File is unchanged → retrieve cached token data
@@ -255,20 +275,29 @@ Wrap your response in <index>...</index> tags.
if not needSet and not clear_cache: if not needSet and not clear_cache:
# If no file needs reindexing, try loading existing index # If no file needs reindexing, try loading existing index
if self.maybe_load_bm25_index(clear_cache=False): if self.maybe_load_bm25_index(clear_cache=False):
self.logger.info("No new/changed .q.md files found. Using existing BM25 index.") self.logger.info(
"No new/changed .q.md files found. Using existing BM25 index."
)
return return
else: else:
# If there's no existing index, we must build a fresh index from the old caches # If there's no existing index, we must build a fresh index from the old caches
self.logger.info("No existing BM25 index found. Building from cached facts.") self.logger.info(
"No existing BM25 index found. Building from cached facts."
)
if existing_facts: if existing_facts:
self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.") self.logger.info(
f"Building BM25 index with {len(existing_facts)} cached facts."
)
self.bm25_index = BM25Okapi(existing_tokens) self.bm25_index = BM25Okapi(existing_tokens)
self.tokenized_facts = existing_facts self.tokenized_facts = existing_facts
with open(self.bm25_index_file, "wb") as f: with open(self.bm25_index_file, "wb") as f:
pickle.dump({ pickle.dump(
{
"bm25_index": self.bm25_index, "bm25_index": self.bm25_index,
"tokenized_facts": self.tokenized_facts "tokenized_facts": self.tokenized_facts,
}, f) },
f,
)
else: else:
self.logger.warning("No facts found at all. Index remains empty.") self.logger.warning("No facts found at all. Index remains empty.")
return return
@@ -311,7 +340,9 @@ Wrap your response in <index>...</index> tags.
self._save_token_cache(file, fresh_cache) self._save_token_cache(file, fresh_cache)
mem_usage = process.memory_info().rss / 1024 / 1024 mem_usage = process.memory_info().rss / 1024 / 1024
self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB") self.logger.debug(
f"Memory usage after {file.name}: {mem_usage:.2f}MB"
)
except Exception as e: except Exception as e:
self.logger.error(f"Error processing {file}: {str(e)}") self.logger.error(f"Error processing {file}: {str(e)}")
@@ -328,21 +359,28 @@ Wrap your response in <index>...</index> tags.
all_tokens = existing_tokens + new_tokens all_tokens = existing_tokens + new_tokens
# 3) Build BM25 index from combined facts # 3) Build BM25 index from combined facts
self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).") self.logger.info(
f"Building BM25 index with {len(all_facts)} total facts (old + new)."
)
self.bm25_index = BM25Okapi(all_tokens) self.bm25_index = BM25Okapi(all_tokens)
self.tokenized_facts = all_facts self.tokenized_facts = all_facts
# 4) Save the updated BM25 index to disk # 4) Save the updated BM25 index to disk
with open(self.bm25_index_file, "wb") as f: with open(self.bm25_index_file, "wb") as f:
pickle.dump({ pickle.dump(
{
"bm25_index": self.bm25_index, "bm25_index": self.bm25_index,
"tokenized_facts": self.tokenized_facts "tokenized_facts": self.tokenized_facts,
}, f) },
f,
)
final_mem = process.memory_info().rss / 1024 / 1024 final_mem = process.memory_info().rss / 1024 / 1024
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB") self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None: async def generate_index_files(
self, force_generate_facts: bool = False, clear_bm25_cache: bool = False
) -> None:
""" """
Generate index files for all documents in parallel batches Generate index files for all documents in parallel batches
@@ -353,15 +391,17 @@ Wrap your response in <index>...</index> tags.
self.logger.info("Starting index generation for documentation files.") self.logger.info("Starting index generation for documentation files.")
md_files = [ md_files = [
self.docs_dir / f for f in os.listdir(self.docs_dir) self.docs_dir / f
if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md']) for f in os.listdir(self.docs_dir)
if f.endswith(".md") and not any(f.endswith(x) for x in [".q.md", ".xs.md"])
] ]
# Filter out files that already have .q files unless force=True # Filter out files that already have .q files unless force=True
if not force_generate_facts: if not force_generate_facts:
md_files = [ md_files = [
f for f in md_files f
if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists() for f in md_files
if not (self.docs_dir / f.name.replace(".md", ".q.md")).exists()
] ]
if not md_files: if not md_files:
@@ -370,7 +410,9 @@ Wrap your response in <index>...</index> tags.
# Process documents in batches # Process documents in batches
for i in range(0, len(md_files), self.batch_size): for i in range(0, len(md_files), self.batch_size):
batch = md_files[i : i + self.batch_size] batch = md_files[i : i + self.batch_size]
self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}") self.logger.info(
f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}"
)
await self._process_document_batch(batch) await self._process_document_batch(batch)
self.logger.info("Index generation complete, building/updating search index.") self.logger.info("Index generation complete, building/updating search index.")
@@ -378,21 +420,31 @@ Wrap your response in <index>...</index> tags.
def generate(self, sections: List[str], mode: str = "extended") -> str: def generate(self, sections: List[str], mode: str = "extended") -> str:
# Get all markdown files # Get all markdown files
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \ all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + glob.glob(
glob.glob(str(self.docs_dir / "[0-9]*.xs.md")) str(self.docs_dir / "[0-9]*.xs.md")
)
# Extract base names without extensions # Extract base names without extensions
base_docs = {Path(f).name.split('.')[0] for f in all_files base_docs = {
if not Path(f).name.endswith('.q.md')} Path(f).name.split(".")[0]
for f in all_files
if not Path(f).name.endswith(".q.md")
}
# Filter by sections if provided # Filter by sections if provided
if sections: if sections:
base_docs = {doc for doc in base_docs base_docs = {
if any(section.lower() in doc.lower() for section in sections)} doc
for doc in base_docs
if any(section.lower() in doc.lower() for section in sections)
}
# Get file paths based on mode # Get file paths based on mode
files = [] files = []
for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999): for doc in sorted(
base_docs,
key=lambda x: int(x.split("_")[0]) if x.split("_")[0].isdigit() else 999999,
):
if mode == "condensed": if mode == "condensed":
xs_file = self.docs_dir / f"{doc}.xs.md" xs_file = self.docs_dir / f"{doc}.xs.md"
regular_file = self.docs_dir / f"{doc}.md" regular_file = self.docs_dir / f"{doc}.md"
@@ -404,7 +456,7 @@ Wrap your response in <index>...</index> tags.
content = [] content = []
for file in files: for file in files:
try: try:
with open(file, 'r', encoding='utf-8') as f: with open(file, "r", encoding="utf-8") as f:
fname = Path(file).name fname = Path(file).name
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}") content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
except Exception as e: except Exception as e:
@@ -443,15 +495,9 @@ Wrap your response in <index>...</index> tags.
for file, _ in ranked_files: for file, _ in ranked_files:
main_doc = str(file).replace(".q.md", ".md") main_doc = str(file).replace(".q.md", ".md")
if os.path.exists(self.docs_dir / main_doc): if os.path.exists(self.docs_dir / main_doc):
with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f: with open(self.docs_dir / main_doc, "r", encoding="utf-8") as f:
only_file_name = main_doc.split("/")[-1] only_file_name = main_doc.split("/")[-1]
content = [ content = ["#" * 20, f"# {only_file_name}", "#" * 20, "", f.read()]
"#" * 20,
f"# {only_file_name}",
"#" * 20,
"",
f.read()
]
results.append("\n".join(content)) results.append("\n".join(content))
return "\n\n---\n\n".join(results) return "\n\n---\n\n".join(results)
@@ -482,7 +528,9 @@ Wrap your response in <index>...</index> tags.
if len(components) == 3: if len(components) == 3:
code_ref = components[2].strip() code_ref = components[2].strip()
code_tokens = self.preprocess_text(code_ref) code_tokens = self.preprocess_text(code_ref)
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens) code_match_score = len(set(query_tokens) & set(code_tokens)) / len(
query_tokens
)
file_data[file_path]["total_score"] += score file_data[file_path]["total_score"] += score
file_data[file_path]["match_count"] += 1 file_data[file_path]["match_count"] += 1

View File

@@ -2,41 +2,51 @@ from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
from .models import MarkdownGenerationResult from .models import MarkdownGenerationResult
from .html2text import CustomHTML2Text from .html2text import CustomHTML2Text
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter from .content_filter_strategy import RelevantContentFilter
import re import re
from urllib.parse import urljoin from urllib.parse import urljoin
# Pre-compile the regex pattern # Pre-compile the regex pattern
LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)') LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)')
def fast_urljoin(base: str, url: str) -> str: def fast_urljoin(base: str, url: str) -> str:
"""Fast URL joining for common cases.""" """Fast URL joining for common cases."""
if url.startswith(('http://', 'https://', 'mailto:', '//')): if url.startswith(("http://", "https://", "mailto:", "//")):
return url return url
if url.startswith('/'): if url.startswith("/"):
# Handle absolute paths # Handle absolute paths
if base.endswith('/'): if base.endswith("/"):
return base[:-1] + url return base[:-1] + url
return base + url return base + url
return urljoin(base, url) return urljoin(base, url)
class MarkdownGenerationStrategy(ABC): class MarkdownGenerationStrategy(ABC):
"""Abstract base class for markdown generation strategies.""" """Abstract base class for markdown generation strategies."""
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
def __init__(
self,
content_filter: Optional[RelevantContentFilter] = None,
options: Optional[Dict[str, Any]] = None,
):
self.content_filter = content_filter self.content_filter = content_filter
self.options = options or {} self.options = options or {}
@abstractmethod @abstractmethod
def generate_markdown(self, def generate_markdown(
self,
cleaned_html: str, cleaned_html: str,
base_url: str = "", base_url: str = "",
html2text_options: Optional[Dict[str, Any]] = None, html2text_options: Optional[Dict[str, Any]] = None,
content_filter: Optional[RelevantContentFilter] = None, content_filter: Optional[RelevantContentFilter] = None,
citations: bool = True, citations: bool = True,
**kwargs) -> MarkdownGenerationResult: **kwargs,
) -> MarkdownGenerationResult:
"""Generate markdown from cleaned HTML.""" """Generate markdown from cleaned HTML."""
pass pass
class DefaultMarkdownGenerator(MarkdownGenerationStrategy): class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
""" """
Default implementation of markdown generation strategy. Default implementation of markdown generation strategy.
@@ -54,10 +64,17 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
Returns: Returns:
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown. MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
""" """
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
def __init__(
self,
content_filter: Optional[RelevantContentFilter] = None,
options: Optional[Dict[str, Any]] = None,
):
super().__init__(content_filter, options) super().__init__(content_filter, options)
def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]: def convert_links_to_citations(
self, markdown: str, base_url: str = ""
) -> Tuple[str, str]:
""" """
Convert links in markdown to citations. Convert links in markdown to citations.
@@ -87,24 +104,30 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
text, url, title = match.groups() text, url, title = match.groups()
# Use cached URL if available, otherwise compute and cache # Use cached URL if available, otherwise compute and cache
if base_url and not url.startswith(('http://', 'https://', 'mailto:')): if base_url and not url.startswith(("http://", "https://", "mailto:")):
if url not in url_cache: if url not in url_cache:
url_cache[url] = fast_urljoin(base_url, url) url_cache[url] = fast_urljoin(base_url, url)
url = url_cache[url] url = url_cache[url]
if url not in link_map: if url not in link_map:
desc = [] desc = []
if title: desc.append(title) if title:
if text and text != title: desc.append(text) desc.append(title)
if text and text != title:
desc.append(text)
link_map[url] = (counter, ": " + " - ".join(desc) if desc else "") link_map[url] = (counter, ": " + " - ".join(desc) if desc else "")
counter += 1 counter += 1
num = link_map[url][0] num = link_map[url][0]
parts.append(f"{text}{num}" if not match.group(0).startswith('!') else f"![{text}{num}⟩]") parts.append(
f"{text}{num}"
if not match.group(0).startswith("!")
else f"![{text}{num}⟩]"
)
last_end = match.end() last_end = match.end()
parts.append(markdown[last_end:]) parts.append(markdown[last_end:])
converted_text = ''.join(parts) converted_text = "".join(parts)
# Pre-build reference strings # Pre-build reference strings
references = ["\n\n## References\n\n"] references = ["\n\n## References\n\n"]
@@ -113,16 +136,18 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0]) for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0])
) )
return converted_text, ''.join(references) return converted_text, "".join(references)
def generate_markdown(self, def generate_markdown(
self,
cleaned_html: str, cleaned_html: str,
base_url: str = "", base_url: str = "",
html2text_options: Optional[Dict[str, Any]] = None, html2text_options: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None,
content_filter: Optional[RelevantContentFilter] = None, content_filter: Optional[RelevantContentFilter] = None,
citations: bool = True, citations: bool = True,
**kwargs) -> MarkdownGenerationResult: **kwargs,
) -> MarkdownGenerationResult:
""" """
Generate markdown with citations from cleaned HTML. Generate markdown with citations from cleaned HTML.
@@ -147,14 +172,14 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
# Initialize HTML2Text with default options for better conversion # Initialize HTML2Text with default options for better conversion
h = CustomHTML2Text(baseurl=base_url) h = CustomHTML2Text(baseurl=base_url)
default_options = { default_options = {
'body_width': 0, # Disable text wrapping "body_width": 0, # Disable text wrapping
'ignore_emphasis': False, "ignore_emphasis": False,
'ignore_links': False, "ignore_links": False,
'ignore_images': False, "ignore_images": False,
'protect_links': True, "protect_links": True,
'single_line_break': True, "single_line_break": True,
'mark_code': True, "mark_code": True,
'escape_snob': False "escape_snob": False,
} }
# Update with custom options if provided # Update with custom options if provided
@@ -179,16 +204,17 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
except Exception as e: except Exception as e:
raw_markdown = f"Error converting HTML to markdown: {str(e)}" raw_markdown = f"Error converting HTML to markdown: {str(e)}"
raw_markdown = raw_markdown.replace(' ```', '```') raw_markdown = raw_markdown.replace(" ```", "```")
# Convert links to citations # Convert links to citations
markdown_with_citations: str = raw_markdown markdown_with_citations: str = raw_markdown
references_markdown: str = "" references_markdown: str = ""
if citations: if citations:
try: try:
markdown_with_citations, references_markdown = self.convert_links_to_citations( (
raw_markdown, base_url markdown_with_citations,
) references_markdown,
) = self.convert_links_to_citations(raw_markdown, base_url)
except Exception as e: except Exception as e:
markdown_with_citations = raw_markdown markdown_with_citations = raw_markdown
references_markdown = f"Error generating citations: {str(e)}" references_markdown = f"Error generating citations: {str(e)}"
@@ -200,7 +226,9 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
try: try:
content_filter = content_filter or self.content_filter content_filter = content_filter or self.content_filter
filtered_html = content_filter.filter_content(cleaned_html) filtered_html = content_filter.filter_content(cleaned_html)
filtered_html = '\n'.join('<div>{}</div>'.format(s) for s in filtered_html) filtered_html = "\n".join(
"<div>{}</div>".format(s) for s in filtered_html
)
fit_markdown = h.handle(filtered_html) fit_markdown = h.handle(filtered_html)
except Exception as e: except Exception as e:
fit_markdown = f"Error generating fit markdown: {str(e)}" fit_markdown = f"Error generating fit markdown: {str(e)}"

View File

@@ -1,13 +1,11 @@
import os import os
import asyncio import asyncio
import logging
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
from typing import Optional from typing import Optional
import xxhash import xxhash
import aiofiles import aiofiles
import shutil import shutil
import time
from datetime import datetime from datetime import datetime
from .async_logger import AsyncLogger, LogLevel from .async_logger import AsyncLogger, LogLevel
@@ -17,6 +15,7 @@ logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
# logging.basicConfig(level=logging.INFO) # logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
class DatabaseMigration: class DatabaseMigration:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
@@ -24,11 +23,11 @@ class DatabaseMigration:
def _ensure_content_dirs(self, base_path: str) -> dict: def _ensure_content_dirs(self, base_path: str) -> dict:
dirs = { dirs = {
'html': 'html_content', "html": "html_content",
'cleaned': 'cleaned_html', "cleaned": "cleaned_html",
'markdown': 'markdown_content', "markdown": "markdown_content",
'extracted': 'extracted_content', "extracted": "extracted_content",
'screenshots': 'screenshots' "screenshots": "screenshots",
} }
content_paths = {} content_paths = {}
for key, dirname in dirs.items(): for key, dirname in dirs.items():
@@ -52,7 +51,7 @@ class DatabaseMigration:
file_path = os.path.join(self.content_paths[content_type], content_hash) file_path = os.path.join(self.content_paths[content_type], content_hash)
if not os.path.exists(file_path): if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
await f.write(content) await f.write(content)
return content_hash return content_hash
@@ -66,24 +65,36 @@ class DatabaseMigration:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
# Get all rows # Get all rows
async with db.execute( async with db.execute(
'''SELECT url, html, cleaned_html, markdown, """SELECT url, html, cleaned_html, markdown,
extracted_content, screenshot FROM crawled_data''' extracted_content, screenshot FROM crawled_data"""
) as cursor: ) as cursor:
rows = await cursor.fetchall() rows = await cursor.fetchall()
migrated_count = 0 migrated_count = 0
for row in rows: for row in rows:
url, html, cleaned_html, markdown, extracted_content, screenshot = row (
url,
html,
cleaned_html,
markdown,
extracted_content,
screenshot,
) = row
# Store content in files and get hashes # Store content in files and get hashes
html_hash = await self._store_content(html, 'html') html_hash = await self._store_content(html, "html")
cleaned_hash = await self._store_content(cleaned_html, 'cleaned') cleaned_hash = await self._store_content(cleaned_html, "cleaned")
markdown_hash = await self._store_content(markdown, 'markdown') markdown_hash = await self._store_content(markdown, "markdown")
extracted_hash = await self._store_content(extracted_content, 'extracted') extracted_hash = await self._store_content(
screenshot_hash = await self._store_content(screenshot, 'screenshots') extracted_content, "extracted"
)
screenshot_hash = await self._store_content(
screenshot, "screenshots"
)
# Update database with hashes # Update database with hashes
await db.execute(''' await db.execute(
"""
UPDATE crawled_data UPDATE crawled_data
SET html = ?, SET html = ?,
cleaned_html = ?, cleaned_html = ?,
@@ -91,26 +102,37 @@ class DatabaseMigration:
extracted_content = ?, extracted_content = ?,
screenshot = ? screenshot = ?
WHERE url = ? WHERE url = ?
''', (html_hash, cleaned_hash, markdown_hash, """,
extracted_hash, screenshot_hash, url)) (
html_hash,
cleaned_hash,
markdown_hash,
extracted_hash,
screenshot_hash,
url,
),
)
migrated_count += 1 migrated_count += 1
if migrated_count % 100 == 0: if migrated_count % 100 == 0:
logger.info(f"Migrated {migrated_count} records...", tag="INIT") logger.info(f"Migrated {migrated_count} records...", tag="INIT")
await db.commit() await db.commit()
logger.success(f"Migration completed. {migrated_count} records processed.", tag="COMPLETE") logger.success(
f"Migration completed. {migrated_count} records processed.",
tag="COMPLETE",
)
except Exception as e: except Exception as e:
# logger.error(f"Migration failed: {e}") # logger.error(f"Migration failed: {e}")
logger.error( logger.error(
message="Migration failed: {error}", message="Migration failed: {error}",
tag="ERROR", tag="ERROR",
params={"error": str(e)} params={"error": str(e)},
) )
raise e raise e
async def backup_database(db_path: str) -> str: async def backup_database(db_path: str) -> str:
"""Create backup of existing database""" """Create backup of existing database"""
if not os.path.exists(db_path): if not os.path.exists(db_path):
@@ -118,7 +140,7 @@ async def backup_database(db_path: str) -> str:
return None return None
# Create backup with timestamp # Create backup with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{db_path}.backup_{timestamp}" backup_path = f"{db_path}.backup_{timestamp}"
try: try:
@@ -132,12 +154,11 @@ async def backup_database(db_path: str) -> str:
except Exception as e: except Exception as e:
# logger.error(f"Backup failed: {e}") # logger.error(f"Backup failed: {e}")
logger.error( logger.error(
message="Migration failed: {error}", message="Migration failed: {error}", tag="ERROR", params={"error": str(e)}
tag="ERROR",
params={"error": str(e)}
) )
raise e raise e
async def run_migration(db_path: Optional[str] = None): async def run_migration(db_path: Optional[str] = None):
"""Run database migration""" """Run database migration"""
if db_path is None: if db_path is None:
@@ -155,14 +176,19 @@ async def run_migration(db_path: Optional[str] = None):
migration = DatabaseMigration(db_path) migration = DatabaseMigration(db_path)
await migration.migrate_database() await migration.migrate_database()
def main(): def main():
"""CLI entry point for migration""" """CLI entry point for migration"""
import argparse import argparse
parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage')
parser.add_argument('--db-path', help='Custom database path') parser = argparse.ArgumentParser(
description="Migrate Crawl4AI database to file-based storage"
)
parser.add_argument("--db-path", help="Custom database path")
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_migration(args.db_path)) asyncio.run(run_migration(args.db_path))
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,30 +2,32 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
import subprocess, os import subprocess, os
import shutil import shutil
import tarfile
from .model_loader import * from .model_loader import *
import argparse import argparse
import urllib.request
from crawl4ai.config import MODEL_REPO_BRANCH from crawl4ai.config import MODEL_REPO_BRANCH
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@lru_cache() @lru_cache()
def get_available_memory(device): def get_available_memory(device):
import torch import torch
if device.type == 'cuda':
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory return torch.cuda.get_device_properties(device).total_memory
elif device.type == 'mps': elif device.type == "mps":
return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
else: else:
return 0 return 0
@lru_cache() @lru_cache()
def calculate_batch_size(device): def calculate_batch_size(device):
available_memory = get_available_memory(device) available_memory = get_available_memory(device)
if device.type == 'cpu': if device.type == "cpu":
return 16 return 16
elif device.type in ['cuda', 'mps']: elif device.type in ["cuda", "mps"]:
# Adjust these thresholds based on your model size and available memory # Adjust these thresholds based on your model size and available memory
if available_memory >= 31 * 1024**3: # > 32GB if available_memory >= 31 * 1024**3: # > 32GB
return 256 return 256
@@ -38,39 +40,48 @@ def calculate_batch_size(device):
else: else:
return 16 # Default batch size return 16 # Default batch size
@lru_cache() @lru_cache()
def get_device(): def get_device():
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
device = torch.device('mps') device = torch.device("mps")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
return device return device
def set_model_device(model): def set_model_device(model):
device = get_device() device = get_device()
model.to(device) model.to(device)
return model, device return model, device
@lru_cache() @lru_cache()
def get_home_folder(): def get_home_folder():
home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") home_folder = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(home_folder, exist_ok=True) os.makedirs(home_folder, exist_ok=True)
os.makedirs(f"{home_folder}/cache", exist_ok=True) os.makedirs(f"{home_folder}/cache", exist_ok=True)
os.makedirs(f"{home_folder}/models", exist_ok=True) os.makedirs(f"{home_folder}/models", exist_ok=True)
return home_folder return home_folder
@lru_cache() @lru_cache()
def load_bert_base_uncased(): def load_bert_base_uncased():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None)
model = BertModel.from_pretrained("bert-base-uncased", resume_download=None)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
return tokenizer, model return tokenizer, model
@lru_cache() @lru_cache()
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple: def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
"""Load the Hugging Face model for embedding. """Load the Hugging Face model for embedding.
@@ -81,30 +92,35 @@ def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
Returns: Returns:
tuple: The tokenizer and model. tuple: The tokenizer and model.
""" """
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None) tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
model = AutoModel.from_pretrained(model_name, resume_download=None) model = AutoModel.from_pretrained(model_name, resume_download=None)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
return tokenizer, model return tokenizer, model
@lru_cache() @lru_cache()
def load_text_classifier(): def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline from transformers import pipeline
import torch
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") tokenizer = AutoTokenizer.from_pretrained(
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") "dstefa/roberta-base_topic_classification_nyt_news"
)
model = AutoModelForSequenceClassification.from_pretrained(
"dstefa/roberta-base_topic_classification_nyt_news"
)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe return pipe
@lru_cache() @lru_cache()
def load_text_multilabel_classifier(): def load_text_multilabel_classifier():
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
from scipy.special import expit from scipy.special import expit
import torch import torch
@@ -117,17 +133,26 @@ def load_text_multilabel_classifier():
# device = torch.device("cpu") # device = torch.device("cpu")
# # return load_spacy_model(), torch.device("cpu") # # return load_spacy_model(), torch.device("cpu")
MODEL = "cardiffnlp/tweet-topic-21-multi" MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) model = AutoModelForSequenceClassification.from_pretrained(
MODEL, resume_download=None
)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
class_mapping = model.config.id2label class_mapping = model.config.id2label
def _classifier(texts, threshold=0.5, max_length=64): def _classifier(texts, threshold=0.5, max_length=64):
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length) tokens = tokenizer(
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
tokens = {
key: val.to(device) for key, val in tokens.items()
} # Move tokens to the selected device
with torch.no_grad(): with torch.no_grad():
output = model(**tokens) output = model(**tokens)
@@ -138,25 +163,31 @@ def load_text_multilabel_classifier():
batch_labels = [] batch_labels = []
for prediction in predictions: for prediction in predictions:
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1] labels = [
class_mapping[i] for i, value in enumerate(prediction) if value == 1
]
batch_labels.append(labels) batch_labels.append(labels)
return batch_labels return batch_labels
return _classifier, device return _classifier, device
@lru_cache() @lru_cache()
def load_nltk_punkt(): def load_nltk_punkt():
import nltk import nltk
try: try:
nltk.data.find('tokenizers/punkt') nltk.data.find("tokenizers/punkt")
except LookupError: except LookupError:
nltk.download('punkt') nltk.download("punkt")
return nltk.data.find('tokenizers/punkt') return nltk.data.find("tokenizers/punkt")
@lru_cache() @lru_cache()
def load_spacy_model(): def load_spacy_model():
import spacy import spacy
name = "models/reuters" name = "models/reuters"
home_folder = get_home_folder() home_folder = get_home_folder()
model_folder = Path(home_folder) / name model_folder = Path(home_folder) / name
@@ -176,7 +207,9 @@ def load_spacy_model():
if model_folder.exists(): if model_folder.exists():
shutil.rmtree(model_folder) shutil.rmtree(model_folder)
except PermissionError: except PermissionError:
print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:") print(
"[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:"
)
print(f"- {repo_folder}") print(f"- {repo_folder}")
print(f"- {model_folder}") print(f"- {model_folder}")
return None return None
@@ -187,7 +220,7 @@ def load_spacy_model():
["git", "clone", "-b", branch, repo_url, str(repo_folder)], ["git", "clone", "-b", branch, repo_url, str(repo_folder)],
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
check=True check=True,
) )
# Create the models directory if it doesn't exist # Create the models directory if it doesn't exist
@@ -215,6 +248,7 @@ def load_spacy_model():
print(f"Error loading spacy model: {e}") print(f"Error loading spacy model: {e}")
return None return None
def download_all_models(remove_existing=False): def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI.""" """Download all models required for Crawl4AI."""
if remove_existing: if remove_existing:
@@ -243,14 +277,20 @@ def download_all_models(remove_existing=False):
load_nltk_punkt() load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.") print("[LOG] ✅ All models downloaded successfully.")
def main(): def main():
print("[LOG] Welcome to the Crawl4AI Model Downloader!") print("[LOG] Welcome to the Crawl4AI Model Downloader!")
print("[LOG] This script will download all the models required for Crawl4AI.") print("[LOG] This script will download all the models required for Crawl4AI.")
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader") parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading") parser.add_argument(
"--remove-existing",
action="store_true",
help="Remove existing models before downloading",
)
args = parser.parse_args() args = parser.parse_args()
download_all_models(remove_existing=args.remove_existing) download_all_models(remove_existing=args.remove_existing)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,18 +1,12 @@
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Optional, Callable, Awaitable, Union, Tuple, Any from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
from enum import Enum from enum import Enum
from dataclasses import dataclass, field
from .ssl_certificate import SSLCertificate
from dataclasses import dataclass from dataclasses import dataclass
from .ssl_certificate import SSLCertificate
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Optional
from datetime import timedelta from datetime import timedelta
############################### ###############################
# Dispatcher Models # Dispatcher Models
############################### ###############################
@@ -22,6 +16,7 @@ class DomainState:
current_delay: float = 0 current_delay: float = 0
fail_count: int = 0 fail_count: int = 0
@dataclass @dataclass
class CrawlerTaskResult: class CrawlerTaskResult:
task_id: str task_id: str
@@ -33,12 +28,14 @@ class CrawlerTaskResult:
end_time: datetime end_time: datetime
error_message: str = "" error_message: str = ""
class CrawlStatus(Enum): class CrawlStatus(Enum):
QUEUED = "QUEUED" QUEUED = "QUEUED"
IN_PROGRESS = "IN_PROGRESS" IN_PROGRESS = "IN_PROGRESS"
COMPLETED = "COMPLETED" COMPLETED = "COMPLETED"
FAILED = "FAILED" FAILED = "FAILED"
@dataclass @dataclass
class CrawlStats: class CrawlStats:
task_id: str task_id: str
@@ -58,10 +55,12 @@ class CrawlStats:
duration = end - self.start_time duration = end - self.start_time
return str(timedelta(seconds=int(duration.total_seconds()))) return str(timedelta(seconds=int(duration.total_seconds())))
class DisplayMode(Enum): class DisplayMode(Enum):
DETAILED = "DETAILED" DETAILED = "DETAILED"
AGGREGATED = "AGGREGATED" AGGREGATED = "AGGREGATED"
############################### ###############################
# Crawler Models # Crawler Models
############################### ###############################
@@ -78,6 +77,7 @@ class UrlModel(BaseModel):
url: HttpUrl url: HttpUrl
forced: bool = False forced: bool = False
class MarkdownGenerationResult(BaseModel): class MarkdownGenerationResult(BaseModel):
raw_markdown: str raw_markdown: str
markdown_with_citations: str markdown_with_citations: str
@@ -85,6 +85,7 @@ class MarkdownGenerationResult(BaseModel):
fit_markdown: Optional[str] = None fit_markdown: Optional[str] = None
fit_html: Optional[str] = None fit_html: Optional[str] = None
class DispatchResult(BaseModel): class DispatchResult(BaseModel):
task_id: str task_id: str
memory_usage: float memory_usage: float
@@ -92,6 +93,8 @@ class DispatchResult(BaseModel):
start_time: datetime start_time: datetime
end_time: datetime end_time: datetime
error_message: str = "" error_message: str = ""
class CrawlResult(BaseModel): class CrawlResult(BaseModel):
url: str url: str
html: str html: str
@@ -114,9 +117,11 @@ class CrawlResult(BaseModel):
status_code: Optional[int] = None status_code: Optional[int] = None
ssl_certificate: Optional[SSLCertificate] = None ssl_certificate: Optional[SSLCertificate] = None
dispatch_result: Optional[DispatchResult] = None dispatch_result: Optional[DispatchResult] = None
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class AsyncCrawlResponse(BaseModel): class AsyncCrawlResponse(BaseModel):
html: str html: str
response_headers: Dict[str, str] response_headers: Dict[str, str]
@@ -130,6 +135,7 @@ class AsyncCrawlResponse(BaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
############################### ###############################
# Scraping Models # Scraping Models
############################### ###############################
@@ -143,21 +149,29 @@ class MediaItem(BaseModel):
format: Optional[str] = None format: Optional[str] = None
width: Optional[int] = None width: Optional[int] = None
class Link(BaseModel): class Link(BaseModel):
href: str href: str
text: str text: str
title: Optional[str] = None title: Optional[str] = None
base_domain: str base_domain: str
class Media(BaseModel): class Media(BaseModel):
images: List[MediaItem] = [] images: List[MediaItem] = []
videos: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Video model if needed videos: List[
audios: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Audio model if needed MediaItem
] = [] # Using MediaItem model for now, can be extended with Video model if needed
audios: List[
MediaItem
] = [] # Using MediaItem model for now, can be extended with Audio model if needed
class Links(BaseModel): class Links(BaseModel):
internal: List[Link] = [] internal: List[Link] = []
external: List[Link] = [] external: List[Link] = []
class ScrapingResult(BaseModel): class ScrapingResult(BaseModel):
cleaned_html: str cleaned_html: str
success: bool success: bool

View File

@@ -26,11 +26,12 @@ class SSLCertificate:
export_as_json() -> Dict[str, Any]: Export the certificate as JSON format. export_as_json() -> Dict[str, Any]: Export the certificate as JSON format.
export_as_text() -> str: Export the certificate as text format. export_as_text() -> str: Export the certificate as text format.
""" """
def __init__(self, cert_info: Dict[str, Any]): def __init__(self, cert_info: Dict[str, Any]):
self._cert_info = self._decode_cert_data(cert_info) self._cert_info = self._decode_cert_data(cert_info)
@staticmethod @staticmethod
def from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: def from_url(url: str, timeout: int = 10) -> Optional["SSLCertificate"]:
""" """
Create SSLCertificate instance from a URL. Create SSLCertificate instance from a URL.
@@ -43,14 +44,16 @@ class SSLCertificate:
""" """
try: try:
hostname = urlparse(url).netloc hostname = urlparse(url).netloc
if ':' in hostname: if ":" in hostname:
hostname = hostname.split(':')[0] hostname = hostname.split(":")[0]
context = ssl.create_default_context() context = ssl.create_default_context()
with socket.create_connection((hostname, 443), timeout=timeout) as sock: with socket.create_connection((hostname, 443), timeout=timeout) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock: with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert_binary = ssock.getpeercert(binary_form=True) cert_binary = ssock.getpeercert(binary_form=True)
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary) x509 = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, cert_binary
)
cert_info = { cert_info = {
"subject": dict(x509.get_subject().get_components()), "subject": dict(x509.get_subject().get_components()),
@@ -61,32 +64,33 @@ class SSLCertificate:
"not_after": x509.get_notAfter(), "not_after": x509.get_notAfter(),
"fingerprint": x509.digest("sha256").hex(), "fingerprint": x509.digest("sha256").hex(),
"signature_algorithm": x509.get_signature_algorithm(), "signature_algorithm": x509.get_signature_algorithm(),
"raw_cert": base64.b64encode(cert_binary) "raw_cert": base64.b64encode(cert_binary),
} }
# Add extensions # Add extensions
extensions = [] extensions = []
for i in range(x509.get_extension_count()): for i in range(x509.get_extension_count()):
ext = x509.get_extension(i) ext = x509.get_extension(i)
extensions.append({ extensions.append(
"name": ext.get_short_name(), {"name": ext.get_short_name(), "value": str(ext)}
"value": str(ext) )
})
cert_info["extensions"] = extensions cert_info["extensions"] = extensions
return SSLCertificate(cert_info) return SSLCertificate(cert_info)
except Exception as e: except Exception:
return None return None
@staticmethod @staticmethod
def _decode_cert_data(data: Any) -> Any: def _decode_cert_data(data: Any) -> Any:
"""Helper method to decode bytes in certificate data.""" """Helper method to decode bytes in certificate data."""
if isinstance(data, bytes): if isinstance(data, bytes):
return data.decode('utf-8') return data.decode("utf-8")
elif isinstance(data, dict): elif isinstance(data, dict):
return { return {
(k.decode('utf-8') if isinstance(k, bytes) else k): SSLCertificate._decode_cert_data(v) (
k.decode("utf-8") if isinstance(k, bytes) else k
): SSLCertificate._decode_cert_data(v)
for k, v in data.items() for k, v in data.items()
} }
elif isinstance(data, list): elif isinstance(data, list):
@@ -105,7 +109,7 @@ class SSLCertificate:
""" """
json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False) json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False)
if filepath: if filepath:
Path(filepath).write_text(json_str, encoding='utf-8') Path(filepath).write_text(json_str, encoding="utf-8")
return None return None
return json_str return json_str
@@ -122,18 +126,17 @@ class SSLCertificate:
try: try:
x509 = OpenSSL.crypto.load_certificate( x509 = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, OpenSSL.crypto.FILETYPE_ASN1,
base64.b64decode(self._cert_info['raw_cert']) base64.b64decode(self._cert_info["raw_cert"]),
) )
pem_data = OpenSSL.crypto.dump_certificate( pem_data = OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM, x509
x509 ).decode("utf-8")
).decode('utf-8')
if filepath: if filepath:
Path(filepath).write_text(pem_data, encoding='utf-8') Path(filepath).write_text(pem_data, encoding="utf-8")
return None return None
return pem_data return pem_data
except Exception as e: except Exception:
return None return None
def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]: def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]:
@@ -147,7 +150,7 @@ class SSLCertificate:
Optional[bytes]: DER bytes if successful, None otherwise. Optional[bytes]: DER bytes if successful, None otherwise.
""" """
try: try:
der_data = base64.b64decode(self._cert_info['raw_cert']) der_data = base64.b64decode(self._cert_info["raw_cert"])
if filepath: if filepath:
Path(filepath).write_bytes(der_data) Path(filepath).write_bytes(der_data)
return None return None
@@ -158,24 +161,24 @@ class SSLCertificate:
@property @property
def issuer(self) -> Dict[str, str]: def issuer(self) -> Dict[str, str]:
"""Get certificate issuer information.""" """Get certificate issuer information."""
return self._cert_info.get('issuer', {}) return self._cert_info.get("issuer", {})
@property @property
def subject(self) -> Dict[str, str]: def subject(self) -> Dict[str, str]:
"""Get certificate subject information.""" """Get certificate subject information."""
return self._cert_info.get('subject', {}) return self._cert_info.get("subject", {})
@property @property
def valid_from(self) -> str: def valid_from(self) -> str:
"""Get certificate validity start date.""" """Get certificate validity start date."""
return self._cert_info.get('not_before', '') return self._cert_info.get("not_before", "")
@property @property
def valid_until(self) -> str: def valid_until(self) -> str:
"""Get certificate validity end date.""" """Get certificate validity end date."""
return self._cert_info.get('not_after', '') return self._cert_info.get("not_after", "")
@property @property
def fingerprint(self) -> str: def fingerprint(self) -> str:
"""Get certificate fingerprint.""" """Get certificate fingerprint."""
return self._cert_info.get('fingerprint', '') return self._cert_info.get("fingerprint", "")

View File

@@ -32,6 +32,7 @@ class UserAgentGenerator:
android_version: Optional[str] = None android_version: Optional[str] = None
): Generates a random user agent string based on the specified parameters. ): Generates a random user agent string based on the specified parameters.
""" """
def __init__(self): def __init__(self):
# Previous platform definitions remain the same... # Previous platform definitions remain the same...
self.desktop_platforms = { self.desktop_platforms = {
@@ -47,7 +48,7 @@ class UserAgentGenerator:
"generic": "(X11; Linux x86_64)", "generic": "(X11; Linux x86_64)",
"ubuntu": "(X11; Ubuntu; Linux x86_64)", "ubuntu": "(X11; Ubuntu; Linux x86_64)",
"chrome_os": "(X11; CrOS x86_64 14541.0.0)", "chrome_os": "(X11; CrOS x86_64 14541.0.0)",
} },
} }
self.mobile_platforms = { self.mobile_platforms = {
@@ -60,26 +61,14 @@ class UserAgentGenerator:
"ios": { "ios": {
"iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)", "iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)",
"ipad": "(iPad; CPU OS 16_5 like Mac OS X)", "ipad": "(iPad; CPU OS 16_5 like Mac OS X)",
} },
} }
# Browser Combinations # Browser Combinations
self.browser_combinations = { self.browser_combinations = {
1: [ 1: [["chrome"], ["firefox"], ["safari"], ["edge"]],
["chrome"], 2: [["gecko", "firefox"], ["chrome", "safari"], ["webkit", "safari"]],
["firefox"], 3: [["chrome", "safari", "edge"], ["webkit", "chrome", "safari"]],
["safari"],
["edge"]
],
2: [
["gecko", "firefox"],
["chrome", "safari"],
["webkit", "safari"]
],
3: [
["chrome", "safari", "edge"],
["webkit", "chrome", "safari"]
]
} }
# Rendering Engines with versions # Rendering Engines with versions
@@ -90,7 +79,7 @@ class UserAgentGenerator:
"Gecko/20100101", "Gecko/20100101",
"Gecko/20100101", # Firefox usually uses this constant version "Gecko/20100101", # Firefox usually uses this constant version
"Gecko/2010010", "Gecko/2010010",
] ],
} }
# Browser Versions # Browser Versions
@@ -170,12 +159,14 @@ class UserAgentGenerator:
return browser_stack return browser_stack
def generate(self, def generate(
device_type: Optional[Literal['desktop', 'mobile']] = None, self,
device_type: Optional[Literal["desktop", "mobile"]] = None,
os_type: Optional[str] = None, os_type: Optional[str] = None,
device_brand: Optional[str] = None, device_brand: Optional[str] = None,
browser_type: Optional[Literal['chrome', 'edge', 'safari', 'firefox']] = None, browser_type: Optional[Literal["chrome", "edge", "safari", "firefox"]] = None,
num_browsers: int = 3) -> str: num_browsers: int = 3,
) -> str:
""" """
Generate a random user agent with specified constraints. Generate a random user agent with specified constraints.
@@ -215,9 +206,13 @@ class UserAgentGenerator:
def get_random_platform(self, device_type, os_type, device_brand): def get_random_platform(self, device_type, os_type, device_brand):
"""Helper method to get random platform based on constraints""" """Helper method to get random platform based on constraints"""
platforms = self.desktop_platforms if device_type == 'desktop' else \ platforms = (
self.mobile_platforms if device_type == 'mobile' else \ self.desktop_platforms
{**self.desktop_platforms, **self.mobile_platforms} if device_type == "desktop"
else self.mobile_platforms
if device_type == "mobile"
else {**self.desktop_platforms, **self.mobile_platforms}
)
if os_type: if os_type:
for platform_group in [self.desktop_platforms, self.mobile_platforms]: for platform_group in [self.desktop_platforms, self.mobile_platforms]:
@@ -233,10 +228,10 @@ class UserAgentGenerator:
def parse_user_agent(self, user_agent: str) -> Dict[str, str]: def parse_user_agent(self, user_agent: str) -> Dict[str, str]:
"""Parse a user agent string to extract browser and version information""" """Parse a user agent string to extract browser and version information"""
browsers = { browsers = {
'chrome': r'Chrome/(\d+)', "chrome": r"Chrome/(\d+)",
'edge': r'Edg/(\d+)', "edge": r"Edg/(\d+)",
'safari': r'Version/(\d+)', "safari": r"Version/(\d+)",
'firefox': r'Firefox/(\d+)' "firefox": r"Firefox/(\d+)",
} }
result = {} result = {}
@@ -255,25 +250,26 @@ class UserAgentGenerator:
hints = [] hints = []
# Handle different browser combinations # Handle different browser combinations
if 'chrome' in browsers: if "chrome" in browsers:
hints.append(f'"Chromium";v="{browsers["chrome"]}"') hints.append(f'"Chromium";v="{browsers["chrome"]}"')
hints.append('"Not_A Brand";v="8"') hints.append('"Not_A Brand";v="8"')
if 'edge' in browsers: if "edge" in browsers:
hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"') hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"')
else: else:
hints.append(f'"Google Chrome";v="{browsers["chrome"]}"') hints.append(f'"Google Chrome";v="{browsers["chrome"]}"')
elif 'firefox' in browsers: elif "firefox" in browsers:
# Firefox doesn't typically send Sec-CH-UA # Firefox doesn't typically send Sec-CH-UA
return '""' return '""'
elif 'safari' in browsers: elif "safari" in browsers:
# Safari's format for client hints # Safari's format for client hints
hints.append(f'"Safari";v="{browsers["safari"]}"') hints.append(f'"Safari";v="{browsers["safari"]}"')
hints.append('"Not_A Brand";v="8"') hints.append('"Not_A Brand";v="8"')
return ', '.join(hints) return ", ".join(hints)
# Example usage: # Example usage:
if __name__ == "__main__": if __name__ == "__main__":
@@ -281,7 +277,7 @@ if __name__ == "__main__":
print(generator.generate()) print(generator.generate())
print("\nSingle browser (Chrome):") print("\nSingle browser (Chrome):")
print(generator.generate(num_browsers=1, browser_type='chrome')) print(generator.generate(num_browsers=1, browser_type="chrome"))
print("\nTwo browsers (Gecko/Firefox):") print("\nTwo browsers (Gecko/Firefox):")
print(generator.generate(num_browsers=2)) print(generator.generate(num_browsers=2))
@@ -290,16 +286,14 @@ if __name__ == "__main__":
print(generator.generate(num_browsers=3)) print(generator.generate(num_browsers=3))
print("\nFirefox on Linux:") print("\nFirefox on Linux:")
print(generator.generate( print(
device_type='desktop', generator.generate(
os_type='linux', device_type="desktop",
browser_type='firefox', os_type="linux",
num_browsers=2 browser_type="firefox",
)) num_browsers=2,
)
)
print("\nChrome/Safari/Edge on Windows:") print("\nChrome/Safari/Edge on Windows:")
print(generator.generate( print(generator.generate(device_type="desktop", os_type="windows", num_browsers=3))
device_type='desktop',
os_type='windows',
num_browsers=3
))

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
# version_manager.py # version_manager.py
import os
from pathlib import Path from pathlib import Path
from packaging import version from packaging import version
from . import __version__ from . import __version__
class VersionManager: class VersionManager:
def __init__(self): def __init__(self):
self.home_dir = Path.home() / ".crawl4ai" self.home_dir = Path.home() / ".crawl4ai"
@@ -27,4 +27,3 @@ class VersionManager:
installed = self.get_installed_version() installed = self.get_installed_version()
current = version.parse(__version__.__version__) current = version.parse(__version__.__version__)
return installed is None or installed < current return installed is None or installed < current

View File

@@ -1,9 +1,10 @@
import os, time import os, time
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path from pathlib import Path
from .models import UrlModel, CrawlResult from .models import UrlModel, CrawlResult
from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db from .database import init_db, get_cached_url, cache_url
from .utils import * from .utils import *
from .chunking_strategy import * from .chunking_strategy import *
from .extraction_strategy import * from .extraction_strategy import *
@@ -14,14 +15,27 @@ from .content_scraping_strategy import WebScrapingStrategy
from .config import * from .config import *
import warnings import warnings
import json import json
warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".')
warnings.filterwarnings(
"ignore",
message='Field "model_name" has conflict with protected namespace "model_".',
)
class WebCrawler: class WebCrawler:
def __init__(self, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False): def __init__(
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose) self,
crawler_strategy: CrawlerStrategy = None,
always_by_pass_cache: bool = False,
verbose: bool = False,
):
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(
verbose=verbose
)
self.always_by_pass_cache = always_by_pass_cache self.always_by_pass_cache = always_by_pass_cache
self.crawl4ai_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") self.crawl4ai_folder = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(self.crawl4ai_folder, exist_ok=True) os.makedirs(self.crawl4ai_folder, exist_ok=True)
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
init_db() init_db()
@@ -30,11 +44,11 @@ class WebCrawler:
def warmup(self): def warmup(self):
print("[LOG] 🌤️ Warming up the WebCrawler") print("[LOG] 🌤️ Warming up the WebCrawler")
self.run( self.run(
url='https://google.com/', url="https://google.com/",
word_count_threshold=5, word_count_threshold=5,
extraction_strategy=NoExtractionStrategy(), extraction_strategy=NoExtractionStrategy(),
bypass_cache=False, bypass_cache=False,
verbose=False verbose=False,
) )
self.ready = True self.ready = True
print("[LOG] 🌞 WebCrawler is ready to crawl") print("[LOG] 🌞 WebCrawler is ready to crawl")
@@ -80,6 +94,7 @@ class WebCrawler:
**kwargs, **kwargs,
) -> List[CrawlResult]: ) -> List[CrawlResult]:
extraction_strategy = extraction_strategy or NoExtractionStrategy() extraction_strategy = extraction_strategy or NoExtractionStrategy()
def fetch_page_wrapper(url_model, *args, **kwargs): def fetch_page_wrapper(url_model, *args, **kwargs):
return self.fetch_page(url_model, *args, **kwargs) return self.fetch_page(url_model, *args, **kwargs)
@@ -150,12 +165,25 @@ class WebCrawler:
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs)) html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
t2 = time.time() t2 = time.time()
if verbose: if verbose:
print(f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds") print(
f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
)
if screenshot: if screenshot:
screenshot_data = self.crawler_strategy.take_screenshot() screenshot_data = self.crawler_strategy.take_screenshot()
crawl_result = self.process_html(
crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs) url,
html,
extracted_content,
word_count_threshold,
extraction_strategy,
chunking_strategy,
css_selector,
screenshot_data,
verbose,
bool(cached),
**kwargs,
)
crawl_result.success = bool(html) crawl_result.success = bool(html)
return crawl_result return crawl_result
except Exception as e: except Exception as e:
@@ -183,7 +211,11 @@ class WebCrawler:
try: try:
t1 = time.time() t1 = time.time()
scrapping_strategy = WebScrapingStrategy() scrapping_strategy = WebScrapingStrategy()
extra_params = {k: v for k, v in kwargs.items() if k not in ["only_text", "image_description_min_word_threshold"]} extra_params = {
k: v
for k, v in kwargs.items()
if k not in ["only_text", "image_description_min_word_threshold"]
}
result = scrapping_strategy.scrap( result = scrapping_strategy.scrap(
url, url,
html, html,
@@ -191,14 +223,17 @@ class WebCrawler:
css_selector=css_selector, css_selector=css_selector,
only_text=kwargs.get("only_text", False), only_text=kwargs.get("only_text", False),
image_description_min_word_threshold=kwargs.get( image_description_min_word_threshold=kwargs.get(
"image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD "image_description_min_word_threshold",
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
), ),
**extra_params, **extra_params,
) )
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False)) # result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
if verbose: if verbose:
print(f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds") print(
f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds"
)
if result is None: if result is None:
raise ValueError(f"Failed to extract content from the website: {url}") raise ValueError(f"Failed to extract content from the website: {url}")
@@ -213,14 +248,20 @@ class WebCrawler:
if extracted_content is None: if extracted_content is None:
if verbose: if verbose:
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}") print(
f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}"
)
sections = chunking_strategy.chunk(markdown) sections = chunking_strategy.chunk(markdown)
extracted_content = extraction_strategy.run(url, sections) extracted_content = extraction_strategy.run(url, sections)
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) extracted_content = json.dumps(
extracted_content, indent=4, default=str, ensure_ascii=False
)
if verbose: if verbose:
print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds.") print(
f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds."
)
screenshot = None if not screenshot else screenshot screenshot = None if not screenshot else screenshot

View File

@@ -9,12 +9,10 @@ from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json import json
async def extract_amazon_products(): async def extract_amazon_products():
# Initialize browser config # Initialize browser config
browser_config = BrowserConfig( browser_config = BrowserConfig(browser_type="chromium", headless=True)
browser_type="chromium",
headless=True
)
# Initialize crawler config with JSON CSS extraction strategy # Initialize crawler config with JSON CSS extraction strategy
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
@@ -27,57 +25,53 @@ async def extract_amazon_products():
"name": "asin", "name": "asin",
"selector": "", "selector": "",
"type": "attribute", "type": "attribute",
"attribute": "data-asin" "attribute": "data-asin",
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
}, },
{"name": "title", "selector": "h2 a span", "type": "text"},
{ {
"name": "url", "name": "url",
"selector": "h2 a", "selector": "h2 a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
}, },
{ {
"name": "image", "name": "image",
"selector": ".s-image", "selector": ".s-image",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
}, },
{ {
"name": "rating", "name": "rating",
"selector": ".a-icon-star-small .a-icon-alt", "selector": ".a-icon-star-small .a-icon-alt",
"type": "text" "type": "text",
}, },
{ {
"name": "reviews_count", "name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text" "type": "text",
}, },
{ {
"name": "price", "name": "price",
"selector": ".a-price .a-offscreen", "selector": ".a-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "original_price", "name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen", "selector": ".a-price.a-text-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "sponsored", "name": "sponsored",
"selector": ".puis-sponsored-label-text", "selector": ".puis-sponsored-label-text",
"type": "exists" "type": "exists",
}, },
{ {
"name": "delivery_info", "name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base", "selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text", "type": "text",
"multiple": True "multiple": True,
} },
] ],
} }
) )
) )
@@ -105,10 +99,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}") print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}") print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'): if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}") print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(extract_amazon_products()) asyncio.run(extract_amazon_products())

View File

@@ -10,6 +10,7 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json import json
from playwright.async_api import Page, BrowserContext from playwright.async_api import Page, BrowserContext
async def extract_amazon_products(): async def extract_amazon_products():
# Initialize browser config # Initialize browser config
browser_config = BrowserConfig( browser_config = BrowserConfig(
@@ -20,7 +21,6 @@ async def extract_amazon_products():
# Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button # Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
extraction_strategy=JsonCssExtractionStrategy( extraction_strategy=JsonCssExtractionStrategy(
schema={ schema={
"name": "Amazon Product Search Results", "name": "Amazon Product Search Results",
@@ -30,82 +30,86 @@ async def extract_amazon_products():
"name": "asin", "name": "asin",
"selector": "", "selector": "",
"type": "attribute", "type": "attribute",
"attribute": "data-asin" "attribute": "data-asin",
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
}, },
{"name": "title", "selector": "h2 a span", "type": "text"},
{ {
"name": "url", "name": "url",
"selector": "h2 a", "selector": "h2 a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
}, },
{ {
"name": "image", "name": "image",
"selector": ".s-image", "selector": ".s-image",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
}, },
{ {
"name": "rating", "name": "rating",
"selector": ".a-icon-star-small .a-icon-alt", "selector": ".a-icon-star-small .a-icon-alt",
"type": "text" "type": "text",
}, },
{ {
"name": "reviews_count", "name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text" "type": "text",
}, },
{ {
"name": "price", "name": "price",
"selector": ".a-price .a-offscreen", "selector": ".a-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "original_price", "name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen", "selector": ".a-price.a-text-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "sponsored", "name": "sponsored",
"selector": ".puis-sponsored-label-text", "selector": ".puis-sponsored-label-text",
"type": "exists" "type": "exists",
}, },
{ {
"name": "delivery_info", "name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base", "selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text", "type": "text",
"multiple": True "multiple": True,
},
],
} }
] ),
}
)
) )
url = "https://www.amazon.com/" url = "https://www.amazon.com/"
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs): async def after_goto(
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
):
"""Hook called after navigating to each URL""" """Hook called after navigating to each URL"""
print(f"[HOOK] after_goto - Successfully loaded: {url}") print(f"[HOOK] after_goto - Successfully loaded: {url}")
try: try:
# Wait for search box to be available # Wait for search box to be available
search_box = await page.wait_for_selector('#twotabsearchtextbox', timeout=1000) search_box = await page.wait_for_selector(
"#twotabsearchtextbox", timeout=1000
)
# Type the search query # Type the search query
await search_box.fill('Samsung Galaxy Tab') await search_box.fill("Samsung Galaxy Tab")
# Get the search button and prepare for navigation # Get the search button and prepare for navigation
search_button = await page.wait_for_selector('#nav-search-submit-button', timeout=1000) search_button = await page.wait_for_selector(
"#nav-search-submit-button", timeout=1000
)
# Click with navigation waiting # Click with navigation waiting
await search_button.click() await search_button.click()
# Wait for search results to load # Wait for search results to load
await page.wait_for_selector('[data-component-type="s-search-result"]', timeout=10000) await page.wait_for_selector(
'[data-component-type="s-search-result"]', timeout=10000
)
print("[HOOK] Search completed and results loaded!") print("[HOOK] Search completed and results loaded!")
except Exception as e: except Exception as e:
@@ -115,7 +119,6 @@ async def extract_amazon_products():
# Use context manager for proper resource handling # Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
crawler.crawler_strategy.set_hook("after_goto", after_goto) crawler.crawler_strategy.set_hook("after_goto", after_goto)
# Extract the data # Extract the data
@@ -136,10 +139,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}") print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}") print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'): if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}") print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(extract_amazon_products()) asyncio.run(extract_amazon_products())

View File

@@ -8,7 +8,7 @@ from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json import json
from playwright.async_api import Page, BrowserContext
async def extract_amazon_products(): async def extract_amazon_products():
# Initialize browser config # Initialize browser config
@@ -41,65 +41,60 @@ async def extract_amazon_products():
"name": "asin", "name": "asin",
"selector": "", "selector": "",
"type": "attribute", "type": "attribute",
"attribute": "data-asin" "attribute": "data-asin",
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
}, },
{"name": "title", "selector": "h2 a span", "type": "text"},
{ {
"name": "url", "name": "url",
"selector": "h2 a", "selector": "h2 a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
}, },
{ {
"name": "image", "name": "image",
"selector": ".s-image", "selector": ".s-image",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
}, },
{ {
"name": "rating", "name": "rating",
"selector": ".a-icon-star-small .a-icon-alt", "selector": ".a-icon-star-small .a-icon-alt",
"type": "text" "type": "text",
}, },
{ {
"name": "reviews_count", "name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text" "type": "text",
}, },
{ {
"name": "price", "name": "price",
"selector": ".a-price .a-offscreen", "selector": ".a-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "original_price", "name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen", "selector": ".a-price.a-text-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "sponsored", "name": "sponsored",
"selector": ".puis-sponsored-label-text", "selector": ".puis-sponsored-label-text",
"type": "exists" "type": "exists",
}, },
{ {
"name": "delivery_info", "name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base", "selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text", "type": "text",
"multiple": True "multiple": True,
},
],
} }
] ),
}
)
) )
# Example search URL (you should replace with your actual Amazon URL) # Example search URL (you should replace with your actual Amazon URL)
url = "https://www.amazon.com/" url = "https://www.amazon.com/"
# Use context manager for proper resource handling # Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
# Extract the data # Extract the data
@@ -120,10 +115,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}") print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}") print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'): if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}") print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(extract_amazon_products()) asyncio.run(extract_amazon_products())

View File

@@ -1,12 +1,16 @@
# File: async_webcrawler_multiple_urls_example.py # File: async_webcrawler_multiple_urls_example.py
import os, sys import os, sys
# append 2 parent directories to sys.path to import crawl4ai # append 2 parent directories to sys.path to import crawl4ai
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir) sys.path.append(parent_dir)
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
async def main(): async def main():
# Initialize the AsyncWebCrawler # Initialize the AsyncWebCrawler
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -16,7 +20,7 @@ async def main():
"https://python.org", "https://python.org",
"https://github.com", "https://github.com",
"https://stackoverflow.com", "https://stackoverflow.com",
"https://news.ycombinator.com" "https://news.ycombinator.com",
] ]
# Set up crawling parameters # Set up crawling parameters
@@ -27,7 +31,7 @@ async def main():
urls=urls, urls=urls,
word_count_threshold=word_count_threshold, word_count_threshold=word_count_threshold,
bypass_cache=True, bypass_cache=True,
verbose=True verbose=True,
) )
# Process the results # Process the results
@@ -36,7 +40,9 @@ async def main():
print(f"Successfully crawled: {result.url}") print(f"Successfully crawled: {result.url}")
print(f"Title: {result.metadata.get('title', 'N/A')}") print(f"Title: {result.metadata.get('title', 'N/A')}")
print(f"Word count: {len(result.markdown.split())}") print(f"Word count: {len(result.markdown.split())}")
print(f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}") print(
f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}"
)
print(f"Number of images: {len(result.media.get('images', []))}") print(f"Number of images: {len(result.media.get('images', []))}")
print("---") print("---")
else: else:
@@ -44,5 +50,6 @@ async def main():
print(f"Error: {result.error_message}") print(f"Error: {result.error_message}")
print("---") print("---")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -6,10 +6,8 @@ This example demonstrates optimal browser usage patterns in Crawl4AI:
""" """
import asyncio import asyncio
import os
from typing import List from typing import List
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator

View File

@@ -1,31 +1,32 @@
import os, time import os, time
# append the path to the root of the project # append the path to the root of the project
import sys import sys
import asyncio import asyncio
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data'
__data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data"
async def compare(): async def compare():
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY']) app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
# Tet Firecrawl with a simple crawl # Tet Firecrawl with a simple crawl
start = time.time() start = time.time()
scrape_status = app.scrape_url( scrape_status = app.scrape_url(
'https://www.nbcnews.com/business', "https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
params={'formats': ['markdown', 'html']}
) )
end = time.time() end = time.time()
print(f"Time taken: {end - start} seconds") print(f"Time taken: {end - start} seconds")
print(len(scrape_status['markdown'])) print(len(scrape_status["markdown"]))
# save the markdown content with provider name # save the markdown content with provider name
with open(f"{__data__}/firecrawl_simple.md", "w") as f: with open(f"{__data__}/firecrawl_simple.md", "w") as f:
f.write(scrape_status['markdown']) f.write(scrape_status["markdown"])
# Count how many "cldnry.s-nbcnews.com" are in the markdown # Count how many "cldnry.s-nbcnews.com" are in the markdown
print(scrape_status['markdown'].count("cldnry.s-nbcnews.com")) print(scrape_status["markdown"].count("cldnry.s-nbcnews.com"))
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
start = time.time() start = time.time()
@@ -34,7 +35,7 @@ async def compare():
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], # js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
word_count_threshold=0, word_count_threshold=0,
bypass_cache=True, bypass_cache=True,
verbose=False verbose=False,
) )
end = time.time() end = time.time()
print(f"Time taken: {end - start} seconds") print(f"Time taken: {end - start} seconds")
@@ -48,10 +49,12 @@ async def compare():
start = time.time() start = time.time()
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], js_code=[
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
word_count_threshold=0, word_count_threshold=0,
bypass_cache=True, bypass_cache=True,
verbose=False verbose=False,
) )
end = time.time() end = time.time()
print(f"Time taken: {end - start} seconds") print(f"Time taken: {end - start} seconds")
@@ -62,6 +65,6 @@ async def compare():
# count how many "cldnry.s-nbcnews.com" are in the markdown # count how many "cldnry.s-nbcnews.com" are in the markdown
print(result.markdown.count("cldnry.s-nbcnews.com")) print(result.markdown.count("cldnry.s-nbcnews.com"))
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(compare()) asyncio.run(compare())

View File

@@ -3,11 +3,18 @@ import time
from rich import print from rich import print
from rich.table import Table from rich.table import Table
from crawl4ai import ( from crawl4ai import (
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, AsyncWebCrawler,
MemoryAdaptiveDispatcher, SemaphoreDispatcher, BrowserConfig,
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode CrawlerRunConfig,
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
CacheMode,
) )
async def memory_adaptive(urls, browser_config, run_config): async def memory_adaptive(urls, browser_config, run_config):
"""Memory adaptive crawler with monitoring""" """Memory adaptive crawler with monitoring"""
start = time.perf_counter() start = time.perf_counter()
@@ -16,14 +23,16 @@ async def memory_adaptive(urls, browser_config, run_config):
memory_threshold_percent=70.0, memory_threshold_percent=70.0,
max_session_permit=10, max_session_permit=10,
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
async def memory_adaptive_with_rate_limit(urls, browser_config, run_config): async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
"""Memory adaptive crawler with rate limiting""" """Memory adaptive crawler with rate limiting"""
start = time.perf_counter() start = time.perf_counter()
@@ -32,19 +41,19 @@ async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
memory_threshold_percent=70.0, memory_threshold_percent=70.0,
max_session_permit=10, max_session_permit=10,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(1.0, 2.0), base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
max_delay=30.0,
max_retries=2
), ),
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
async def semaphore(urls, browser_config, run_config): async def semaphore(urls, browser_config, run_config):
"""Basic semaphore crawler""" """Basic semaphore crawler"""
start = time.perf_counter() start = time.perf_counter()
@@ -52,14 +61,16 @@ async def semaphore(urls, browser_config, run_config):
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(
semaphore_count=5, semaphore_count=5,
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
async def semaphore_with_rate_limit(urls, browser_config, run_config): async def semaphore_with_rate_limit(urls, browser_config, run_config):
"""Semaphore crawler with rate limiting""" """Semaphore crawler with rate limiting"""
start = time.perf_counter() start = time.perf_counter()
@@ -67,19 +78,19 @@ async def semaphore_with_rate_limit(urls, browser_config, run_config):
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(
semaphore_count=5, semaphore_count=5,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(1.0, 2.0), base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
max_delay=30.0,
max_retries=2
), ),
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
def create_performance_table(results): def create_performance_table(results):
"""Creates a rich table showing performance results""" """Creates a rich table showing performance results"""
table = Table(title="Crawler Strategy Performance Comparison") table = Table(title="Crawler Strategy Performance Comparison")
@@ -93,14 +104,12 @@ def create_performance_table(results):
for strategy, (urls_crawled, duration) in sorted_results: for strategy, (urls_crawled, duration) in sorted_results:
urls_per_second = urls_crawled / duration urls_per_second = urls_crawled / duration
table.add_row( table.add_row(
strategy, strategy, str(urls_crawled), f"{duration:.2f}", f"{urls_per_second:.2f}"
str(urls_crawled),
f"{duration:.2f}",
f"{urls_per_second:.2f}"
) )
return table return table
async def main(): async def main():
urls = [f"https://example.com/page{i}" for i in range(1, 20)] urls = [f"https://example.com/page{i}" for i in range(1, 20)]
browser_config = BrowserConfig(headless=True, verbose=False) browser_config = BrowserConfig(headless=True, verbose=False)
@@ -108,14 +117,19 @@ async def main():
results = { results = {
"Memory Adaptive": await memory_adaptive(urls, browser_config, run_config), "Memory Adaptive": await memory_adaptive(urls, browser_config, run_config),
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(urls, browser_config, run_config), "Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(
urls, browser_config, run_config
),
"Semaphore": await semaphore(urls, browser_config, run_config), "Semaphore": await semaphore(urls, browser_config, run_config),
"Semaphore + Rate Limit": await semaphore_with_rate_limit(urls, browser_config, run_config), "Semaphore + Rate Limit": await semaphore_with_rate_limit(
urls, browser_config, run_config
),
} }
table = create_performance_table(results) table = create_performance_table(results)
print("\nPerformance Summary:") print("\nPerformance Summary:")
print(table) print(table)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -6,15 +6,24 @@ import base64
import os import os
from typing import Dict, Any from typing import Dict, Any
class Crawl4AiTester: class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
self.base_url = base_url self.base_url = base_url
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" # Check environment variable as fallback self.api_token = (
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} api_token or os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
) # Check environment variable as fallback
self.headers = (
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
)
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job # Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
if response.status_code == 403: if response.status_code == 403:
raise Exception("API token is invalid or missing") raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"] task_id = response.json()["task_id"]
@@ -24,9 +33,13 @@ class Crawl4AiTester:
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) result = requests.get(
f"{self.base_url}/task/{task_id}", headers=self.headers
)
status = result.json() status = result.json()
if status["status"] == "failed": if status["status"] == "failed":
@@ -39,7 +52,12 @@ class Crawl4AiTester:
time.sleep(2) time.sleep(2)
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) response = requests.post(
f"{self.base_url}/crawl_sync",
json=request_data,
headers=self.headers,
timeout=60,
)
if response.status_code == 408: if response.status_code == 408:
raise TimeoutError("Task did not complete within server timeout") raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status() response.raise_for_status()
@@ -48,13 +66,12 @@ class Crawl4AiTester:
def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Directly crawl without using task queue""" """Directly crawl without using task queue"""
response = requests.post( response = requests.post(
f"{self.base_url}/crawl_direct", f"{self.base_url}/crawl_direct", json=request_data, headers=self.headers
json=request_data,
headers=self.headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def test_docker_deployment(version="basic"): def test_docker_deployment(version="basic"):
tester = Crawl4AiTester( tester = Crawl4AiTester(
base_url="http://localhost:11235", base_url="http://localhost:11235",
@@ -70,7 +87,7 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10) health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json()) print("Health check:", health.json())
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
if i == max_retries - 1: if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts") print(f"Failed to connect after {max_retries} attempts")
sys.exit(1) sys.exit(1)
@@ -99,7 +116,7 @@ def test_basic_crawl(tester: Crawl4AiTester):
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -107,19 +124,21 @@ def test_basic_crawl(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0 assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_sync(tester: Crawl4AiTester): def test_basic_crawl_sync(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Sync) ===") print("\n=== Testing Basic Crawl (Sync) ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_sync(request) result = tester.submit_sync(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['status'] == 'completed' assert result["status"] == "completed"
assert result['result']['success'] assert result["result"]["success"]
assert len(result['result']['markdown']) > 0 assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_direct(tester: Crawl4AiTester): def test_basic_crawl_direct(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Direct) ===") print("\n=== Testing Basic Crawl (Direct) ===")
@@ -127,13 +146,14 @@ def test_basic_crawl_direct(tester: Crawl4AiTester):
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
# "session_id": "test" # "session_id": "test"
"cache_mode": "bypass" # or "enabled", "disabled", "read_only", "write_only" "cache_mode": "bypass", # or "enabled", "disabled", "read_only", "write_only"
} }
result = tester.crawl_direct(request) result = tester.crawl_direct(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['result']['success'] assert result["result"]["success"]
assert len(result['result']['markdown']) > 0 assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester): def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
@@ -144,32 +164,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}") print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester): def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description", "css_selector": ".wide-tease-item__description",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True "extra": {"word_count_threshold": 10},
},
"extra": {"word_count_threshold": 10}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}") print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester): def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
schema = { schema = {
@@ -190,19 +207,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price", "name": "price",
"selector": "td:nth-child(2)", "selector": "td:nth-child(2)",
"type": "text", "type": "text",
} },
], ],
} }
request = { request = {
"urls": "https://www.coinbase.com/explore", "urls": "https://www.coinbase.com/explore",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -212,6 +224,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester): def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===") print("\n=== Testing LLM Extraction ===")
schema = { schema = {
@@ -219,18 +232,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": { "properties": {
"model_name": { "model_name": {
"type": "string", "type": "string",
"description": "Name of the OpenAI model." "description": "Name of the OpenAI model.",
}, },
"input_fee": { "input_fee": {
"type": "string", "type": "string",
"description": "Fee for input token for the OpenAI model." "description": "Fee for input token for the OpenAI model.",
}, },
"output_fee": { "output_fee": {
"type": "string", "type": "string",
"description": "Fee for output token for the OpenAI model." "description": "Fee for output token for the OpenAI model.",
}
}, },
"required": ["model_name", "input_fee", "output_fee"] },
"required": ["model_name", "input_fee", "output_fee"],
} }
request = { request = {
@@ -243,10 +256,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"), "api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
}
}, },
"crawler_params": {"word_count_threshold": 1} },
"crawler_params": {"word_count_threshold": 1},
} }
try: try:
@@ -258,6 +271,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester): def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===") print("\n=== Testing LLM with Ollama ===")
schema = { schema = {
@@ -265,18 +279,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
} },
} }
request = { request = {
@@ -288,11 +302,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2", "provider": "ollama/llama2",
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics." "instruction": "Extract the main article information including title, summary, and main topics.",
} },
}, },
"extra": {"word_count_threshold": 1}, "extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True} "crawler_params": {"verbose": True},
} }
try: try:
@@ -303,6 +317,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Ollama extraction test failed: {str(e)}") print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester): def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===") print("\n=== Testing Cosine Extraction ===")
request = { request = {
@@ -314,9 +329,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy", "semantic_filter": "business finance economy",
"word_count_threshold": 10, "word_count_threshold": 10,
"max_dist": 0.2, "max_dist": 0.2,
"top_k": 3 "top_k": 3,
} },
} },
} }
try: try:
@@ -328,15 +343,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Cosine extraction test failed: {str(e)}") print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester): def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -351,6 +365,7 @@ def test_screenshot(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
if __name__ == "__main__": if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic" version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full" # version = "full"

View File

@@ -9,18 +9,17 @@ This example shows how to:
import asyncio import asyncio
import os import os
from typing import Dict, Any
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
LLMExtractionStrategy, LLMExtractionStrategy,
JsonCssExtractionStrategy, JsonCssExtractionStrategy,
JsonXPathExtractionStrategy JsonXPathExtractionStrategy,
) )
from crawl4ai.chunking_strategy import RegexChunking, IdentityChunking
from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str): async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str):
"""Helper function to run extraction with proper configuration""" """Helper function to run extraction with proper configuration"""
try: try:
@@ -30,7 +29,7 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
extraction_strategy=strategy, extraction_strategy=strategy,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter() # For fit_markdown support content_filter=PruningContentFilter() # For fit_markdown support
) ),
) )
# Run the crawler # Run the crawler
@@ -40,22 +39,22 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
print(f"\n=== {name} Results ===") print(f"\n=== {name} Results ===")
print(f"Extracted Content: {result.extracted_content}") print(f"Extracted Content: {result.extracted_content}")
print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}") print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}")
print(f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}") print(
f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}"
)
else: else:
print(f"Error in {name}: Crawl failed") print(f"Error in {name}: Crawl failed")
except Exception as e: except Exception as e:
print(f"Error in {name}: {str(e)}") print(f"Error in {name}: {str(e)}")
async def main(): async def main():
# Example URL (replace with actual URL) # Example URL (replace with actual URL)
url = "https://example.com/product-page" url = "https://example.com/product-page"
# Configure browser settings # Configure browser settings
browser_config = BrowserConfig( browser_config = BrowserConfig(headless=True, verbose=True)
headless=True,
verbose=True
)
# Initialize extraction strategies # Initialize extraction strategies
@@ -63,21 +62,21 @@ async def main():
markdown_strategy = LLMExtractionStrategy( markdown_strategy = LLMExtractionStrategy(
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information including name, price, and description" instruction="Extract product information including name, price, and description",
) )
html_strategy = LLMExtractionStrategy( html_strategy = LLMExtractionStrategy(
input_format="html", input_format="html",
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information from HTML including structured data" instruction="Extract product information from HTML including structured data",
) )
fit_markdown_strategy = LLMExtractionStrategy( fit_markdown_strategy = LLMExtractionStrategy(
input_format="fit_markdown", input_format="fit_markdown",
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information from cleaned markdown" instruction="Extract product information from cleaned markdown",
) )
# 2. JSON CSS Extraction (automatically uses HTML input) # 2. JSON CSS Extraction (automatically uses HTML input)
@@ -86,8 +85,8 @@ async def main():
"fields": [ "fields": [
{"name": "title", "selector": "h1.product-title", "type": "text"}, {"name": "title", "selector": "h1.product-title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"}, {"name": "price", "selector": ".price", "type": "text"},
{"name": "description", "selector": ".description", "type": "text"} {"name": "description", "selector": ".description", "type": "text"},
] ],
} }
css_strategy = JsonCssExtractionStrategy(schema=css_schema) css_strategy = JsonCssExtractionStrategy(schema=css_schema)
@@ -95,10 +94,22 @@ async def main():
xpath_schema = { xpath_schema = {
"baseSelector": "//div[@class='product']", "baseSelector": "//div[@class='product']",
"fields": [ "fields": [
{"name": "title", "selector": ".//h1[@class='product-title']/text()", "type": "text"}, {
{"name": "price", "selector": ".//span[@class='price']/text()", "type": "text"}, "name": "title",
{"name": "description", "selector": ".//div[@class='description']/text()", "type": "text"} "selector": ".//h1[@class='product-title']/text()",
] "type": "text",
},
{
"name": "price",
"selector": ".//span[@class='price']/text()",
"type": "text",
},
{
"name": "description",
"selector": ".//div[@class='description']/text()",
"type": "text",
},
],
} }
xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema) xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema)
@@ -111,5 +122,6 @@ async def main():
await run_extraction(crawler, url, css_strategy, "CSS Extraction") await run_extraction(crawler, url, css_strategy, "CSS Extraction")
await run_extraction(crawler, url, xpath_strategy, "XPath Extraction") await run_extraction(crawler, url, xpath_strategy, "XPath Extraction")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,20 +1,23 @@
import asyncio import asyncio
from crawl4ai import * from crawl4ai import *
async def main(): async def main():
browser_config = BrowserConfig(headless=True, verbose=True) browser_config = BrowserConfig(headless=True, verbose=True)
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
) )
),
) )
result = await crawler.arun( result = await crawler.arun(
url="https://www.helloworld.org", url="https://www.helloworld.org", config=crawler_config
config=crawler_config
) )
print(result.markdown_v2.raw_markdown[:500]) print(result.markdown_v2.raw_markdown[:500])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,19 +1,18 @@
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from playwright.async_api import Page, BrowserContext from playwright.async_api import Page, BrowserContext
async def main(): async def main():
print("🔗 Hooks Example: Demonstrating different hook use cases") print("🔗 Hooks Example: Demonstrating different hook use cases")
# Configure browser settings # Configure browser settings
browser_config = BrowserConfig( browser_config = BrowserConfig(headless=True)
headless=True
)
# Configure crawler settings # Configure crawler settings
crawler_run_config = CrawlerRunConfig( crawler_run_config = CrawlerRunConfig(
js_code="window.scrollTo(0, document.body.scrollHeight);", js_code="window.scrollTo(0, document.body.scrollHeight);",
wait_for="body", wait_for="body",
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
# Create crawler instance # Create crawler instance
@@ -30,16 +29,22 @@ async def main():
"""Hook called after a new page and context are created""" """Hook called after a new page and context are created"""
print("[HOOK] on_page_context_created - New page created!") print("[HOOK] on_page_context_created - New page created!")
# Example: Set default viewport size # Example: Set default viewport size
await context.add_cookies([{ await context.add_cookies(
'name': 'session_id', [
'value': 'example_session', {
'domain': '.example.com', "name": "session_id",
'path': '/' "value": "example_session",
}]) "domain": ".example.com",
"path": "/",
}
]
)
await page.set_viewport_size({"width": 1080, "height": 800}) await page.set_viewport_size({"width": 1080, "height": 800})
return page return page
async def on_user_agent_updated(page: Page, context: BrowserContext, user_agent: str, **kwargs): async def on_user_agent_updated(
page: Page, context: BrowserContext, user_agent: str, **kwargs
):
"""Hook called when the user agent is updated""" """Hook called when the user agent is updated"""
print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}") print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}")
return page return page
@@ -53,17 +58,17 @@ async def main():
"""Hook called before navigating to each URL""" """Hook called before navigating to each URL"""
print(f"[HOOK] before_goto - About to visit: {url}") print(f"[HOOK] before_goto - About to visit: {url}")
# Example: Add custom headers for the request # Example: Add custom headers for the request
await page.set_extra_http_headers({ await page.set_extra_http_headers({"Custom-Header": "my-value"})
"Custom-Header": "my-value"
})
return page return page
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs): async def after_goto(
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
):
"""Hook called after navigating to each URL""" """Hook called after navigating to each URL"""
print(f"[HOOK] after_goto - Successfully loaded: {url}") print(f"[HOOK] after_goto - Successfully loaded: {url}")
# Example: Wait for a specific element to be loaded # Example: Wait for a specific element to be loaded
try: try:
await page.wait_for_selector('.content', timeout=1000) await page.wait_for_selector(".content", timeout=1000)
print("Content element found!") print("Content element found!")
except: except:
print("Content element not found, continuing anyway") print("Content element not found, continuing anyway")
@@ -76,7 +81,9 @@ async def main():
await page.evaluate("window.scrollTo(0, document.body.scrollHeight);") await page.evaluate("window.scrollTo(0, document.body.scrollHeight);")
return page return page
async def before_return_html(page: Page, context: BrowserContext, html:str, **kwargs): async def before_return_html(
page: Page, context: BrowserContext, html: str, **kwargs
):
"""Hook called before returning the HTML content""" """Hook called before returning the HTML content"""
print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})") print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})")
# Example: You could modify the HTML content here if needed # Example: You could modify the HTML content here if needed
@@ -84,7 +91,9 @@ async def main():
# Set all the hooks # Set all the hooks
crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created) crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created)
crawler.crawler_strategy.set_hook("on_page_context_created", on_page_context_created) crawler.crawler_strategy.set_hook(
"on_page_context_created", on_page_context_created
)
crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated) crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated)
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started) crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
crawler.crawler_strategy.set_hook("before_goto", before_goto) crawler.crawler_strategy.set_hook("before_goto", before_goto)
@@ -95,13 +104,15 @@ async def main():
await crawler.start() await crawler.start()
# Example usage: crawl a simple website # Example usage: crawl a simple website
url = 'https://example.com' url = "https://example.com"
result = await crawler.arun(url, config=crawler_run_config) result = await crawler.arun(url, config=crawler_run_config)
print(f"\nCrawled URL: {result.url}") print(f"\nCrawled URL: {result.url}")
print(f"HTML length: {len(result.html)}") print(f"HTML length: {len(result.html)}")
await crawler.close() await crawler.close()
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy
async def main(): async def main():
# Example 1: Setting language when creating the crawler # Example 1: Setting language when creating the crawler
crawler1 = AsyncWebCrawler( crawler1 = AsyncWebCrawler(
@@ -9,11 +10,15 @@ async def main():
) )
) )
result1 = await crawler1.arun("https://www.example.com") result1 = await crawler1.arun("https://www.example.com")
print("Example 1 result:", result1.extracted_content[:100]) # Print first 100 characters print(
"Example 1 result:", result1.extracted_content[:100]
) # Print first 100 characters
# Example 2: Setting language before crawling # Example 2: Setting language before crawling
crawler2 = AsyncWebCrawler() crawler2 = AsyncWebCrawler()
crawler2.crawler_strategy.headers["Accept-Language"] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7" crawler2.crawler_strategy.headers[
"Accept-Language"
] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
result2 = await crawler2.arun("https://www.example.com") result2 = await crawler2.arun("https://www.example.com")
print("Example 2 result:", result2.extracted_content[:100]) print("Example 2 result:", result2.extracted_content[:100])
@@ -21,7 +26,7 @@ async def main():
crawler3 = AsyncWebCrawler() crawler3 = AsyncWebCrawler()
result3 = await crawler3.arun( result3 = await crawler3.arun(
"https://www.example.com", "https://www.example.com",
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"} headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"},
) )
print("Example 3 result:", result3.extracted_content[:100]) print("Example 3 result:", result3.extracted_content[:100])
@@ -33,13 +38,13 @@ async def main():
] ]
crawler4 = AsyncWebCrawler() crawler4 = AsyncWebCrawler()
results = await asyncio.gather(*[ results = await asyncio.gather(
crawler4.arun(url, headers={"Accept-Language": lang}) *[crawler4.arun(url, headers={"Accept-Language": lang}) for url, lang in urls]
for url, lang in urls )
])
for url, result in zip([u for u, _ in urls], results): for url, result in zip([u for u, _ in urls], results):
print(f"Result for {url}:", result.extracted_content[:100]) print(f"Result for {url}:", result.extracted_content[:100])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -3,15 +3,20 @@ from crawl4ai.crawler_strategy import *
import asyncio import asyncio
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
url = r'https://openai.com/api/pricing/' url = r"https://openai.com/api/pricing/"
class OpenAIModelFee(BaseModel): class OpenAIModelFee(BaseModel):
model_name: str = Field(..., description="Name of the OpenAI model.") model_name: str = Field(..., description="Name of the OpenAI model.")
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.") input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
output_fee: str = Field(..., description="Fee for output token for the OpenAI model.") output_fee: str = Field(
..., description="Fee for output token for the OpenAI model."
)
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
async def main(): async def main():
# Use AsyncWebCrawler # Use AsyncWebCrawler
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
@@ -20,15 +25,15 @@ async def main():
word_count_threshold=1, word_count_threshold=1,
extraction_strategy=LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
# provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'), # provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
provider= "groq/llama-3.1-70b-versatile", api_token = os.getenv('GROQ_API_KEY'), provider="groq/llama-3.1-70b-versatile",
api_token=os.getenv("GROQ_API_KEY"),
schema=OpenAIModelFee.model_json_schema(), schema=OpenAIModelFee.model_json_schema(),
extraction_type="schema", extraction_type="schema",
instruction="From the crawled content, extract all mentioned model names along with their " \ instruction="From the crawled content, extract all mentioned model names along with their "
"fees for input and output tokens. Make sure not to miss anything in the entire content. " \ "fees for input and output tokens. Make sure not to miss anything in the entire content. "
'One extracted model JSON format should look like this: ' \ "One extracted model JSON format should look like this: "
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }' '{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }',
), ),
) )
print("Success:", result.success) print("Success:", result.success)
model_fees = json.loads(result.extracted_content) model_fees = json.loads(result.extracted_content)
@@ -37,4 +42,5 @@ async def main():
with open(".data/data.json", "w", encoding="utf-8") as f: with open(".data/data.json", "w", encoding="utf-8") as f:
f.write(result.extracted_content) f.write(result.extracted_content)
asyncio.run(main()) asyncio.run(main())

View File

@@ -8,12 +8,12 @@ import asyncio
import time import time
import json import json
import re import re
from typing import Dict, List from typing import Dict
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
JsonCssExtractionStrategy, JsonCssExtractionStrategy,
LLMExtractionStrategy, LLMExtractionStrategy,
@@ -62,6 +62,7 @@ async def clean_content():
print(f"Full Markdown Length: {full_markdown_length}") print(f"Full Markdown Length: {full_markdown_length}")
print(f"Fit Markdown Length: {fit_markdown_length}") print(f"Fit Markdown Length: {fit_markdown_length}")
async def link_analysis(): async def link_analysis():
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.ENABLED, cache_mode=CacheMode.ENABLED,
@@ -76,9 +77,10 @@ async def link_analysis():
print(f"Found {len(result.links['internal'])} internal links") print(f"Found {len(result.links['internal'])} internal links")
print(f"Found {len(result.links['external'])} external links") print(f"Found {len(result.links['external'])} external links")
for link in result.links['internal'][:5]: for link in result.links["internal"][:5]:
print(f"Href: {link['href']}\nText: {link['text']}\n") print(f"Href: {link['href']}\nText: {link['text']}\n")
# JavaScript Execution Example # JavaScript Execution Example
async def simple_example_with_running_js_code(): async def simple_example_with_running_js_code():
print("\n--- Executing JavaScript and Using CSS Selectors ---") print("\n--- Executing JavaScript and Using CSS Selectors ---")
@@ -112,25 +114,29 @@ async def simple_example_with_css_selector():
) )
print(result.markdown[:500]) print(result.markdown[:500])
async def media_handling(): async def media_handling():
crawler_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True) crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True
)
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business", config=crawler_config
config=crawler_config
) )
for img in result.media['images'][:5]: for img in result.media["images"][:5]:
print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}") print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}")
async def custom_hook_workflow(verbose=True): async def custom_hook_workflow(verbose=True):
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
# Set a 'before_goto' hook to run custom code just before navigation # Set a 'before_goto' hook to run custom code just before navigation
crawler.crawler_strategy.set_hook("before_goto", lambda page, context: print("[Hook] Preparing to navigate...")) crawler.crawler_strategy.set_hook(
"before_goto",
lambda page, context: print("[Hook] Preparing to navigate..."),
)
# Perform the crawl operation # Perform the crawl operation
result = await crawler.arun( result = await crawler.arun(url="https://crawl4ai.com")
url="https://crawl4ai.com"
)
print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- ")) print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- "))
@@ -417,16 +423,17 @@ async def cosine_similarity_extraction():
top_k=3, # Number of top keywords to extract top_k=3, # Number of top keywords to extract
sim_threshold=0.3, # Similarity threshold for clustering sim_threshold=0.3, # Similarity threshold for clustering
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
verbose=True verbose=True,
), ),
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156", url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156",
config=crawl_config config=crawl_config,
) )
print(json.loads(result.extracted_content)[:5]) print(json.loads(result.extracted_content)[:5])
# Browser Comparison # Browser Comparison
async def crawl_custom_browser_type(): async def crawl_custom_browser_type():
print("\n--- Browser Comparison ---") print("\n--- Browser Comparison ---")
@@ -484,18 +491,16 @@ async def crawl_with_user_simulation():
result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config) result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config)
print(result.markdown) print(result.markdown)
async def ssl_certification(): async def ssl_certification():
# Configure crawler to fetch SSL certificate # Configure crawler to fetch SSL certificate
config = CrawlerRunConfig( config = CrawlerRunConfig(
fetch_ssl_certificate=True, fetch_ssl_certificate=True,
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=config)
url='https://example.com',
config=config
)
if result.success and result.ssl_certificate: if result.success and result.ssl_certificate:
cert = result.ssl_certificate cert = result.ssl_certificate
@@ -511,12 +516,17 @@ async def ssl_certification():
print("\nCertificate exported to:") print("\nCertificate exported to:")
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}") print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers pem_data = cert.to_pem(
os.path.join(tmp_dir, "certificate.pem")
) # For web servers
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}") print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps der_data = cert.to_der(
os.path.join(tmp_dir, "certificate.der")
) # For Java apps
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}") print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
# Speed Comparison # Speed Comparison
async def speed_comparison(): async def speed_comparison():
print("\n--- Speed Comparison ---") print("\n--- Speed Comparison ---")

View File

@@ -1,6 +1,10 @@
import os, sys import os, sys
# append parent directory to system path # append parent directory to system path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))); os.environ['FIRECRAWL_API_KEY'] = "fc-84b370ccfad44beabc686b38f1769692"; sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
os.environ["FIRECRAWL_API_KEY"] = "fc-84b370ccfad44beabc686b38f1769692"
import asyncio import asyncio
# import nest_asyncio # import nest_asyncio
@@ -15,7 +19,7 @@ from bs4 import BeautifulSoup
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
JsonCssExtractionStrategy, JsonCssExtractionStrategy,
LLMExtractionStrategy, LLMExtractionStrategy,
@@ -32,9 +36,12 @@ print("Website: https://crawl4ai.com")
async def simple_crawl(): async def simple_crawl():
print("\n--- Basic Usage ---") print("\n--- Basic Usage ---")
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun(url="https://www.nbcnews.com/business", cache_mode= CacheMode.BYPASS) result = await crawler.arun(
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def simple_example_with_running_js_code(): async def simple_example_with_running_js_code():
print("\n--- Executing JavaScript and Using CSS Selectors ---") print("\n--- Executing JavaScript and Using CSS Selectors ---")
# New code to handle the wait_for parameter # New code to handle the wait_for parameter
@@ -57,6 +64,7 @@ async def simple_example_with_running_js_code():
) )
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def simple_example_with_css_selector(): async def simple_example_with_css_selector():
print("\n--- Using CSS Selectors ---") print("\n--- Using CSS Selectors ---")
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -67,26 +75,27 @@ async def simple_example_with_css_selector():
) )
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def use_proxy(): async def use_proxy():
print("\n--- Using a Proxy ---") print("\n--- Using a Proxy ---")
print( print(
"Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example." "Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example."
) )
# Uncomment and modify the following lines to use a proxy # Uncomment and modify the following lines to use a proxy
async with AsyncWebCrawler(verbose=True, proxy="http://your-proxy-url:port") as crawler: async with AsyncWebCrawler(
verbose=True, proxy="http://your-proxy-url:port"
) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
cache_mode= CacheMode.BYPASS
) )
if result.success: if result.success:
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def capture_and_save_screenshot(url: str, output_path: str): async def capture_and_save_screenshot(url: str, output_path: str):
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, screenshot=True, cache_mode=CacheMode.BYPASS
screenshot=True,
cache_mode= CacheMode.BYPASS
) )
if result.success and result.screenshot: if result.success and result.screenshot:
@@ -96,13 +105,14 @@ async def capture_and_save_screenshot(url: str, output_path: str):
screenshot_data = base64.b64decode(result.screenshot) screenshot_data = base64.b64decode(result.screenshot)
# Save the screenshot as a JPEG file # Save the screenshot as a JPEG file
with open(output_path, 'wb') as f: with open(output_path, "wb") as f:
f.write(screenshot_data) f.write(screenshot_data)
print(f"Screenshot saved successfully to {output_path}") print(f"Screenshot saved successfully to {output_path}")
else: else:
print("Failed to capture screenshot") print("Failed to capture screenshot")
class OpenAIModelFee(BaseModel): class OpenAIModelFee(BaseModel):
model_name: str = Field(..., description="Name of the OpenAI model.") model_name: str = Field(..., description="Name of the OpenAI model.")
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.") input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
@@ -110,7 +120,10 @@ class OpenAIModelFee(BaseModel):
..., description="Fee for output token for the OpenAI model." ..., description="Fee for output token for the OpenAI model."
) )
async def extract_structured_data_using_llm(provider: str, api_token: str = None, extra_headers: Dict[str, str] = None):
async def extract_structured_data_using_llm(
provider: str, api_token: str = None, extra_headers: Dict[str, str] = None
):
print(f"\n--- Extracting Structured Data with {provider} ---") print(f"\n--- Extracting Structured Data with {provider} ---")
if api_token is None and provider != "ollama": if api_token is None and provider != "ollama":
@@ -139,12 +152,13 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None
instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens. instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens.
Do not miss any models in the entire content. One extracted model JSON format should look like this: Do not miss any models in the entire content. One extracted model JSON format should look like this:
{"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""", {"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""",
extra_args=extra_args extra_args=extra_args,
), ),
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
) )
print(result.extracted_content) print(result.extracted_content)
async def extract_structured_data_using_css_extractor(): async def extract_structured_data_using_css_extractor():
print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---") print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---")
schema = { schema = {
@@ -175,16 +189,12 @@ async def extract_structured_data_using_css_extractor():
"name": "course_icon", "name": "course_icon",
"selector": ".image-92", "selector": ".image-92",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
} },
] ],
} }
async with AsyncWebCrawler( async with AsyncWebCrawler(headless=True, verbose=True) as crawler:
headless=True,
verbose=True
) as crawler:
# Create the JavaScript that handles clicking multiple times # Create the JavaScript that handles clicking multiple times
js_click_tabs = """ js_click_tabs = """
(async () => { (async () => {
@@ -204,13 +214,14 @@ async def extract_structured_data_using_css_extractor():
url="https://www.kidocode.com/degrees/technology", url="https://www.kidocode.com/degrees/technology",
extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True), extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True),
js_code=[js_click_tabs], js_code=[js_click_tabs],
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
companies = json.loads(result.extracted_content) companies = json.loads(result.extracted_content)
print(f"Successfully extracted {len(companies)} companies") print(f"Successfully extracted {len(companies)} companies")
print(json.dumps(companies[0], indent=2)) print(json.dumps(companies[0], indent=2))
# Advanced Session-Based Crawling with Dynamic Content 🔄 # Advanced Session-Based Crawling with Dynamic Content 🔄
async def crawl_dynamic_content_pages_method_1(): async def crawl_dynamic_content_pages_method_1():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---") print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
@@ -267,6 +278,7 @@ async def crawl_dynamic_content_pages_method_1():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_dynamic_content_pages_method_2(): async def crawl_dynamic_content_pages_method_2():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---") print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
@@ -334,8 +346,11 @@ async def crawl_dynamic_content_pages_method_2():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_dynamic_content_pages_method_3(): async def crawl_dynamic_content_pages_method_3():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---") print(
"\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---"
)
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://github.com/microsoft/TypeScript/commits/main" url = "https://github.com/microsoft/TypeScript/commits/main"
@@ -395,28 +410,40 @@ async def crawl_dynamic_content_pages_method_3():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_custom_browser_type(): async def crawl_custom_browser_type():
# Use Firefox # Use Firefox
start = time.time() start = time.time()
async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler: async with AsyncWebCrawler(
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) browser_type="firefox", verbose=True, headless=True
) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) print(result.markdown[:500])
print("Time taken: ", time.time() - start) print("Time taken: ", time.time() - start)
# Use WebKit # Use WebKit
start = time.time() start = time.time()
async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler: async with AsyncWebCrawler(
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) browser_type="webkit", verbose=True, headless=True
) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) print(result.markdown[:500])
print("Time taken: ", time.time() - start) print("Time taken: ", time.time() - start)
# Use Chromium (default) # Use Chromium (default)
start = time.time() start = time.time()
async with AsyncWebCrawler(verbose=True, headless=True) as crawler: async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) print(result.markdown[:500])
print("Time taken: ", time.time() - start) print("Time taken: ", time.time() - start)
async def crawl_with_user_simultion(): async def crawl_with_user_simultion():
async with AsyncWebCrawler(verbose=True, headless=True) as crawler: async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
url = "YOUR-URL-HERE" url = "YOUR-URL-HERE"
@@ -430,6 +457,7 @@ async def crawl_with_user_simultion():
print(result.markdown) print(result.markdown)
async def speed_comparison(): async def speed_comparison():
# print("\n--- Speed Comparison ---") # print("\n--- Speed Comparison ---")
# print("Firecrawl (simulated):") # print("Firecrawl (simulated):")
@@ -439,11 +467,11 @@ async def speed_comparison():
# print() # print()
# Simulated Firecrawl performance # Simulated Firecrawl performance
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
start = time.time() start = time.time()
scrape_status = app.scrape_url( scrape_status = app.scrape_url(
'https://www.nbcnews.com/business', "https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
params={'formats': ['markdown', 'html']}
) )
end = time.time() end = time.time()
print("Firecrawl:") print("Firecrawl:")
@@ -474,7 +502,9 @@ async def speed_comparison():
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
word_count_threshold=0, word_count_threshold=0,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0) # content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
), ),
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
@@ -498,7 +528,9 @@ async def speed_comparison():
word_count_threshold=0, word_count_threshold=0,
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0) # content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
), ),
verbose=False, verbose=False,
@@ -520,6 +552,7 @@ async def speed_comparison():
print("If you run these tests in an environment with better network conditions,") print("If you run these tests in an environment with better network conditions,")
print("you may observe an even more significant speed advantage for Crawl4AI.") print("you may observe an even more significant speed advantage for Crawl4AI.")
async def generate_knowledge_graph(): async def generate_knowledge_graph():
class Entity(BaseModel): class Entity(BaseModel):
name: str name: str
@@ -536,11 +569,11 @@ async def generate_knowledge_graph():
relationships: List[Relationship] relationships: List[Relationship]
extraction_strategy = LLMExtractionStrategy( extraction_strategy = LLMExtractionStrategy(
provider='openai/gpt-4o-mini', # Or any other provider, including Ollama and open source models provider="openai/gpt-4o-mini", # Or any other provider, including Ollama and open source models
api_token=os.getenv('OPENAI_API_KEY'), # In case of Ollama just pass "no-token" api_token=os.getenv("OPENAI_API_KEY"), # In case of Ollama just pass "no-token"
schema=KnowledgeGraph.model_json_schema(), schema=KnowledgeGraph.model_json_schema(),
extraction_type="schema", extraction_type="schema",
instruction="""Extract entities and relationships from the given text.""" instruction="""Extract entities and relationships from the given text.""",
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
url = "https://paulgraham.com/love.html" url = "https://paulgraham.com/love.html"
@@ -554,27 +587,22 @@ async def generate_knowledge_graph():
with open(os.path.join(__location__, "kb.json"), "w") as f: with open(os.path.join(__location__, "kb.json"), "w") as f:
f.write(result.extracted_content) f.write(result.extracted_content)
async def fit_markdown_remove_overlay():
async def fit_markdown_remove_overlay():
async with AsyncWebCrawler( async with AsyncWebCrawler(
headless=True, # Set to False to see what is happening headless=True, # Set to False to see what is happening
verbose=True, verbose=True,
user_agent_mode="random", user_agent_mode="random",
user_agent_generator_config={ user_agent_generator_config={"device_type": "mobile", "os_type": "android"},
"device_type": "mobile",
"os_type": "android"
},
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.kidocode.com/degrees/technology', url="https://www.kidocode.com/degrees/technology",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter( content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0 threshold=0.48, threshold_type="fixed", min_word_threshold=0
), ),
options={ options={"ignore_links": True},
"ignore_links": True
}
), ),
# markdown_generator=DefaultMarkdownGenerator( # markdown_generator=DefaultMarkdownGenerator(
# content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0), # content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0),
@@ -593,13 +621,20 @@ async def fit_markdown_remove_overlay():
with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f: with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f:
f.write(result.cleaned_html) f.write(result.cleaned_html)
with open(os.path.join(__location__, "output/output_raw_markdown.md"), "w") as f: with open(
os.path.join(__location__, "output/output_raw_markdown.md"), "w"
) as f:
f.write(result.markdown_v2.raw_markdown) f.write(result.markdown_v2.raw_markdown)
with open(os.path.join(__location__, "output/output_markdown_with_citations.md"), "w") as f: with open(
os.path.join(__location__, "output/output_markdown_with_citations.md"),
"w",
) as f:
f.write(result.markdown_v2.markdown_with_citations) f.write(result.markdown_v2.markdown_with_citations)
with open(os.path.join(__location__, "output/output_fit_markdown.md"), "w") as f: with open(
os.path.join(__location__, "output/output_fit_markdown.md"), "w"
) as f:
f.write(result.markdown_v2.fit_markdown) f.write(result.markdown_v2.fit_markdown)
print("Done") print("Done")

View File

@@ -10,15 +10,17 @@ from functools import lru_cache
console = Console() console = Console()
@lru_cache() @lru_cache()
def create_crawler(): def create_crawler():
crawler = WebCrawler(verbose=True) crawler = WebCrawler(verbose=True)
crawler.warmup() crawler.warmup()
return crawler return crawler
def print_result(result): def print_result(result):
# Print each key in one line and just the first 10 characters of each one's value and three dots # Print each key in one line and just the first 10 characters of each one's value and three dots
console.print(f"\t[bold]Result:[/bold]") console.print("\t[bold]Result:[/bold]")
for key, value in result.model_dump().items(): for key, value in result.model_dump().items():
if isinstance(value, str) and value: if isinstance(value, str) and value:
console.print(f"\t{key}: [green]{value[:20]}...[/green]") console.print(f"\t{key}: [green]{value[:20]}...[/green]")
@@ -33,18 +35,27 @@ def cprint(message, press_any_key=False):
console.print("Press any key to continue...", style="") console.print("Press any key to continue...", style="")
input() input()
def basic_usage(crawler): def basic_usage(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") cprint(
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
)
result = crawler.run(url="https://www.nbcnews.com/business", only_text=True) result = crawler.run(url="https://www.nbcnews.com/business", only_text=True)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result) print_result(result)
def basic_usage_some_params(crawler): def basic_usage_some_params(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") cprint(
result = crawler.run(url="https://www.nbcnews.com/business", word_count_threshold=1, only_text = True) "🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
)
result = crawler.run(
url="https://www.nbcnews.com/business", word_count_threshold=1, only_text=True
)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result) print_result(result)
def screenshot_usage(crawler): def screenshot_usage(crawler):
cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]") cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True) result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
@@ -55,16 +66,23 @@ def screenshot_usage(crawler):
cprint("Screenshot saved to 'screenshot.png'!") cprint("Screenshot saved to 'screenshot.png'!")
print_result(result) print_result(result)
def understanding_parameters(crawler): def understanding_parameters(crawler):
cprint("\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]") cprint(
cprint("By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action.") "\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]"
)
cprint(
"By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action."
)
# First crawl (reads from cache) # First crawl (reads from cache)
cprint("1⃣ First crawl (caches the result):", True) cprint("1⃣ First crawl (caches the result):", True)
start_time = time.time() start_time = time.time()
result = crawler.run(url="https://www.nbcnews.com/business") result = crawler.run(url="https://www.nbcnews.com/business")
end_time = time.time() end_time = time.time()
cprint(f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]") cprint(
f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]"
)
print_result(result) print_result(result)
# Force to crawl again # Force to crawl again
@@ -72,132 +90,194 @@ def understanding_parameters(crawler):
start_time = time.time() start_time = time.time()
result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True) result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True)
end_time = time.time() end_time = time.time()
cprint(f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]") cprint(
f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]"
)
print_result(result) print_result(result)
def add_chunking_strategy(crawler): def add_chunking_strategy(crawler):
# Adding a chunking strategy: RegexChunking # Adding a chunking strategy: RegexChunking
cprint("\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", True) cprint(
cprint("RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!") "\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]",
True,
)
cprint(
"RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
chunking_strategy=RegexChunking(patterns=["\n\n"]) chunking_strategy=RegexChunking(patterns=["\n\n"]),
) )
cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]")
print_result(result) print_result(result)
# Adding another chunking strategy: NlpSentenceChunking # Adding another chunking strategy: NlpSentenceChunking
cprint("\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", True) cprint(
cprint("NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!") "\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]",
True,
)
cprint(
"NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business", chunking_strategy=NlpSentenceChunking()
chunking_strategy=NlpSentenceChunking()
) )
cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]")
print_result(result) print_result(result)
def add_extraction_strategy(crawler): def add_extraction_strategy(crawler):
# Adding an extraction strategy: CosineStrategy # Adding an extraction strategy: CosineStrategy
cprint("\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", True) cprint(
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!") "\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]",
True,
)
cprint(
"CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True) extraction_strategy=CosineStrategy(
word_count_threshold=10,
max_dist=0.2,
linkage_method="ward",
top_k=3,
sim_threshold=0.3,
verbose=True,
),
) )
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
print_result(result) print_result(result)
# Using semantic_filter with CosineStrategy # Using semantic_filter with CosineStrategy
cprint("You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!") cprint(
"You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy( extraction_strategy=CosineStrategy(
semantic_filter="inflation rent prices", semantic_filter="inflation rent prices",
),
) )
cprint(
"[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]")
print_result(result) print_result(result)
def add_llm_extraction_strategy(crawler): def add_llm_extraction_strategy(crawler):
# Adding an LLM extraction strategy without instructions # Adding an LLM extraction strategy without instructions
cprint("\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", True) cprint(
cprint("LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!") "\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]",
True,
)
cprint(
"LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-4o", api_token=os.getenv('OPENAI_API_KEY')) extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", api_token=os.getenv("OPENAI_API_KEY")
),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]")
print_result(result) print_result(result)
# Adding an LLM extraction strategy with instructions # Adding an LLM extraction strategy with instructions
cprint("\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", True) cprint(
cprint("Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!") "\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]",
True,
)
cprint(
"Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", provider="openai/gpt-4o",
api_token=os.getenv('OPENAI_API_KEY'), api_token=os.getenv("OPENAI_API_KEY"),
instruction="I am interested in only financial news" instruction="I am interested in only financial news",
),
) )
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]")
print_result(result) print_result(result)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", provider="openai/gpt-4o",
api_token=os.getenv('OPENAI_API_KEY'), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract only content related to technology" instruction="Extract only content related to technology",
),
) )
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]")
print_result(result) print_result(result)
def targeted_extraction(crawler): def targeted_extraction(crawler):
# Using a CSS selector to extract only H2 tags # Using a CSS selector to extract only H2 tags
cprint("\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", True) cprint(
result = crawler.run( "\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]",
url="https://www.nbcnews.com/business", True,
css_selector="h2"
) )
result = crawler.run(url="https://www.nbcnews.com/business", css_selector="h2")
cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]")
print_result(result) print_result(result)
def interactive_extraction(crawler): def interactive_extraction(crawler):
# Passing JavaScript code to interact with the page # Passing JavaScript code to interact with the page
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True) cprint(
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.") "\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
True,
)
cprint(
"In this example we try to click the 'Load More' button on the page using JavaScript code."
)
js_code = """ js_code = """
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click(); loadMoreButton && loadMoreButton.click();
""" """
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code) # crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True) # crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
result = crawler.run( result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
url="https://www.nbcnews.com/business", cprint(
js = js_code "[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
print_result(result) print_result(result)
def multiple_scrip(crawler): def multiple_scrip(crawler):
# Passing JavaScript code to interact with the page # Passing JavaScript code to interact with the page
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True) cprint(
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.") "\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
js_code = [""" True,
)
cprint(
"In this example we try to click the 'Load More' button on the page using JavaScript code."
)
js_code = [
"""
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click(); loadMoreButton && loadMoreButton.click();
"""] * 2 """
] * 2
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code) # crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True) # crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
result = crawler.run( result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
url="https://www.nbcnews.com/business", cprint(
js = js_code "[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
print_result(result) print_result(result)
def using_crawler_hooks(crawler): def using_crawler_hooks(crawler):
# Example usage of the hooks for authentication and setting a cookie # Example usage of the hooks for authentication and setting a cookie
def on_driver_created(driver): def on_driver_created(driver):
@@ -206,33 +286,34 @@ def using_crawler_hooks(crawler):
driver.maximize_window() driver.maximize_window()
# Example customization: logging in to a hypothetical website # Example customization: logging in to a hypothetical website
driver.get('https://example.com/login') driver.get("https://example.com/login")
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.NAME, 'username')) EC.presence_of_element_located((By.NAME, "username"))
) )
driver.find_element(By.NAME, 'username').send_keys('testuser') driver.find_element(By.NAME, "username").send_keys("testuser")
driver.find_element(By.NAME, 'password').send_keys('password123') driver.find_element(By.NAME, "password").send_keys("password123")
driver.find_element(By.NAME, 'login').click() driver.find_element(By.NAME, "login").click()
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'welcome')) EC.presence_of_element_located((By.ID, "welcome"))
) )
# Add a custom cookie # Add a custom cookie
driver.add_cookie({'name': 'test_cookie', 'value': 'cookie_value'}) driver.add_cookie({"name": "test_cookie", "value": "cookie_value"})
return driver return driver
def before_get_url(driver): def before_get_url(driver):
print("[HOOK] before_get_url") print("[HOOK] before_get_url")
# Example customization: add a custom header # Example customization: add a custom header
# Enable Network domain for sending headers # Enable Network domain for sending headers
driver.execute_cdp_cmd('Network.enable', {}) driver.execute_cdp_cmd("Network.enable", {})
# Add a custom header # Add a custom header
driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': {'X-Test-Header': 'test'}}) driver.execute_cdp_cmd(
"Network.setExtraHTTPHeaders", {"headers": {"X-Test-Header": "test"}}
)
return driver return driver
def after_get_url(driver): def after_get_url(driver):
@@ -247,13 +328,16 @@ def using_crawler_hooks(crawler):
print(len(html)) print(len(html))
return driver return driver
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", True) cprint(
"\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]",
True,
)
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True) crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
crawler_strategy.set_hook('on_driver_created', on_driver_created) crawler_strategy.set_hook("on_driver_created", on_driver_created)
crawler_strategy.set_hook('before_get_url', before_get_url) crawler_strategy.set_hook("before_get_url", before_get_url)
crawler_strategy.set_hook('after_get_url', after_get_url) crawler_strategy.set_hook("after_get_url", after_get_url)
crawler_strategy.set_hook('before_return_html', before_return_html) crawler_strategy.set_hook("before_return_html", before_return_html)
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy) crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
crawler.warmup() crawler.warmup()
@@ -262,6 +346,7 @@ def using_crawler_hooks(crawler):
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
print_result(result=result) print_result(result=result)
def using_crawler_hooks_dleay_example(crawler): def using_crawler_hooks_dleay_example(crawler):
def delay(driver): def delay(driver):
print("Delaying for 5 seconds...") print("Delaying for 5 seconds...")
@@ -270,12 +355,14 @@ def using_crawler_hooks_dleay_example(crawler):
def create_crawler(): def create_crawler():
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True) crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
crawler_strategy.set_hook('after_get_url', delay) crawler_strategy.set_hook("after_get_url", delay)
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy) crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
crawler.warmup() crawler.warmup()
return crawler return crawler
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]") cprint(
"\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]"
)
crawler = create_crawler() crawler = create_crawler()
result = crawler.run(url="https://google.com", bypass_cache=True) result = crawler.run(url="https://google.com", bypass_cache=True)
@@ -283,11 +370,16 @@ def using_crawler_hooks_dleay_example(crawler):
print_result(result) print_result(result)
def main(): def main():
cprint("🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]") cprint(
cprint("⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]") "🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]"
cprint("If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files.") )
cprint(
"⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]"
)
cprint(
"If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files."
)
crawler = create_crawler() crawler = create_crawler()
@@ -305,8 +397,10 @@ def main():
interactive_extraction(crawler) interactive_extraction(crawler)
multiple_scrip(crawler) multiple_scrip(crawler)
cprint("\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]") cprint(
"\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -11,7 +11,9 @@ from groq import Groq
# Import threadpools to run the crawl_url function in a separate thread # Import threadpools to run the crawl_url function in a separate thread
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
client = AsyncOpenAI(base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")) client = AsyncOpenAI(
base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")
)
# Instrument the OpenAI client # Instrument the OpenAI client
cl.instrument_openai() cl.instrument_openai()
@@ -25,32 +27,31 @@ settings = {
"presence_penalty": 0, "presence_penalty": 0,
} }
def extract_urls(text): def extract_urls(text):
url_pattern = re.compile(r'(https?://\S+)') url_pattern = re.compile(r"(https?://\S+)")
return url_pattern.findall(text) return url_pattern.findall(text)
def crawl_url(url): def crawl_url(url):
data = { data = {
"urls": [url], "urls": [url],
"include_raw_html": True, "include_raw_html": True,
"word_count_threshold": 10, "word_count_threshold": 10,
"extraction_strategy": "NoExtractionStrategy", "extraction_strategy": "NoExtractionStrategy",
"chunking_strategy": "RegexChunking" "chunking_strategy": "RegexChunking",
} }
response = requests.post("https://crawl4ai.com/crawl", json=data) response = requests.post("https://crawl4ai.com/crawl", json=data)
response_data = response.json() response_data = response.json()
response_data = response_data['results'][0] response_data = response_data["results"][0]
return response_data['markdown'] return response_data["markdown"]
@cl.on_chat_start @cl.on_chat_start
async def on_chat_start(): async def on_chat_start():
cl.user_session.set("session", { cl.user_session.set("session", {"history": [], "context": {}})
"history": [], await cl.Message(content="Welcome to the chat! How can I assist you today?").send()
"context": {}
})
await cl.Message(
content="Welcome to the chat! How can I assist you today?"
).send()
@cl.on_message @cl.on_message
async def on_message(message: cl.Message): async def on_message(message: cl.Message):
@@ -59,7 +60,6 @@ async def on_message(message: cl.Message):
# Extract URLs from the user's message # Extract URLs from the user's message
urls = extract_urls(message.content) urls = extract_urls(message.content)
futures = [] futures = []
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
for url in urls: for url in urls:
@@ -69,16 +69,9 @@ async def on_message(message: cl.Message):
for url, result in zip(urls, results): for url, result in zip(urls, results):
ref_number = f"REF_{len(user_session['context']) + 1}" ref_number = f"REF_{len(user_session['context']) + 1}"
user_session["context"][ref_number] = { user_session["context"][ref_number] = {"url": url, "content": result}
"url": url,
"content": result
}
user_session["history"].append({"role": "user", "content": message.content})
user_session["history"].append({
"role": "user",
"content": message.content
})
# Create a system message that includes the context # Create a system message that includes the context
context_messages = [ context_messages = [
@@ -95,26 +88,17 @@ async def on_message(message: cl.Message):
"If not, there is no need to add a references section. " "If not, there is no need to add a references section. "
"At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n" "At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n"
"\n\n".join(context_messages) "\n\n".join(context_messages)
) ),
} }
else: else:
system_message = { system_message = {"role": "system", "content": "You are a helpful assistant."}
"role": "system",
"content": "You are a helpful assistant."
}
msg = cl.Message(content="") msg = cl.Message(content="")
await msg.send() await msg.send()
# Get response from the LLM # Get response from the LLM
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
messages=[ messages=[system_message, *user_session["history"]], stream=True, **settings
system_message,
*user_session["history"]
],
stream=True,
**settings
) )
assistant_response = "" assistant_response = ""
@@ -124,10 +108,7 @@ async def on_message(message: cl.Message):
await msg.stream_token(token) await msg.stream_token(token)
# Add assistant message to the history # Add assistant message to the history
user_session["history"].append({ user_session["history"].append({"role": "assistant", "content": assistant_response})
"role": "assistant",
"content": assistant_response
})
await msg.update() await msg.update()
# Append the reference section to the assistant's response # Append the reference section to the assistant's response
@@ -154,6 +135,7 @@ async def on_audio_chunk(chunk: cl.AudioChunk):
pass pass
@cl.step(type="tool") @cl.step(type="tool")
async def speech_to_text(audio_file): async def speech_to_text(audio_file):
cli = Groq() cli = Groq()
@@ -179,17 +161,12 @@ async def on_audio_end(elements: list[ElementBased]):
end_time = time.time() end_time = time.time()
print(f"Transcription took {end_time - start_time} seconds") print(f"Transcription took {end_time - start_time} seconds")
user_msg = cl.Message( user_msg = cl.Message(author="You", type="user_message", content=transcription)
author="You",
type="user_message",
content=transcription
)
await user_msg.send() await user_msg.send()
await on_message(user_msg) await on_message(user_msg)
if __name__ == "__main__": if __name__ == "__main__":
from chainlit.cli import run_chainlit from chainlit.cli import run_chainlit
run_chainlit(__file__) run_chainlit(__file__)

View File

@@ -1,4 +1,3 @@
import requests, base64, os import requests, base64, os
data = { data = {
@@ -7,58 +6,49 @@ data = {
} }
response = requests.post("https://crawl4ai.com/crawl", json=data) response = requests.post("https://crawl4ai.com/crawl", json=data)
result = response.json()['results'][0] result = response.json()["results"][0]
print(result.keys()) print(result.keys())
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media', # dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
# 'links', 'screenshot', 'markdown', 'extracted_content', # 'links', 'screenshot', 'markdown', 'extracted_content',
# 'metadata', 'error_message']) # 'metadata', 'error_message'])
with open("screenshot.png", "wb") as f: with open("screenshot.png", "wb") as f:
f.write(base64.b64decode(result['screenshot'])) f.write(base64.b64decode(result["screenshot"]))
# Example of filtering the content using CSS selectors # Example of filtering the content using CSS selectors
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"css_selector": "article", "css_selector": "article",
"screenshot": True, "screenshot": True,
} }
# Example of executing a JS script on the page before extracting the content # Example of executing a JS script on the page before extracting the content
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"screenshot": True, "screenshot": True,
'js' : [""" "js": [
"""
const loadMoreButton = Array.from(document.querySelectorAll('button')). const loadMoreButton = Array.from(document.querySelectorAll('button')).
find(button => button.textContent.includes('Load More')); find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click(); loadMoreButton && loadMoreButton.click();
"""] """
],
} }
# Example of using a custom extraction strategy # Example of using a custom extraction strategy
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"extraction_strategy": "CosineStrategy", "extraction_strategy": "CosineStrategy",
"extraction_strategy_args": { "extraction_strategy_args": {"semantic_filter": "inflation rent prices"},
"semantic_filter": "inflation rent prices"
},
} }
# Example of using LLM to extract content # Example of using LLM to extract content
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"extraction_strategy": "LLMExtractionStrategy", "extraction_strategy": "LLMExtractionStrategy",
"extraction_strategy_args": { "extraction_strategy_args": {
"provider": "groq/llama3-8b-8192", "provider": "groq/llama3-8b-8192",
"api_token": os.environ.get("GROQ_API_KEY"), "api_token": os.environ.get("GROQ_API_KEY"),
"instruction": """I am interested in only financial news, "instruction": """I am interested in only financial news,
and translate them in French.""" and translate them in French.""",
}, },
} }

View File

@@ -5,22 +5,22 @@ import os
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode
# Create tmp directory if it doesn't exist # Create tmp directory if it doesn't exist
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
tmp_dir = os.path.join(parent_dir, "tmp") tmp_dir = os.path.join(parent_dir, "tmp")
os.makedirs(tmp_dir, exist_ok=True) os.makedirs(tmp_dir, exist_ok=True)
async def main(): async def main():
# Configure crawler to fetch SSL certificate # Configure crawler to fetch SSL certificate
config = CrawlerRunConfig( config = CrawlerRunConfig(
fetch_ssl_certificate=True, fetch_ssl_certificate=True,
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=config)
url='https://example.com',
config=config
)
if result.success and result.ssl_certificate: if result.success and result.ssl_certificate:
cert = result.ssl_certificate cert = result.ssl_certificate
@@ -36,11 +36,16 @@ async def main():
print("\nCertificate exported to:") print("\nCertificate exported to:")
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}") print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers pem_data = cert.to_pem(
os.path.join(tmp_dir, "certificate.pem")
) # For web servers
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}") print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps der_data = cert.to_der(
os.path.join(tmp_dir, "certificate.der")
) # For Java apps
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}") print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,39 +1,41 @@
import os import os
import time
import json import json
from crawl4ai.web_crawler import WebCrawler from crawl4ai.web_crawler import WebCrawler
from crawl4ai.chunking_strategy import * from crawl4ai.chunking_strategy import *
from crawl4ai.extraction_strategy import * from crawl4ai.extraction_strategy import *
from crawl4ai.crawler_strategy import * from crawl4ai.crawler_strategy import *
url = r'https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot' url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot"
crawler = WebCrawler() crawler = WebCrawler()
crawler.warmup() crawler.warmup()
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class PageSummary(BaseModel): class PageSummary(BaseModel):
title: str = Field(..., description="Title of the page.") title: str = Field(..., description="Title of the page.")
summary: str = Field(..., description="Summary of the page.") summary: str = Field(..., description="Summary of the page.")
brief_summary: str = Field(..., description="Brief summary of the page.") brief_summary: str = Field(..., description="Brief summary of the page.")
keywords: list = Field(..., description="Keywords assigned to the page.") keywords: list = Field(..., description="Keywords assigned to the page.")
result = crawler.run( result = crawler.run(
url=url, url=url,
word_count_threshold=1, word_count_threshold=1,
extraction_strategy=LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'), provider="openai/gpt-4o",
api_token=os.getenv("OPENAI_API_KEY"),
schema=PageSummary.model_json_schema(), schema=PageSummary.model_json_schema(),
extraction_type="schema", extraction_type="schema",
apply_chunking=False, apply_chunking=False,
instruction="From the crawled content, extract the following details: "\ instruction="From the crawled content, extract the following details: "
"1. Title of the page "\ "1. Title of the page "
"2. Summary of the page, which is a detailed summary "\ "2. Summary of the page, which is a detailed summary "
"3. Brief summary of the page, which is a paragraph text "\ "3. Brief summary of the page, which is a paragraph text "
"4. Keywords assigned to the page, which is a list of keywords. "\ "4. Keywords assigned to the page, which is a list of keywords. "
'The extracted JSON format should look like this: '\ "The extracted JSON format should look like this: "
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }' '{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }',
), ),
bypass_cache=True, bypass_cache=True,
) )

View File

@@ -1,4 +1,5 @@
import os, sys import os, sys
# append the parent directory to the sys.path # append the parent directory to the sys.path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
@@ -13,6 +14,7 @@ import json
from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.content_filter_strategy import BM25ContentFilter from crawl4ai.content_filter_strategy import BM25ContentFilter
# 1. File Download Processing Example # 1. File Download Processing Example
async def download_example(): async def download_example():
"""Example of downloading files from Python.org""" """Example of downloading files from Python.org"""
@@ -23,9 +25,7 @@ async def download_example():
print(f"Downloads will be saved to: {downloads_path}") print(f"Downloads will be saved to: {downloads_path}")
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True, downloads_path=downloads_path, verbose=True
downloads_path=downloads_path,
verbose=True
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
@@ -40,7 +40,7 @@ async def download_example():
} }
""", """,
delay_before_return_html=1, # Wait 5 seconds to ensure download starts delay_before_return_html=1, # Wait 5 seconds to ensure download starts
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
if result.downloaded_files: if result.downloaded_files:
@@ -52,24 +52,25 @@ async def download_example():
else: else:
print("\nNo files were downloaded") print("\nNo files were downloaded")
# 2. Local File and Raw HTML Processing Example # 2. Local File and Raw HTML Processing Example
async def local_and_raw_html_example(): async def local_and_raw_html_example():
"""Example of processing local files and raw HTML""" """Example of processing local files and raw HTML"""
# Create a sample HTML file # Create a sample HTML file
sample_file = os.path.join(__data__, "sample.html") sample_file = os.path.join(__data__, "sample.html")
with open(sample_file, "w") as f: with open(sample_file, "w") as f:
f.write(""" f.write(
"""
<html><body> <html><body>
<h1>Test Content</h1> <h1>Test Content</h1>
<p>This is a test paragraph.</p> <p>This is a test paragraph.</p>
</body></html> </body></html>
""") """
)
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
# Process local file # Process local file
local_result = await crawler.arun( local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}")
url=f"file://{os.path.abspath(sample_file)}"
)
# Process raw HTML # Process raw HTML
raw_html = """ raw_html = """
@@ -78,9 +79,7 @@ async def local_and_raw_html_example():
<p>This is a test of raw HTML processing.</p> <p>This is a test of raw HTML processing.</p>
</body></html> </body></html>
""" """
raw_result = await crawler.arun( raw_result = await crawler.arun(url=f"raw:{raw_html}")
url=f"raw:{raw_html}"
)
# Clean up # Clean up
os.remove(sample_file) os.remove(sample_file)
@@ -88,6 +87,7 @@ async def local_and_raw_html_example():
print("Local file content:", local_result.markdown) print("Local file content:", local_result.markdown)
print("\nRaw HTML content:", raw_result.markdown) print("\nRaw HTML content:", raw_result.markdown)
# 3. Enhanced Markdown Generation Example # 3. Enhanced Markdown Generation Example
async def markdown_generation_example(): async def markdown_generation_example():
"""Example of enhanced markdown generation with citations and LLM-friendly features""" """Example of enhanced markdown generation with citations and LLM-friendly features"""
@@ -102,27 +102,32 @@ async def markdown_generation_example():
url="https://en.wikipedia.org/wiki/Apple", url="https://en.wikipedia.org/wiki/Apple",
css_selector="main div#bodyContent", css_selector="main div#bodyContent",
content_filter=content_filter, content_filter=content_filter,
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
from crawl4ai import AsyncWebCrawler
from crawl4ai.content_filter_strategy import BM25ContentFilter from crawl4ai.content_filter_strategy import BM25ContentFilter
result = await crawler.arun( result = await crawler.arun(
url="https://en.wikipedia.org/wiki/Apple", url="https://en.wikipedia.org/wiki/Apple",
css_selector="main div#bodyContent", css_selector="main div#bodyContent",
content_filter=BM25ContentFilter() content_filter=BM25ContentFilter(),
) )
print(result.markdown_v2.fit_markdown) print(result.markdown_v2.fit_markdown)
print("\nMarkdown Generation Results:") print("\nMarkdown Generation Results:")
print(f"1. Original markdown length: {len(result.markdown)}") print(f"1. Original markdown length: {len(result.markdown)}")
print(f"2. New markdown versions (markdown_v2):") print("2. New markdown versions (markdown_v2):")
print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}") print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}")
print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}") print(
print(f" - References section length: {len(result.markdown_v2.references_markdown)}") f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}"
)
print(
f" - References section length: {len(result.markdown_v2.references_markdown)}"
)
if result.markdown_v2.fit_markdown: if result.markdown_v2.fit_markdown:
print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}") print(
f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}"
)
# Save examples to files # Save examples to files
output_dir = os.path.join(__data__, "markdown_examples") output_dir = os.path.join(__data__, "markdown_examples")
@@ -148,7 +153,10 @@ async def markdown_generation_example():
print("\nSample of markdown with citations:") print("\nSample of markdown with citations:")
print(result.markdown_v2.markdown_with_citations[:500] + "...\n") print(result.markdown_v2.markdown_with_citations[:500] + "...\n")
print("Sample of references:") print("Sample of references:")
print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...") print(
"\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..."
)
# 4. Browser Management Example # 4. Browser Management Example
async def browser_management_example(): async def browser_management_example():
@@ -163,31 +171,31 @@ async def browser_management_example():
use_managed_browser=True, use_managed_browser=True,
user_data_dir=user_data_dir, user_data_dir=user_data_dir,
headless=False, headless=False,
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://crawl4ai.com", url="https://crawl4ai.com",
# session_id="persistent_session_1", # session_id="persistent_session_1",
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
# Use GitHub as an example - it's a good test for browser management # Use GitHub as an example - it's a good test for browser management
# because it requires proper browser handling # because it requires proper browser handling
result = await crawler.arun( result = await crawler.arun(
url="https://github.com/trending", url="https://github.com/trending",
# session_id="persistent_session_1", # session_id="persistent_session_1",
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
print("\nBrowser session result:", result.success) print("\nBrowser session result:", result.success)
if result.success: if result.success:
print("Page title:", result.metadata.get('title', 'No title found')) print("Page title:", result.metadata.get("title", "No title found"))
# 5. API Usage Example # 5. API Usage Example
async def api_example(): async def api_example():
"""Example of using the new API endpoints""" """Example of using the new API endpoints"""
api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
headers = {'Authorization': f'Bearer {api_token}'} headers = {"Authorization": f"Bearer {api_token}"}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Submit crawl job # Submit crawl job
crawl_request = { crawl_request = {
@@ -199,26 +207,18 @@ async def api_example():
"name": "Hacker News Articles", "name": "Hacker News Articles",
"baseSelector": ".athing", "baseSelector": ".athing",
"fields": [ "fields": [
{ {"name": "title", "selector": ".title a", "type": "text"},
"name": "title", {"name": "score", "selector": ".score", "type": "text"},
"selector": ".title a",
"type": "text"
},
{
"name": "score",
"selector": ".score",
"type": "text"
},
{ {
"name": "url", "name": "url",
"selector": ".title a", "selector": ".title a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
} },
] ],
}
} }
}, },
},
"crawler_params": { "crawler_params": {
"headless": True, "headless": True,
# "use_managed_browser": True # "use_managed_browser": True
@@ -229,9 +229,7 @@ async def api_example():
} }
async with session.post( async with session.post(
"http://localhost:11235/crawl", "http://localhost:11235/crawl", json=crawl_request, headers=headers
json=crawl_request,
headers=headers
) as response: ) as response:
task_data = await response.json() task_data = await response.json()
task_id = task_data["task_id"] task_id = task_data["task_id"]
@@ -239,8 +237,7 @@ async def api_example():
# Check task status # Check task status
while True: while True:
async with session.get( async with session.get(
f"http://localhost:11235/task/{task_id}", f"http://localhost:11235/task/{task_id}", headers=headers
headers=headers
) as status_response: ) as status_response:
result = await status_response.json() result = await status_response.json()
print(f"Task status: {result['status']}") print(f"Task status: {result['status']}")
@@ -248,12 +245,13 @@ async def api_example():
if result["status"] == "completed": if result["status"] == "completed":
print("Task completed!") print("Task completed!")
print("Results:") print("Results:")
news = json.loads(result["results"][0]['extracted_content']) news = json.loads(result["results"][0]["extracted_content"])
print(json.dumps(news[:4], indent=2)) print(json.dumps(news[:4], indent=2))
break break
else: else:
await asyncio.sleep(1) await asyncio.sleep(1)
# Main execution # Main execution
async def main(): async def main():
# print("Running Crawl4AI feature examples...") # print("Running Crawl4AI feature examples...")
@@ -273,5 +271,6 @@ async def main():
# print("\n5. Running API Example:") # print("\n5. Running API Example:")
await api_example() await api_example()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -10,15 +10,14 @@ import asyncio
import os import os
import json import json
import re import re
from typing import List, Optional, Dict, Any from typing import List
from pydantic import BaseModel, Field
from crawl4ai import ( from crawl4ai import (
AsyncWebCrawler, AsyncWebCrawler,
BrowserConfig, BrowserConfig,
CrawlerRunConfig, CrawlerRunConfig,
CacheMode, CacheMode,
LLMExtractionStrategy, LLMExtractionStrategy,
JsonCssExtractionStrategy JsonCssExtractionStrategy,
) )
from crawl4ai.content_filter_strategy import RelevantContentFilter from crawl4ai.content_filter_strategy import RelevantContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
@@ -52,6 +51,7 @@ SAMPLE_HTML = """
</div> </div>
""" """
async def demo_ssl_features(): async def demo_ssl_features():
""" """
Enhanced SSL & Security Features Demo Enhanced SSL & Security Features Demo
@@ -76,14 +76,11 @@ async def demo_ssl_features():
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
fetch_ssl_certificate=True # Enable SSL certificate fetching fetch_ssl_certificate=True, # Enable SSL certificate fetching
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=run_config)
url="https://example.com",
config=run_config
)
print(f"SSL Crawl Success: {result.success}") print(f"SSL Crawl Success: {result.success}")
result.ssl_certificate.to_json( result.ssl_certificate.to_json(
os.path.join(os.getcwd(), "ssl_certificate.json") os.path.join(os.getcwd(), "ssl_certificate.json")
@@ -91,6 +88,7 @@ async def demo_ssl_features():
if not result.success: if not result.success:
print(f"SSL Error: {result.error_message}") print(f"SSL Error: {result.error_message}")
async def demo_content_filtering(): async def demo_content_filtering():
""" """
Smart Content Filtering Demo Smart Content Filtering Demo
@@ -110,12 +108,14 @@ async def demo_content_filtering():
super().__init__() super().__init__()
# Add news-specific patterns # Add news-specific patterns
self.negative_patterns = re.compile( self.negative_patterns = re.compile(
r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending', r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending",
re.I re.I,
) )
self.min_word_count = 30 # Higher threshold for news content self.min_word_count = 30 # Higher threshold for news content
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: def filter_content(
self, html: str, min_word_threshold: int = None
) -> List[str]:
""" """
Implements news-specific content filtering logic. Implements news-specific content filtering logic.
@@ -129,14 +129,16 @@ async def demo_content_filtering():
if not html or not isinstance(html, str): if not html or not isinstance(html, str):
return [] return []
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
if not soup.body: if not soup.body:
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml') soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
body = soup.find('body') body = soup.find("body")
# Extract chunks with metadata # Extract chunks with metadata
chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count) chunks = self.extract_text_chunks(
body, min_word_threshold or self.min_word_count
)
# Filter chunks based on news-specific criteria # Filter chunks based on news-specific criteria
filtered_chunks = [] filtered_chunks = []
@@ -146,7 +148,7 @@ async def demo_content_filtering():
continue continue
# Headers are important in news articles # Headers are important in news articles
if tag_type == 'header': if tag_type == "header":
filtered_chunks.append(self.clean_element(element)) filtered_chunks.append(self.clean_element(element))
continue continue
@@ -154,7 +156,9 @@ async def demo_content_filtering():
text = element.get_text(strip=True) text = element.get_text(strip=True)
if len(text.split()) >= (min_word_threshold or self.min_word_count): if len(text.split()) >= (min_word_threshold or self.min_word_count):
# Calculate link density # Calculate link density
links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a')) links_text = " ".join(
a.get_text(strip=True) for a in element.find_all("a")
)
link_density = len(links_text) / len(text) if text else 1 link_density = len(links_text) / len(text) if text else 1
# Accept if link density is reasonable # Accept if link density is reasonable
@@ -164,23 +168,20 @@ async def demo_content_filtering():
return filtered_chunks return filtered_chunks
# Create markdown generator with custom filter # Create markdown generator with custom filter
markdown_gen = DefaultMarkdownGenerator( markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter())
content_filter=CustomNewsFilter()
)
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
markdown_generator=markdown_gen, markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://news.ycombinator.com", url="https://news.ycombinator.com", config=run_config
config=run_config
) )
print("Filtered Content Sample:") print("Filtered Content Sample:")
print(result.markdown[:500]) # Show first 500 chars print(result.markdown[:500]) # Show first 500 chars
async def demo_json_extraction(): async def demo_json_extraction():
""" """
Improved JSON Extraction Demo Improved JSON Extraction Demo
@@ -206,7 +207,7 @@ async def demo_json_extraction():
"baseSelector": "div.article-list", "baseSelector": "div.article-list",
"baseFields": [ "baseFields": [
{"name": "list_id", "type": "attribute", "attribute": "data-list-id"}, {"name": "list_id", "type": "attribute", "attribute": "data-list-id"},
{"name": "category", "type": "attribute", "attribute": "data-category"} {"name": "category", "type": "attribute", "attribute": "data-category"},
], ],
"fields": [ "fields": [
{ {
@@ -214,8 +215,16 @@ async def demo_json_extraction():
"selector": "article.post", "selector": "article.post",
"type": "nested_list", "type": "nested_list",
"baseFields": [ "baseFields": [
{"name": "post_id", "type": "attribute", "attribute": "data-post-id"}, {
{"name": "author_id", "type": "attribute", "attribute": "data-author"} "name": "post_id",
"type": "attribute",
"attribute": "data-post-id",
},
{
"name": "author_id",
"type": "attribute",
"attribute": "data-author",
},
], ],
"fields": [ "fields": [
{ {
@@ -223,51 +232,59 @@ async def demo_json_extraction():
"selector": "h2.title a", "selector": "h2.title a",
"type": "text", "type": "text",
"baseFields": [ "baseFields": [
{"name": "url", "type": "attribute", "attribute": "href"} {
] "name": "url",
"type": "attribute",
"attribute": "href",
}
],
}, },
{ {
"name": "author", "name": "author",
"selector": "div.meta a.author", "selector": "div.meta a.author",
"type": "text", "type": "text",
"baseFields": [ "baseFields": [
{"name": "profile_url", "type": "attribute", "attribute": "href"}
]
},
{ {
"name": "date", "name": "profile_url",
"selector": "span.date", "type": "attribute",
"type": "text" "attribute": "href",
}
],
}, },
{"name": "date", "selector": "span.date", "type": "text"},
{ {
"name": "read_more", "name": "read_more",
"selector": "a.read-more", "selector": "a.read-more",
"type": "nested", "type": "nested",
"fields": [ "fields": [
{"name": "text", "type": "text"}, {"name": "text", "type": "text"},
{"name": "url", "type": "attribute", "attribute": "href"} {
] "name": "url",
"type": "attribute",
"attribute": "href",
},
],
},
],
} }
] ],
}
]
} }
) )
# Demonstrate extraction from raw HTML # Demonstrate extraction from raw HTML
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
extraction_strategy=json_strategy, extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML
config=run_config config=run_config,
) )
print("Extracted Content:") print("Extracted Content:")
print(result.extracted_content) print(result.extracted_content)
async def demo_input_formats(): async def demo_input_formats():
""" """
Input Format Handling Demo Input Format Handling Demo
@@ -359,18 +376,30 @@ async def demo_input_formats():
# Define our schema using Pydantic # Define our schema using Pydantic
class JobRequirement(BaseModel): class JobRequirement(BaseModel):
category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)") category: str = Field(
items: List[str] = Field(description="List of specific requirements in this category") description="Category of the requirement (e.g., Technical, Soft Skills)"
priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context") )
items: List[str] = Field(
description="List of specific requirements in this category"
)
priority: str = Field(
description="Priority level (Required/Preferred) based on the HTML class or context"
)
class JobPosting(BaseModel): class JobPosting(BaseModel):
title: str = Field(description="Job title") title: str = Field(description="Job title")
department: str = Field(description="Department or team") department: str = Field(description="Department or team")
location: str = Field(description="Job location, including remote options") location: str = Field(description="Job location, including remote options")
salary_range: Optional[str] = Field(description="Salary range if specified") salary_range: Optional[str] = Field(description="Salary range if specified")
requirements: List[JobRequirement] = Field(description="Categorized job requirements") requirements: List[JobRequirement] = Field(
application_deadline: Optional[str] = Field(description="Application deadline if specified") description="Categorized job requirements"
contact_info: Optional[dict] = Field(description="Contact information from footer or contact section") )
application_deadline: Optional[str] = Field(
description="Application deadline if specified"
)
contact_info: Optional[dict] = Field(
description="Contact information from footer or contact section"
)
# First try with markdown (default) # First try with markdown (default)
markdown_strategy = LLMExtractionStrategy( markdown_strategy = LLMExtractionStrategy(
@@ -382,7 +411,7 @@ async def demo_input_formats():
Extract job posting details into structured data. Focus on the visible text content Extract job posting details into structured data. Focus on the visible text content
and organize requirements into categories. and organize requirements into categories.
""", """,
input_format="markdown" # default input_format="markdown", # default
) )
# Then with HTML for better structure understanding # Then with HTML for better structure understanding
@@ -400,34 +429,25 @@ async def demo_input_formats():
Use HTML attributes and classes to enhance extraction accuracy. Use HTML attributes and classes to enhance extraction accuracy.
""", """,
input_format="html" # explicitly use HTML input_format="html", # explicitly use HTML
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
# Try with markdown first # Try with markdown first
markdown_config = CrawlerRunConfig( markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy)
extraction_strategy=markdown_strategy markdown_result = await crawler.arun(url=url, config=markdown_config)
)
markdown_result = await crawler.arun(
url=url,
config=markdown_config
)
print("\nMarkdown-based Extraction Result:") print("\nMarkdown-based Extraction Result:")
items = json.loads(markdown_result.extracted_content) items = json.loads(markdown_result.extracted_content)
print(json.dumps(items, indent=2)) print(json.dumps(items, indent=2))
# Then with HTML for better structure understanding # Then with HTML for better structure understanding
html_config = CrawlerRunConfig( html_config = CrawlerRunConfig(extraction_strategy=html_strategy)
extraction_strategy=html_strategy html_result = await crawler.arun(url=url, config=html_config)
)
html_result = await crawler.arun(
url=url,
config=html_config
)
print("\nHTML-based Extraction Result:") print("\nHTML-based Extraction Result:")
items = json.loads(html_result.extracted_content) items = json.loads(html_result.extracted_content)
print(json.dumps(items, indent=2)) print(json.dumps(items, indent=2))
# Main execution # Main execution
async def main(): async def main():
print("Crawl4AI v0.4.24 Feature Walkthrough") print("Crawl4AI v0.4.24 Feature Walkthrough")
@@ -439,5 +459,6 @@ async def main():
await demo_json_extraction() await demo_json_extraction()
# await demo_input_formats() # await demo_input_formats()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

78
main.py
View File

@@ -1,14 +1,9 @@
import asyncio, os import asyncio, os
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import FileResponse
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends, Security from fastapi import Depends, Security
@@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union
import psutil import psutil
import time import time
import uuid import uuid
from collections import defaultdict
from urllib.parse import urlparse
import math import math
import logging import logging
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
import json
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
from crawl4ai.config import MIN_WORD_THRESHOLD from crawl4ai.config import MIN_WORD_THRESHOLD
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
@@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"
COMPLETED = "completed" COMPLETED = "completed"
FAILED = "failed" FAILED = "failed"
class CrawlerType(str, Enum): class CrawlerType(str, Enum):
BASIC = "basic" BASIC = "basic"
LLM = "llm" LLM = "llm"
COSINE = "cosine" COSINE = "cosine"
JSON_CSS = "json_css" JSON_CSS = "json_css"
class ExtractionConfig(BaseModel): class ExtractionConfig(BaseModel):
type: CrawlerType type: CrawlerType
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
class ChunkingStrategy(BaseModel): class ChunkingStrategy(BaseModel):
type: str type: str
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
class ContentFilter(BaseModel): class ContentFilter(BaseModel):
type: str = "bm25" type: str = "bm25"
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
class CrawlRequest(BaseModel): class CrawlRequest(BaseModel):
urls: Union[HttpUrl, List[HttpUrl]] urls: Union[HttpUrl, List[HttpUrl]]
word_count_threshold: int = MIN_WORD_THRESHOLD word_count_threshold: int = MIN_WORD_THRESHOLD
@@ -80,6 +78,7 @@ class CrawlRequest(BaseModel):
ttl: Optional[int] = 3600 ttl: Optional[int] = 3600
crawler_params: Dict[str, Any] = {} crawler_params: Dict[str, Any] = {}
@dataclass @dataclass
class TaskInfo: class TaskInfo:
id: str id: str
@@ -89,6 +88,7 @@ class TaskInfo:
created_at: float = time.time() created_at: float = time.time()
ttl: int = 3600 ttl: int = 3600
class ResourceMonitor: class ResourceMonitor:
def __init__(self, max_concurrent_tasks: int = 10): def __init__(self, max_concurrent_tasks: int = 10):
self.max_concurrent_tasks = max_concurrent_tasks self.max_concurrent_tasks = max_concurrent_tasks
@@ -106,7 +106,9 @@ class ResourceMonitor:
mem_usage = psutil.virtual_memory().percent / 100 mem_usage = psutil.virtual_memory().percent / 100
cpu_usage = psutil.cpu_percent() / 100 cpu_usage = psutil.cpu_percent() / 100
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold) memory_factor = max(
0, (self.memory_threshold - mem_usage) / self.memory_threshold
)
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
self._last_available_slots = math.floor( self._last_available_slots = math.floor(
@@ -116,6 +118,7 @@ class ResourceMonitor:
return self._last_available_slots return self._last_available_slots
class TaskManager: class TaskManager:
def __init__(self, cleanup_interval: int = 300): def __init__(self, cleanup_interval: int = 300):
self.tasks: Dict[str, TaskInfo] = {} self.tasks: Dict[str, TaskInfo] = {}
@@ -149,12 +152,16 @@ class TaskManager:
except asyncio.TimeoutError: except asyncio.TimeoutError:
try: try:
# Then try low priority # Then try low priority
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1) _, task_id = await asyncio.wait_for(
self.low_priority.get(), timeout=0.1
)
return task_id return task_id
except asyncio.TimeoutError: except asyncio.TimeoutError:
return None return None
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None): def update_task(
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
):
if task_id in self.tasks: if task_id in self.tasks:
task_info = self.tasks[task_id] task_info = self.tasks[task_id]
task_info.status = status task_info.status = status
@@ -180,6 +187,7 @@ class TaskManager:
except Exception as e: except Exception as e:
logger.error(f"Error in cleanup loop: {e}") logger.error(f"Error in cleanup loop: {e}")
class CrawlerPool: class CrawlerPool:
def __init__(self, max_size: int = 10): def __init__(self, max_size: int = 10):
self.max_size = max_size self.max_size = max_size
@@ -222,6 +230,7 @@ class CrawlerPool:
await crawler.__aexit__(None, None, None) await crawler.__aexit__(None, None, None)
self.active_crawlers.clear() self.active_crawlers.clear()
class CrawlerService: class CrawlerService:
def __init__(self, max_concurrent_tasks: int = 10): def __init__(self, max_concurrent_tasks: int = 10):
self.resource_monitor = ResourceMonitor(max_concurrent_tasks) self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
@@ -287,7 +296,9 @@ class CrawlerService:
try: try:
crawler = await self.crawler_pool.acquire(**request.crawler_params) crawler = await self.crawler_pool.acquire(**request.crawler_params)
extraction_strategy = self._create_extraction_strategy(request.extraction_config) extraction_strategy = self._create_extraction_strategy(
request.extraction_config
)
if isinstance(request.urls, list): if isinstance(request.urls, list):
results = await crawler.arun_many( results = await crawler.arun_many(
@@ -318,16 +329,21 @@ class CrawlerService:
) )
await self.crawler_pool.release(crawler) await self.crawler_pool.release(crawler)
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results) self.task_manager.update_task(
task_id, TaskStatus.COMPLETED, results
)
except Exception as e: except Exception as e:
logger.error(f"Error processing task {task_id}: {str(e)}") logger.error(f"Error processing task {task_id}: {str(e)}")
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e)) self.task_manager.update_task(
task_id, TaskStatus.FAILED, error=str(e)
)
except Exception as e: except Exception as e:
logger.error(f"Error in queue processing: {str(e)}") logger.error(f"Error in queue processing: {str(e)}")
await asyncio.sleep(1) await asyncio.sleep(1)
app = FastAPI(title="Crawl4AI API") app = FastAPI(title="Crawl4AI API")
# CORS configuration # CORS configuration
@@ -344,6 +360,7 @@ app.add_middleware(
security = HTTPBearer() security = HTTPBearer()
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN") CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
if not CRAWL4AI_API_TOKEN: if not CRAWL4AI_API_TOKEN:
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
@@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu
raise HTTPException(status_code=401, detail="Invalid token") raise HTTPException(status_code=401, detail="Invalid token")
return credentials return credentials
def secure_endpoint(): def secure_endpoint():
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set""" """Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
# Check if site directory exists # Check if site directory exists
if os.path.exists(__location__ + "/site"): if os.path.exists(__location__ + "/site"):
# Mount the site directory as a static directory # Mount the site directory as a static directory
@@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site")
crawler_service = CrawlerService() crawler_service = CrawlerService()
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
await crawler_service.start() await crawler_service.start()
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
await crawler_service.stop() await crawler_service.stop()
@app.get("/") @app.get("/")
def read_root(): def read_root():
if os.path.exists(__location__ + "/site"): if os.path.exists(__location__ + "/site"):
@@ -379,12 +401,16 @@ def read_root():
# Return a json response # Return a json response
return {"message": "Crawl4AI API service is running"} return {"message": "Crawl4AI API service is running"}
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) @app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl(request: CrawlRequest) -> Dict[str, str]: async def crawl(request: CrawlRequest) -> Dict[str, str]:
task_id = await crawler_service.submit_task(request) task_id = await crawler_service.submit_task(request)
return {"task_id": task_id} return {"task_id": task_id}
@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
@app.get(
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def get_task_status(task_id: str): async def get_task_status(task_id: str):
task_info = crawler_service.task_manager.get_task(task_id) task_info = crawler_service.task_manager.get_task(task_id)
if not task_info: if not task_info:
@@ -406,6 +432,7 @@ async def get_task_status(task_id: str):
return response return response
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) @app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
task_id = await crawler_service.submit_task(request) task_id = await crawler_service.submit_task(request)
@@ -419,7 +446,10 @@ async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
if task_info.status == TaskStatus.COMPLETED: if task_info.status == TaskStatus.COMPLETED:
# Return same format as /task/{task_id} endpoint # Return same format as /task/{task_id} endpoint
if isinstance(task_info.result, list): if isinstance(task_info.result, list):
return {"status": task_info.status, "results": [result.dict() for result in task_info.result]} return {
"status": task_info.status,
"results": [result.dict() for result in task_info.result],
}
return {"status": task_info.status, "result": task_info.result.dict()} return {"status": task_info.status, "result": task_info.result.dict()}
if task_info.status == TaskStatus.FAILED: if task_info.status == TaskStatus.FAILED:
@@ -430,11 +460,16 @@ async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
# If we get here, task didn't complete within timeout # If we get here, task didn't complete within timeout
raise HTTPException(status_code=408, detail="Task timed out") raise HTTPException(status_code=408, detail="Task timed out")
@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
@app.post(
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
try: try:
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config) extraction_strategy = crawler_service._create_extraction_strategy(
request.extraction_config
)
try: try:
if isinstance(request.urls, list): if isinstance(request.urls, list):
@@ -471,6 +506,7 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
logger.error(f"Error in direct crawl: {str(e)}") logger.error(f"Error in direct crawl: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
available_slots = await crawler_service.resource_monitor.get_available_slots() available_slots = await crawler_service.resource_monitor.get_available_slots()
@@ -482,6 +518,8 @@ async def health_check():
"cpu_usage": psutil.cpu_percent(), "cpu_usage": psutil.cpu_percent(),
} }
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=11235) uvicorn.run(app, host="0.0.0.0", port=11235)

View File

@@ -51,9 +51,7 @@ setup(
author_email="unclecode@kidocode.com", author_email="unclecode@kidocode.com",
license="MIT", license="MIT",
packages=find_packages(), packages=find_packages(),
package_data={ package_data={"crawl4ai": ["js_snippet/*.js"]},
'crawl4ai': ['js_snippet/*.js']
},
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -1,17 +1,18 @@
import os, sys import os
import sys
import asyncio
from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
import os, sys
import asyncio
from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Assuming that the changes made allow different configurations # Assuming that the changes made allow different configurations
# for managed browser, persistent context, and so forth. # for managed browser, persistent context, and so forth.
async def test_default_headless(): async def test_default_headless():
async with AsyncWebCrawler( async with AsyncWebCrawler(
headless=True, headless=True,
@@ -24,13 +25,14 @@ async def test_default_headless():
# Testing normal ephemeral context # Testing normal ephemeral context
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.kidocode.com/degrees/technology', url="https://www.kidocode.com/degrees/technology",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_default_headless] success:", result.success) print("[test_default_headless] success:", result.success)
print("HTML length:", len(result.html if result.html else "")) print("HTML length:", len(result.html if result.html else ""))
async def test_managed_browser_persistent(): async def test_managed_browser_persistent():
# Treating use_persistent_context=True as managed_browser scenario. # Treating use_persistent_context=True as managed_browser scenario.
async with AsyncWebCrawler( async with AsyncWebCrawler(
@@ -44,13 +46,14 @@ async def test_managed_browser_persistent():
# This should store and reuse profile data across runs # This should store and reuse profile data across runs
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.google.com', url="https://www.google.com",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_managed_browser_persistent] success:", result.success) print("[test_managed_browser_persistent] success:", result.success)
print("HTML length:", len(result.html if result.html else "")) print("HTML length:", len(result.html if result.html else ""))
async def test_session_reuse(): async def test_session_reuse():
# Test creating a session, using it for multiple calls # Test creating a session, using it for multiple calls
session_id = "my_session" session_id = "my_session"
@@ -62,25 +65,25 @@ async def test_session_reuse():
use_managed_browser=False, use_managed_browser=False,
use_persistent_context=False, use_persistent_context=False,
) as crawler: ) as crawler:
# First call: create session # First call: create session
result1 = await crawler.arun( result1 = await crawler.arun(
url='https://www.example.com', url="https://www.example.com",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
session_id=session_id, session_id=session_id,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_session_reuse first call] success:", result1.success) print("[test_session_reuse first call] success:", result1.success)
# Second call: same session, possibly cookie retained # Second call: same session, possibly cookie retained
result2 = await crawler.arun( result2 = await crawler.arun(
url='https://www.example.com/about', url="https://www.example.com/about",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
session_id=session_id, session_id=session_id,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_session_reuse second call] success:", result2.success) print("[test_session_reuse second call] success:", result2.success)
async def test_magic_mode(): async def test_magic_mode():
# Test magic mode with override_navigator and simulate_user # Test magic mode with override_navigator and simulate_user
async with AsyncWebCrawler( async with AsyncWebCrawler(
@@ -95,13 +98,14 @@ async def test_magic_mode():
simulate_user=True, simulate_user=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.kidocode.com/degrees/business', url="https://www.kidocode.com/degrees/business",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_magic_mode] success:", result.success) print("[test_magic_mode] success:", result.success)
print("HTML length:", len(result.html if result.html else "")) print("HTML length:", len(result.html if result.html else ""))
async def test_proxy_settings(): async def test_proxy_settings():
# Test with a proxy (if available) to ensure code runs with proxy # Test with a proxy (if available) to ensure code runs with proxy
async with AsyncWebCrawler( async with AsyncWebCrawler(
@@ -113,14 +117,15 @@ async def test_proxy_settings():
use_persistent_context=False, use_persistent_context=False,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://httpbin.org/ip', url="https://httpbin.org/ip",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_proxy_settings] success:", result.success) print("[test_proxy_settings] success:", result.success)
if result.success: if result.success:
print("HTML preview:", result.html[:200] if result.html else "") print("HTML preview:", result.html[:200] if result.html else "")
async def test_ignore_https_errors(): async def test_ignore_https_errors():
# Test ignore HTTPS errors with a self-signed or invalid cert domain # Test ignore HTTPS errors with a self-signed or invalid cert domain
# This is just conceptual, the domain should be one that triggers SSL error. # This is just conceptual, the domain should be one that triggers SSL error.
@@ -134,12 +139,13 @@ async def test_ignore_https_errors():
use_persistent_context=False, use_persistent_context=False,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://self-signed.badssl.com/', url="https://self-signed.badssl.com/",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_ignore_https_errors] success:", result.success) print("[test_ignore_https_errors] success:", result.success)
async def main(): async def main():
print("Running tests...") print("Running tests...")
# await test_default_headless() # await test_default_headless()
@@ -149,5 +155,6 @@ async def main():
# await test_proxy_settings() # await test_proxy_settings()
await test_ignore_https_errors() await test_ignore_https_errors()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,4 +1,5 @@
import os, sys import os, sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@@ -9,7 +10,7 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.chunking_strategy import RegexChunking from crawl4ai.chunking_strategy import RegexChunking
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Category 1: Browser Configuration Tests # Category 1: Browser Configuration Tests
async def test_browser_config_object(): async def test_browser_config_object():
@@ -21,29 +22,31 @@ async def test_browser_config_object():
viewport_height=1080, viewport_height=1080,
use_managed_browser=True, use_managed_browser=True,
user_agent_mode="random", user_agent_mode="random",
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"} user_agent_generator_config={"device_type": "desktop", "os_type": "windows"},
) )
async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler: async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler:
result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS) result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS)
assert result.success, "Browser config crawl failed" assert result.success, "Browser config crawl failed"
assert len(result.html) > 0, "No HTML content retrieved" assert len(result.html) > 0, "No HTML content retrieved"
async def test_browser_performance_config(): async def test_browser_performance_config():
"""Test browser configurations focused on performance""" """Test browser configurations focused on performance"""
browser_config = BrowserConfig( browser_config = BrowserConfig(
text_mode=True, text_mode=True,
light_mode=True, light_mode=True,
extra_args=['--disable-gpu', '--disable-software-rasterizer'], extra_args=["--disable-gpu", "--disable-software-rasterizer"],
ignore_https_errors=True, ignore_https_errors=True,
java_script_enabled=False java_script_enabled=False,
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun('https://example.com') result = await crawler.arun("https://example.com")
assert result.success, "Performance optimized crawl failed" assert result.success, "Performance optimized crawl failed"
assert result.status_code == 200, "Unexpected status code" assert result.status_code == 200, "Unexpected status code"
# Category 2: Content Processing Tests # Category 2: Content Processing Tests
async def test_content_extraction_config(): async def test_content_extraction_config():
"""Test content extraction with various strategies""" """Test content extraction with various strategies"""
@@ -53,24 +56,20 @@ async def test_content_extraction_config():
schema={ schema={
"name": "article", "name": "article",
"baseSelector": "div", "baseSelector": "div",
"fields": [{ "fields": [{"name": "title", "selector": "h1", "type": "text"}],
"name": "title",
"selector": "h1",
"type": "text"
}]
} }
), ),
chunking_strategy=RegexChunking(), chunking_strategy=RegexChunking(),
content_filter=PruningContentFilter() content_filter=PruningContentFilter(),
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
'https://example.com/article', "https://example.com/article", config=crawler_config
config=crawler_config
) )
assert result.extracted_content is not None, "Content extraction failed" assert result.extracted_content is not None, "Content extraction failed"
assert 'title' in result.extracted_content, "Missing expected content field" assert "title" in result.extracted_content, "Missing expected content field"
# Category 3: Cache and Session Management Tests # Category 3: Cache and Session Management Tests
async def test_cache_and_session_management(): async def test_cache_and_session_management():
@@ -79,25 +78,20 @@ async def test_cache_and_session_management():
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.WRITE_ONLY, cache_mode=CacheMode.WRITE_ONLY,
process_iframes=True, process_iframes=True,
remove_overlay_elements=True remove_overlay_elements=True,
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
# First request - should write to cache # First request - should write to cache
result1 = await crawler.arun( result1 = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
# Second request - should use fresh fetch due to WRITE_ONLY mode # Second request - should use fresh fetch due to WRITE_ONLY mode
result2 = await crawler.arun( result2 = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
assert result1.success and result2.success, "Cache mode crawl failed" assert result1.success and result2.success, "Cache mode crawl failed"
assert result1.html == result2.html, "Inconsistent results between requests" assert result1.html == result2.html, "Inconsistent results between requests"
# Category 4: Media Handling Tests # Category 4: Media Handling Tests
async def test_media_handling_config(): async def test_media_handling_config():
"""Test configurations related to media handling""" """Test configurations related to media handling"""
@@ -107,24 +101,22 @@ async def test_media_handling_config():
viewport_width=1920, viewport_width=1920,
viewport_height=1080, viewport_height=1080,
accept_downloads=True, accept_downloads=True,
downloads_path= os.path.expanduser("~/.crawl4ai/downloads") downloads_path=os.path.expanduser("~/.crawl4ai/downloads"),
) )
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
screenshot=True, screenshot=True,
pdf=True, pdf=True,
adjust_viewport_to_content=True, adjust_viewport_to_content=True,
wait_for_images=True, wait_for_images=True,
screenshot_height_threshold=20000 screenshot_height_threshold=20000,
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun( result = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
assert result.screenshot is not None, "Screenshot capture failed" assert result.screenshot is not None, "Screenshot capture failed"
assert result.pdf is not None, "PDF generation failed" assert result.pdf is not None, "PDF generation failed"
# Category 5: Anti-Bot and Site Interaction Tests # Category 5: Anti-Bot and Site Interaction Tests
async def test_antibot_config(): async def test_antibot_config():
"""Test configurations for handling anti-bot measures""" """Test configurations for handling anti-bot measures"""
@@ -135,57 +127,43 @@ async def test_antibot_config():
wait_for="js:()=>document.querySelector('body')", wait_for="js:()=>document.querySelector('body')",
delay_before_return_html=1.0, delay_before_return_html=1.0,
log_console=True, log_console=True,
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
assert result.success, "Anti-bot measure handling failed" assert result.success, "Anti-bot measure handling failed"
# Category 6: Parallel Processing Tests # Category 6: Parallel Processing Tests
async def test_parallel_processing(): async def test_parallel_processing():
"""Test parallel processing capabilities""" """Test parallel processing capabilities"""
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5)
mean_delay=0.5,
max_range=1.0,
semaphore_count=5
)
urls = [ urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"]
'https://example.com/1',
'https://example.com/2',
'https://example.com/3'
]
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
results = await crawler.arun_many( results = await crawler.arun_many(urls, config=crawler_config)
urls,
config=crawler_config
)
assert len(results) == len(urls), "Not all URLs were processed" assert len(results) == len(urls), "Not all URLs were processed"
assert all(r.success for r in results), "Some parallel requests failed" assert all(r.success for r in results), "Some parallel requests failed"
# Category 7: Backwards Compatibility Tests # Category 7: Backwards Compatibility Tests
async def test_legacy_parameter_support(): async def test_legacy_parameter_support():
"""Test that legacy parameters still work""" """Test that legacy parameters still work"""
async with AsyncWebCrawler( async with AsyncWebCrawler(
headless=True, headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768
browser_type="chromium",
viewport_width=1024,
viewport_height=768
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
'https://example.com', "https://example.com",
screenshot=True, screenshot=True,
word_count_threshold=200, word_count_threshold=200,
bypass_cache=True, bypass_cache=True,
css_selector=".main-content" css_selector=".main-content",
) )
assert result.success, "Legacy parameter support failed" assert result.success, "Legacy parameter support failed"
# Category 8: Mixed Configuration Tests # Category 8: Mixed Configuration Tests
async def test_mixed_config_usage(): async def test_mixed_config_usage():
"""Test mixing new config objects with legacy parameters""" """Test mixing new config objects with legacy parameters"""
@@ -194,17 +172,19 @@ async def test_mixed_config_usage():
async with AsyncWebCrawler( async with AsyncWebCrawler(
config=browser_config, config=browser_config,
verbose=True # legacy parameter verbose=True, # legacy parameter
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
'https://example.com', "https://example.com",
config=crawler_config, config=crawler_config,
cache_mode=CacheMode.BYPASS, # legacy parameter cache_mode=CacheMode.BYPASS, # legacy parameter
css_selector="body" # legacy parameter css_selector="body", # legacy parameter
) )
assert result.success, "Mixed configuration usage failed" assert result.success, "Mixed configuration usage failed"
if __name__ == "__main__": if __name__ == "__main__":
async def run_tests(): async def run_tests():
test_functions = [ test_functions = [
test_browser_config_object, test_browser_config_object,

View File

@@ -4,7 +4,6 @@ import asyncio
import shutil import shutil
from typing import List from typing import List
import tempfile import tempfile
import time
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -12,6 +11,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
class TestDownloads: class TestDownloads:
def __init__(self): def __init__(self):
self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_") self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_")
@@ -31,9 +31,7 @@ class TestDownloads:
"""Test basic file download functionality""" """Test basic file download functionality"""
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True, downloads_path=self.download_dir, verbose=True
downloads_path=self.download_dir,
verbose=True
) as crawler: ) as crawler:
# Python.org downloads page typically has stable download links # Python.org downloads page typically has stable download links
result = await crawler.arun( result = await crawler.arun(
@@ -42,14 +40,19 @@ class TestDownloads:
// Click first download link // Click first download link
const downloadLink = document.querySelector('a[href$=".exe"]'); const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click(); if (downloadLink) downloadLink.click();
""" """,
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 0 success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
)
self.log_result( self.log_result(
"Basic Download", "Basic Download",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result("Basic Download", False, str(e)) self.log_result("Basic Download", False, str(e))
@@ -65,21 +68,26 @@ class TestDownloads:
downloads_path=self.download_dir, downloads_path=self.download_dir,
use_persistent_context=True, use_persistent_context=True,
user_data_dir=user_data_dir, user_data_dir=user_data_dir,
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code=""" js_code="""
const downloadLink = document.querySelector('a[href$=".exe"]'); const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click(); if (downloadLink) downloadLink.click();
""" """,
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 0 success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
)
self.log_result( self.log_result(
"Persistent Context Download", "Persistent Context Download",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result("Persistent Context Download", False, str(e)) self.log_result("Persistent Context Download", False, str(e))
@@ -88,9 +96,7 @@ class TestDownloads:
"""Test multiple simultaneous downloads""" """Test multiple simultaneous downloads"""
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True, downloads_path=self.download_dir, verbose=True
downloads_path=self.download_dir,
verbose=True
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
@@ -98,14 +104,19 @@ class TestDownloads:
// Click multiple download links // Click multiple download links
const downloadLinks = document.querySelectorAll('a[href$=".exe"]'); const downloadLinks = document.querySelectorAll('a[href$=".exe"]');
downloadLinks.forEach(link => link.click()); downloadLinks.forEach(link => link.click());
""" """,
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 1 success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 1
)
self.log_result( self.log_result(
"Multiple Downloads", "Multiple Downloads",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "Not enough files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result("Multiple Downloads", False, str(e)) self.log_result("Multiple Downloads", False, str(e))
@@ -120,21 +131,26 @@ class TestDownloads:
accept_downloads=True, accept_downloads=True,
downloads_path=self.download_dir, downloads_path=self.download_dir,
browser_type=browser_type, browser_type=browser_type,
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code=""" js_code="""
const downloadLink = document.querySelector('a[href$=".exe"]'); const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click(); if (downloadLink) downloadLink.click();
""" """,
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 0 success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
)
self.log_result( self.log_result(
f"{browser_type.title()} Download", f"{browser_type.title()} Download",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result(f"{browser_type.title()} Download", False, str(e)) self.log_result(f"{browser_type.title()} Download", False, str(e))
@@ -144,18 +160,15 @@ class TestDownloads:
# Test 1: Downloads without specifying download path # Test 1: Downloads without specifying download path
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler:
accept_downloads=True,
verbose=True
) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()" js_code="document.querySelector('a[href$=\".exe\"]').click()",
) )
self.log_result( self.log_result(
"Default Download Path", "Default Download Path",
True, True,
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}" f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}",
) )
except Exception as e: except Exception as e:
self.log_result("Default Download Path", False, str(e)) self.log_result("Default Download Path", False, str(e))
@@ -165,31 +178,34 @@ class TestDownloads:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True,
downloads_path="/invalid/path/that/doesnt/exist", downloads_path="/invalid/path/that/doesnt/exist",
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()" js_code="document.querySelector('a[href$=\".exe\"]').click()",
)
self.log_result(
"Invalid Download Path", False, "Should have raised an error"
)
except Exception:
self.log_result(
"Invalid Download Path", True, "Correctly handled invalid path"
) )
self.log_result("Invalid Download Path", False, "Should have raised an error")
except Exception as e:
self.log_result("Invalid Download Path", True, "Correctly handled invalid path")
# Test 3: Download with accept_downloads=False # Test 3: Download with accept_downloads=False
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler:
accept_downloads=False,
verbose=True
) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()" js_code="document.querySelector('a[href$=\".exe\"]').click()",
) )
success = result.downloaded_files is None success = result.downloaded_files is None
self.log_result( self.log_result(
"Disabled Downloads", "Disabled Downloads",
success, success,
"Correctly ignored downloads" if success else "Unexpectedly downloaded files" "Correctly ignored downloads"
if success
else "Unexpectedly downloaded files",
) )
except Exception as e: except Exception as e:
self.log_result("Disabled Downloads", False, str(e)) self.log_result("Disabled Downloads", False, str(e))
@@ -203,7 +219,7 @@ class TestDownloads:
self.test_persistent_context_download, self.test_persistent_context_download,
self.test_multiple_downloads, self.test_multiple_downloads,
self.test_different_browsers, self.test_different_browsers,
self.test_edge_cases self.test_edge_cases,
] ]
for test in test_methods: for test in test_methods:
@@ -215,15 +231,17 @@ class TestDownloads:
for result in self.results: for result in self.results:
print(result) print(result)
successes = len([r for r in self.results if '' in r]) successes = len([r for r in self.results if "" in r])
total = len(self.results) total = len(self.results)
print(f"\nTotal: {successes}/{total} tests passed") print(f"\nTotal: {successes}/{total} tests passed")
self.cleanup() self.cleanup()
async def main(): async def main():
tester = TestDownloads() tester = TestDownloads()
await tester.run_all_tests() await tester.run_all_tests()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,15 +1,17 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import time import time
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_crawl(): async def test_successful_crawl():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -21,6 +23,7 @@ async def test_successful_crawl():
assert result.markdown assert result.markdown
assert result.cleaned_html assert result.cleaned_html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_url(): async def test_invalid_url():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -29,19 +32,21 @@ async def test_invalid_url():
assert not result.success assert not result.success
assert result.error_message assert result.error_message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_urls(): async def test_multiple_urls():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
urls = [ urls = [
"https://www.nbcnews.com/business", "https://www.nbcnews.com/business",
"https://www.example.com", "https://www.example.com",
"https://www.python.org" "https://www.python.org",
] ]
results = await crawler.arun_many(urls=urls, bypass_cache=True) results = await crawler.arun_many(urls=urls, bypass_cache=True)
assert len(results) == len(urls) assert len(results) == len(urls)
assert all(result.success for result in results) assert all(result.success for result in results)
assert all(result.html for result in results) assert all(result.html for result in results)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_javascript_execution(): async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -51,6 +56,7 @@ async def test_javascript_execution():
assert result.success assert result.success
assert "<h1>Modified by JS</h1>" in result.html assert "<h1>Modified by JS</h1>" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_crawling_performance(): async def test_concurrent_crawling_performance():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -59,7 +65,7 @@ async def test_concurrent_crawling_performance():
"https://www.example.com", "https://www.example.com",
"https://www.python.org", "https://www.python.org",
"https://www.github.com", "https://www.github.com",
"https://www.stackoverflow.com" "https://www.stackoverflow.com",
] ]
start_time = time.time() start_time = time.time()
@@ -74,7 +80,10 @@ async def test_concurrent_crawling_performance():
# Assert that concurrent crawling is faster than sequential # Assert that concurrent crawling is faster than sequential
# This multiplier may need adjustment based on the number of URLs and their complexity # This multiplier may need adjustment based on the number of URLs and their complexity
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" assert (
total_time < len(urls) * 5
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -9,6 +9,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_caching(): async def test_caching():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -31,6 +32,7 @@ async def test_caching():
assert result2.success assert result2.success
assert time_taken2 < time_taken1 # Cached result should be faster assert time_taken2 < time_taken1 # Cached result should be faster
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bypass_cache(): async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -47,6 +49,7 @@ async def test_bypass_cache():
# Content should be different (or at least, not guaranteed to be the same) # Content should be different (or at least, not guaranteed to be the same)
assert result1.html != result2.html or result1.markdown != result2.markdown assert result1.html != result2.html or result1.markdown != result2.markdown
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_clear_cache(): async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -62,6 +65,7 @@ async def test_clear_cache():
cache_size = await crawler.aget_cache_size() cache_size = await crawler.aget_cache_size()
assert cache_size == 0 assert cache_size == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flush_cache(): async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -77,6 +81,7 @@ async def test_flush_cache():
cache_size = await crawler.aget_cache_size() cache_size = await crawler.aget_cache_size()
assert cache_size == 0 assert cache_size == 0
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
@@ -9,8 +8,9 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking from crawl4ai.chunking_strategy import RegexChunking
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy from crawl4ai.extraction_strategy import LLMExtractionStrategy
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regex_chunking(): async def test_regex_chunking():
@@ -18,15 +18,14 @@ async def test_regex_chunking():
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
chunking_strategy = RegexChunking(patterns=["\n\n"]) chunking_strategy = RegexChunking(patterns=["\n\n"])
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, chunking_strategy=chunking_strategy, bypass_cache=True
chunking_strategy=chunking_strategy,
bypass_cache=True
) )
assert result.success assert result.success
assert result.extracted_content assert result.extracted_content
chunks = json.loads(result.extracted_content) chunks = json.loads(result.extracted_content)
assert len(chunks) > 1 # Ensure multiple chunks were created assert len(chunks) > 1 # Ensure multiple chunks were created
# @pytest.mark.asyncio # @pytest.mark.asyncio
# async def test_cosine_strategy(): # async def test_cosine_strategy():
# async with AsyncWebCrawler(verbose=True) as crawler: # async with AsyncWebCrawler(verbose=True) as crawler:
@@ -43,25 +42,25 @@ async def test_regex_chunking():
# assert len(extracted_data) > 0 # assert len(extracted_data) > 0
# assert all('tags' in item for item in extracted_data) # assert all('tags' in item for item in extracted_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_extraction_strategy(): async def test_llm_extraction_strategy():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
extraction_strategy = LLMExtractionStrategy( extraction_strategy = LLMExtractionStrategy(
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv('OPENAI_API_KEY'), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract only content related to technology" instruction="Extract only content related to technology",
) )
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, extraction_strategy=extraction_strategy, bypass_cache=True
extraction_strategy=extraction_strategy,
bypass_cache=True
) )
assert result.success assert result.success
assert result.extracted_content assert result.extracted_content
extracted_data = json.loads(result.extracted_content) extracted_data = json.loads(result.extracted_content)
assert len(extracted_data) > 0 assert len(extracted_data) > 0
assert all('content' in item for item in extracted_data) assert all("content" in item for item in extracted_data)
# @pytest.mark.asyncio # @pytest.mark.asyncio
# async def test_combined_chunking_and_extraction(): # async def test_combined_chunking_and_extraction():

View File

@@ -1,8 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_markdown(): async def test_extract_markdown():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -20,6 +19,7 @@ async def test_extract_markdown():
assert isinstance(result.markdown, str) assert isinstance(result.markdown, str)
assert len(result.markdown) > 0 assert len(result.markdown) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_cleaned_html(): async def test_extract_cleaned_html():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -30,6 +30,7 @@ async def test_extract_cleaned_html():
assert isinstance(result.cleaned_html, str) assert isinstance(result.cleaned_html, str)
assert len(result.cleaned_html) > 0 assert len(result.cleaned_html) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_media(): async def test_extract_media():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -46,6 +47,7 @@ async def test_extract_media():
assert "alt" in image assert "alt" in image
assert "type" in image assert "type" in image
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_links(): async def test_extract_links():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -63,6 +65,7 @@ async def test_extract_links():
assert "href" in link assert "href" in link
assert "text" in link assert "text" in link
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_metadata(): async def test_extract_metadata():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -75,16 +78,20 @@ async def test_extract_metadata():
assert "title" in metadata assert "title" in metadata
assert isinstance(metadata["title"], str) assert isinstance(metadata["title"], str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_css_selector_extraction(): async def test_css_selector_extraction():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
css_selector = "h1, h2, h3" css_selector = "h1, h2, h3"
result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector) result = await crawler.arun(
url=url, bypass_cache=True, css_selector=css_selector
)
assert result.success assert result.success
assert result.markdown assert result.markdown
assert all(heading in result.markdown for heading in ["#", "##", "###"]) assert all(heading in result.markdown for heading in ["#", "##", "###"])
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os, sys import os, sys
import pytest import pytest
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import List
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -9,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.content_filter_strategy import BM25ContentFilter from crawl4ai.content_filter_strategy import BM25ContentFilter
@pytest.fixture @pytest.fixture
def basic_html(): def basic_html():
return """ return """
@@ -28,6 +28,7 @@ def basic_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def wiki_html(): def wiki_html():
return """ return """
@@ -46,6 +47,7 @@ def wiki_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def no_meta_html(): def no_meta_html():
return """ return """
@@ -57,6 +59,7 @@ def no_meta_html():
</html> </html>
""" """
class TestBM25ContentFilter: class TestBM25ContentFilter:
def test_basic_extraction(self, basic_html): def test_basic_extraction(self, basic_html):
"""Test basic content extraction functionality""" """Test basic content extraction functionality"""
@@ -65,8 +68,8 @@ class TestBM25ContentFilter:
assert contents, "Should extract content" assert contents, "Should extract content"
assert len(contents) >= 1, "Should extract at least one content block" assert len(contents) >= 1, "Should extract at least one content block"
assert "long paragraph" in ' '.join(contents).lower() assert "long paragraph" in " ".join(contents).lower()
assert "navigation" not in ' '.join(contents).lower() assert "navigation" not in " ".join(contents).lower()
def test_user_query_override(self, basic_html): def test_user_query_override(self, basic_html):
"""Test that user query overrides metadata extraction""" """Test that user query overrides metadata extraction"""
@@ -74,8 +77,8 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter(user_query=user_query) filter = BM25ContentFilter(user_query=user_query)
# Access internal state to verify query usage # Access internal state to verify query usage
soup = BeautifulSoup(basic_html, 'lxml') soup = BeautifulSoup(basic_html, "lxml")
extracted_query = filter.extract_page_query(soup.find('head')) extracted_query = filter.extract_page_query(soup.find("head"))
assert extracted_query == user_query assert extracted_query == user_query
assert "Test description" not in extracted_query assert "Test description" not in extracted_query
@@ -85,7 +88,7 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(wiki_html) contents = filter.filter_content(wiki_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "section 1" in combined_content, "Should include section header" assert "section 1" in combined_content, "Should include section header"
assert "article title" in combined_content, "Should include main title" assert "article title" in combined_content, "Should include main title"
@@ -95,7 +98,9 @@ class TestBM25ContentFilter:
contents = filter.filter_content(no_meta_html) contents = filter.filter_content(no_meta_html)
assert contents, "Should extract content even without metadata" assert contents, "Should extract content even without metadata"
assert "First paragraph" in ' '.join(contents), "Should use first paragraph content" assert "First paragraph" in " ".join(
contents
), "Should use first paragraph content"
def test_empty_input(self): def test_empty_input(self):
"""Test handling of empty input""" """Test handling of empty input"""
@@ -119,18 +124,19 @@ class TestBM25ContentFilter:
strict_contents = strict_filter.filter_content(basic_html) strict_contents = strict_filter.filter_content(basic_html)
lenient_contents = lenient_filter.filter_content(basic_html) lenient_contents = lenient_filter.filter_content(basic_html)
assert len(strict_contents) <= len(lenient_contents), \ assert len(strict_contents) <= len(
"Strict threshold should extract fewer elements" lenient_contents
), "Strict threshold should extract fewer elements"
def test_html_cleaning(self, basic_html): def test_html_cleaning(self, basic_html):
"""Test HTML cleaning functionality""" """Test HTML cleaning functionality"""
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(basic_html) contents = filter.filter_content(basic_html)
cleaned_content = ' '.join(contents) cleaned_content = " ".join(contents)
assert 'class=' not in cleaned_content, "Should remove class attributes" assert "class=" not in cleaned_content, "Should remove class attributes"
assert 'style=' not in cleaned_content, "Should remove style attributes" assert "style=" not in cleaned_content, "Should remove style attributes"
assert '<script' not in cleaned_content, "Should remove script tags" assert "<script" not in cleaned_content, "Should remove script tags"
def test_large_content(self): def test_large_content(self):
"""Test handling of large content blocks""" """Test handling of large content blocks"""
@@ -143,9 +149,9 @@ class TestBM25ContentFilter:
contents = filter.filter_content(large_html) contents = filter.filter_content(large_html)
assert contents, "Should handle large content blocks" assert contents, "Should handle large content blocks"
@pytest.mark.parametrize("unwanted_tag", [ @pytest.mark.parametrize(
'script', 'style', 'nav', 'footer', 'header' "unwanted_tag", ["script", "style", "nav", "footer", "header"]
]) )
def test_excluded_tags(self, unwanted_tag): def test_excluded_tags(self, unwanted_tag):
"""Test that specific tags are properly excluded""" """Test that specific tags are properly excluded"""
html = f""" html = f"""
@@ -157,7 +163,7 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(html) contents = filter.filter_content(html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "should not appear" not in combined_content assert "should not appear" not in combined_content
def test_performance(self, basic_html): def test_performance(self, basic_html):
@@ -165,11 +171,13 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter() filter = BM25ContentFilter()
import time import time
start = time.perf_counter() start = time.perf_counter()
filter.filter_content(basic_html) filter.filter_content(basic_html)
duration = time.perf_counter() - start duration = time.perf_counter() - start
assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds" assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@@ -1,12 +1,12 @@
import os, sys import os, sys
import pytest import pytest
from bs4 import BeautifulSoup
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
@pytest.fixture @pytest.fixture
def basic_html(): def basic_html():
return """ return """
@@ -22,6 +22,7 @@ def basic_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def link_heavy_html(): def link_heavy_html():
return """ return """
@@ -40,6 +41,7 @@ def link_heavy_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def mixed_content_html(): def mixed_content_html():
return """ return """
@@ -60,13 +62,14 @@ def mixed_content_html():
</html> </html>
""" """
class TestPruningContentFilter: class TestPruningContentFilter:
def test_basic_pruning(self, basic_html): def test_basic_pruning(self, basic_html):
"""Test basic content pruning functionality""" """Test basic content pruning functionality"""
filter = PruningContentFilter(min_word_threshold=5) filter = PruningContentFilter(min_word_threshold=5)
contents = filter.filter_content(basic_html) contents = filter.filter_content(basic_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "high-quality paragraph" in combined_content assert "high-quality paragraph" in combined_content
assert "sidebar content" not in combined_content assert "sidebar content" not in combined_content
assert "share buttons" not in combined_content assert "share buttons" not in combined_content
@@ -76,39 +79,41 @@ class TestPruningContentFilter:
filter = PruningContentFilter(min_word_threshold=10) filter = PruningContentFilter(min_word_threshold=10)
contents = filter.filter_content(mixed_content_html) contents = filter.filter_content(mixed_content_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "short summary" not in combined_content assert "short summary" not in combined_content
assert "long high-quality paragraph" in combined_content assert "long high-quality paragraph" in combined_content
assert "short comment" not in combined_content assert "short comment" not in combined_content
def test_threshold_types(self, basic_html): def test_threshold_types(self, basic_html):
"""Test fixed vs dynamic thresholds""" """Test fixed vs dynamic thresholds"""
fixed_filter = PruningContentFilter(threshold_type='fixed', threshold=0.48) fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48)
dynamic_filter = PruningContentFilter(threshold_type='dynamic', threshold=0.45) dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45)
fixed_contents = fixed_filter.filter_content(basic_html) fixed_contents = fixed_filter.filter_content(basic_html)
dynamic_contents = dynamic_filter.filter_content(basic_html) dynamic_contents = dynamic_filter.filter_content(basic_html)
assert len(fixed_contents) != len(dynamic_contents), \ assert len(fixed_contents) != len(
"Fixed and dynamic thresholds should yield different results" dynamic_contents
), "Fixed and dynamic thresholds should yield different results"
def test_link_density_impact(self, link_heavy_html): def test_link_density_impact(self, link_heavy_html):
"""Test handling of link-heavy content""" """Test handling of link-heavy content"""
filter = PruningContentFilter(threshold_type='dynamic') filter = PruningContentFilter(threshold_type="dynamic")
contents = filter.filter_content(link_heavy_html) contents = filter.filter_content(link_heavy_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "good content paragraph" in combined_content assert "good content paragraph" in combined_content
assert len([c for c in contents if 'href' in c]) < 2, \ assert (
"Should prune link-heavy sections" len([c for c in contents if "href" in c]) < 2
), "Should prune link-heavy sections"
def test_tag_importance(self, mixed_content_html): def test_tag_importance(self, mixed_content_html):
"""Test tag importance in scoring""" """Test tag importance in scoring"""
filter = PruningContentFilter(threshold_type='dynamic') filter = PruningContentFilter(threshold_type="dynamic")
contents = filter.filter_content(mixed_content_html) contents = filter.filter_content(mixed_content_html)
has_article = any('article' in c.lower() for c in contents) has_article = any("article" in c.lower() for c in contents)
has_h1 = any('h1' in c.lower() for c in contents) has_h1 = any("h1" in c.lower() for c in contents)
assert has_article or has_h1, "Should retain important tags" assert has_article or has_h1, "Should retain important tags"
def test_empty_input(self): def test_empty_input(self):
@@ -129,6 +134,7 @@ class TestPruningContentFilter:
filter = PruningContentFilter() filter = PruningContentFilter()
import time import time
start = time.perf_counter() start = time.perf_counter()
filter.filter_content(basic_html) filter.filter_content(basic_html)
duration = time.perf_counter() - start duration = time.perf_counter() - start
@@ -136,17 +142,21 @@ class TestPruningContentFilter:
# Extra strict on performance since you mentioned milliseconds matter # Extra strict on performance since you mentioned milliseconds matter
assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds" assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds"
@pytest.mark.parametrize("threshold,expected_count", [ @pytest.mark.parametrize(
"threshold,expected_count",
[
(0.3, 4), # Very lenient (0.3, 4), # Very lenient
(0.48, 2), # Default (0.48, 2), # Default
(0.7, 1), # Very strict (0.7, 1), # Very strict
]) ],
)
def test_threshold_levels(self, mixed_content_html, threshold, expected_count): def test_threshold_levels(self, mixed_content_html, threshold, expected_count):
"""Test different threshold levels""" """Test different threshold levels"""
filter = PruningContentFilter(threshold_type='fixed', threshold=threshold) filter = PruningContentFilter(threshold_type="fixed", threshold=threshold)
contents = filter.filter_content(mixed_content_html) contents = filter.filter_content(mixed_content_html)
assert len(contents) <= expected_count, \ assert (
f"Expected {expected_count} or fewer elements with threshold {threshold}" len(contents) <= expected_count
), f"Expected {expected_count} or fewer elements with threshold {threshold}"
def test_consistent_output(self, basic_html): def test_consistent_output(self, basic_html):
"""Test output consistency across multiple runs""" """Test output consistency across multiple runs"""
@@ -155,5 +165,6 @@ class TestPruningContentFilter:
second_run = filter.filter_content(basic_html) second_run = filter.filter_content(basic_html)
assert first_run == second_run, "Output should be consistent" assert first_run == second_run, "Output should be consistent"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@@ -1,22 +1,24 @@
import asyncio
from bs4 import BeautifulSoup
from typing import Dict, Any
import os import os
import sys import sys
import time import time
import csv import csv
from tabulate import tabulate from tabulate import tabulate
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict from typing import List
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
from crawl4ai.content_scraping_strategy import WebScrapingStrategy from crawl4ai.content_scraping_strategy import WebScrapingStrategy
from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent from crawl4ai.content_scraping_strategy import (
WebScrapingStrategy as WebScrapingStrategyCurrent,
)
# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent # from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
@dataclass @dataclass
class TestResult: class TestResult:
name: str name: str
@@ -27,33 +29,32 @@ class TestResult:
markdown_length: int markdown_length: int
execution_time: float execution_time: float
class StrategyTester: class StrategyTester:
def __init__(self): def __init__(self):
self.new_scraper = WebScrapingStrategy() self.new_scraper = WebScrapingStrategy()
self.current_scraper = WebScrapingStrategyCurrent() self.current_scraper = WebScrapingStrategyCurrent()
with open(__location__ + '/sample_wikipedia.html', 'r', encoding='utf-8') as f: with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f:
self.WIKI_HTML = f.read() self.WIKI_HTML = f.read()
self.results = {'new': [], 'current': []} self.results = {"new": [], "current": []}
def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]: def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
results = [] results = []
for scraper in [self.new_scraper, self.current_scraper]: for scraper in [self.new_scraper, self.current_scraper]:
start_time = time.time() start_time = time.time()
result = scraper._get_content_of_website_optimized( result = scraper._get_content_of_website_optimized(
url="https://en.wikipedia.org/wiki/Test", url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs
html=self.WIKI_HTML,
**kwargs
) )
execution_time = time.time() - start_time execution_time = time.time() - start_time
test_result = TestResult( test_result = TestResult(
name=name, name=name,
success=result['success'], success=result["success"],
images=len(result['media']['images']), images=len(result["media"]["images"]),
internal_links=len(result['links']['internal']), internal_links=len(result["links"]["internal"]),
external_links=len(result['links']['external']), external_links=len(result["links"]["external"]),
markdown_length=len(result['markdown']), markdown_length=len(result["markdown"]),
execution_time=execution_time execution_time=execution_time,
) )
results.append(test_result) results.append(test_result)
@@ -62,34 +63,37 @@ class StrategyTester:
def run_all_tests(self): def run_all_tests(self):
test_cases = [ test_cases = [
("Basic Extraction", {}), ("Basic Extraction", {}),
("Exclude Tags", {'excluded_tags': ['table', 'div.infobox', 'div.navbox']}), ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
("Word Threshold", {'word_count_threshold': 50}), ("Word Threshold", {"word_count_threshold": 50}),
("CSS Selector", {'css_selector': 'div.mw-parser-output > p'}), ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
("Link Exclusions", { (
'exclude_external_links': True, "Link Exclusions",
'exclude_social_media_links': True, {
'exclude_domains': ['facebook.com', 'twitter.com'] "exclude_external_links": True,
}), "exclude_social_media_links": True,
("Media Handling", { "exclude_domains": ["facebook.com", "twitter.com"],
'exclude_external_images': True, },
'image_description_min_word_threshold': 20 ),
}), (
("Text Only", { "Media Handling",
'only_text': True, {
'remove_forms': True "exclude_external_images": True,
}), "image_description_min_word_threshold": 20,
("HTML Cleaning", { },
'clean_html': True, ),
'keep_data_attributes': True ("Text Only", {"only_text": True, "remove_forms": True}),
}), ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
("HTML2Text Options", { (
'html2text': { "HTML2Text Options",
'skip_internal_links': True, {
'single_line_break': True, "html2text": {
'mark_code': True, "skip_internal_links": True,
'preserve_tags': ['pre', 'code'] "single_line_break": True,
"mark_code": True,
"preserve_tags": ["pre", "code"],
} }
}) },
),
] ]
all_results = [] all_results = []
@@ -104,58 +108,111 @@ class StrategyTester:
self.print_comparison_table(all_results) self.print_comparison_table(all_results)
def save_results_to_csv(self, all_results: List[tuple]): def save_results_to_csv(self, all_results: List[tuple]):
csv_file = os.path.join(__location__, 'strategy_comparison_results.csv') csv_file = os.path.join(__location__, "strategy_comparison_results.csv")
with open(csv_file, 'w', newline='') as f: with open(csv_file, "w", newline="") as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow(['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links', writer.writerow(
'External Links', 'Markdown Length', 'Execution Time']) [
"Test Name",
"Strategy",
"Success",
"Images",
"Internal Links",
"External Links",
"Markdown Length",
"Execution Time",
]
)
for name, new_result, current_result in all_results: for name, new_result, current_result in all_results:
writer.writerow([name, 'New', new_result.success, new_result.images, writer.writerow(
new_result.internal_links, new_result.external_links, [
new_result.markdown_length, f"{new_result.execution_time:.3f}"]) name,
writer.writerow([name, 'Current', current_result.success, current_result.images, "New",
current_result.internal_links, current_result.external_links, new_result.success,
current_result.markdown_length, f"{current_result.execution_time:.3f}"]) new_result.images,
new_result.internal_links,
new_result.external_links,
new_result.markdown_length,
f"{new_result.execution_time:.3f}",
]
)
writer.writerow(
[
name,
"Current",
current_result.success,
current_result.images,
current_result.internal_links,
current_result.external_links,
current_result.markdown_length,
f"{current_result.execution_time:.3f}",
]
)
def print_comparison_table(self, all_results: List[tuple]): def print_comparison_table(self, all_results: List[tuple]):
table_data = [] table_data = []
headers = ['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links', headers = [
'External Links', 'Markdown Length', 'Time (s)'] "Test Name",
"Strategy",
"Success",
"Images",
"Internal Links",
"External Links",
"Markdown Length",
"Time (s)",
]
for name, new_result, current_result in all_results: for name, new_result, current_result in all_results:
# Check for differences # Check for differences
differences = [] differences = []
if new_result.images != current_result.images: differences.append('images') if new_result.images != current_result.images:
if new_result.internal_links != current_result.internal_links: differences.append('internal_links') differences.append("images")
if new_result.external_links != current_result.external_links: differences.append('external_links') if new_result.internal_links != current_result.internal_links:
if new_result.markdown_length != current_result.markdown_length: differences.append('markdown') differences.append("internal_links")
if new_result.external_links != current_result.external_links:
differences.append("external_links")
if new_result.markdown_length != current_result.markdown_length:
differences.append("markdown")
# Add row for new strategy # Add row for new strategy
new_row = [ new_row = [
name, 'New', new_result.success, new_result.images, name,
new_result.internal_links, new_result.external_links, "New",
new_result.markdown_length, f"{new_result.execution_time:.3f}" new_result.success,
new_result.images,
new_result.internal_links,
new_result.external_links,
new_result.markdown_length,
f"{new_result.execution_time:.3f}",
] ]
table_data.append(new_row) table_data.append(new_row)
# Add row for current strategy # Add row for current strategy
current_row = [ current_row = [
'', 'Current', current_result.success, current_result.images, "",
current_result.internal_links, current_result.external_links, "Current",
current_result.markdown_length, f"{current_result.execution_time:.3f}" current_result.success,
current_result.images,
current_result.internal_links,
current_result.external_links,
current_result.markdown_length,
f"{current_result.execution_time:.3f}",
] ]
table_data.append(current_row) table_data.append(current_row)
# Add difference summary if any # Add difference summary if any
if differences: if differences:
table_data.append(['', '⚠️ Differences', ', '.join(differences), '', '', '', '', '']) table_data.append(
["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
)
# Add empty row for better readability # Add empty row for better readability
table_data.append([''] * len(headers)) table_data.append([""] * len(headers))
print("\nStrategy Comparison Results:") print("\nStrategy Comparison Results:")
print(tabulate(table_data, headers=headers, tablefmt='grid')) print(tabulate(table_data, headers=headers, tablefmt="grid"))
if __name__ == "__main__": if __name__ == "__main__":
tester = StrategyTester() tester = StrategyTester()

View File

@@ -1,14 +1,13 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_user_agent(): async def test_custom_user_agent():
@@ -20,6 +19,7 @@ async def test_custom_user_agent():
assert result.success assert result.success
assert custom_user_agent in result.html assert custom_user_agent in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_headers(): async def test_custom_headers():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -31,6 +31,7 @@ async def test_custom_headers():
assert "X-Test-Header" in result.html assert "X-Test-Header" in result.html
assert "TestValue" in result.html assert "TestValue" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_javascript_execution(): async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -40,19 +41,22 @@ async def test_javascript_execution():
assert result.success assert result.success
assert "<h1>Modified by JS</h1>" in result.html assert "<h1>Modified by JS</h1>" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_hook_execution(): async def test_hook_execution():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
async def test_hook(page): async def test_hook(page):
await page.evaluate("document.body.style.backgroundColor = 'red';") await page.evaluate("document.body.style.backgroundColor = 'red';")
return page return page
crawler.crawler_strategy.set_hook('after_goto', test_hook) crawler.crawler_strategy.set_hook("after_goto", test_hook)
url = "https://www.example.com" url = "https://www.example.com"
result = await crawler.arun(url=url, bypass_cache=True) result = await crawler.arun(url=url, bypass_cache=True)
assert result.success assert result.success
assert "background-color: red" in result.html assert "background-color: red" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot(): async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -63,6 +67,7 @@ async def test_screenshot():
assert isinstance(result.screenshot, str) assert isinstance(result.screenshot, str)
assert len(result.screenshot) > 0 assert len(result.screenshot) > 0
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,8 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cache_url(): async def test_cache_url():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -23,6 +22,7 @@ async def test_cache_url():
assert result2.success assert result2.success
assert result2.html == result1.html assert result2.html == result1.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bypass_cache(): async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -34,7 +34,10 @@ async def test_bypass_cache():
# Second run bypassing cache # Second run bypassing cache
result2 = await crawler.arun(url=url, bypass_cache=True) result2 = await crawler.arun(url=url, bypass_cache=True)
assert result2.success assert result2.success
assert result2.html != result1.html # Content might be different due to dynamic nature of websites assert (
result2.html != result1.html
) # Content might be different due to dynamic nature of websites
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cache_size(): async def test_cache_size():
@@ -47,6 +50,7 @@ async def test_cache_size():
new_size = await crawler.aget_cache_size() new_size = await crawler.aget_cache_size()
assert new_size == initial_size + 1 assert new_size == initial_size + 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_clear_cache(): async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -60,6 +64,7 @@ async def test_clear_cache():
new_size = await crawler.aget_cache_size() new_size = await crawler.aget_cache_size()
assert new_size == 0 assert new_size == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flush_cache(): async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -75,7 +80,10 @@ async def test_flush_cache():
# Try to retrieve the previously cached URL # Try to retrieve the previously cached URL
result = await crawler.arun(url=url, bypass_cache=False) result = await crawler.arun(url=url, bypass_cache=False)
assert result.success # The crawler should still succeed, but it will fetch the content anew assert (
result.success
) # The crawler should still succeed, but it will fetch the content anew
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,114 +1,133 @@
import pytest import pytest
import asyncio, time import time
from crawl4ai import ( from crawl4ai import (
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, AsyncWebCrawler,
MemoryAdaptiveDispatcher, SemaphoreDispatcher, BrowserConfig,
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode CrawlerRunConfig,
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
CacheMode,
) )
@pytest.fixture @pytest.fixture
def browser_config(): def browser_config():
return BrowserConfig( return BrowserConfig(headless=True, verbose=False)
headless=True,
verbose=False
)
@pytest.fixture @pytest.fixture
def run_config(): def run_config():
return CrawlerRunConfig( return CrawlerRunConfig(cache_mode=CacheMode.BYPASS, verbose=False)
cache_mode=CacheMode.BYPASS,
verbose=False
)
@pytest.fixture @pytest.fixture
def test_urls(): def test_urls():
return [ return [
"http://example.com", "http://example.com",
"http://example.com/page1", "http://example.com/page1",
"http://example.com/page2" "http://example.com/page2",
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
class TestDispatchStrategies: class TestDispatchStrategies:
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls): async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=70.0, memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1
max_session_permit=2, )
check_interval=0.1 results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_memory_adaptive_with_rate_limit(self, browser_config, run_config, test_urls): async def test_memory_adaptive_with_rate_limit(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=70.0, memory_threshold_percent=70.0,
max_session_permit=2, max_session_permit=2,
check_interval=0.1, check_interval=0.1,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(0.1, 0.2), base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
max_delay=1.0, ),
max_retries=2
) )
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_semaphore_basic(self, browser_config, run_config, test_urls): async def test_semaphore_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(semaphore_count=2)
semaphore_count=2 results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_semaphore_with_rate_limit(self, browser_config, run_config, test_urls): async def test_semaphore_with_rate_limit(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(
semaphore_count=2, semaphore_count=2,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(0.1, 0.2), base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
max_delay=1.0, ),
max_retries=2
) )
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_memory_adaptive_memory_error(self, browser_config, run_config, test_urls): async def test_memory_adaptive_memory_error(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=1.0, # Set unrealistically low threshold memory_threshold_percent=1.0, # Set unrealistically low threshold
max_session_permit=2, max_session_permit=2,
check_interval=0.1, check_interval=0.1,
memory_wait_timeout=1.0 # Short timeout for testing memory_wait_timeout=1.0, # Short timeout for testing
) )
with pytest.raises(MemoryError): with pytest.raises(MemoryError):
await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
async def test_empty_urls(self, browser_config, run_config): async def test_empty_urls(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many([], config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
[], config=run_config, dispatcher=dispatcher
)
assert len(results) == 0 assert len(results) == 0
async def test_single_url(self, browser_config, run_config): async def test_single_url(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many(["http://example.com"], config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
["http://example.com"], config=run_config, dispatcher=dispatcher
)
assert len(results) == 1 assert len(results) == 1
assert results[0].success assert results[0].success
async def test_invalid_urls(self, browser_config, run_config): async def test_invalid_urls(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many(["http://invalid.url.that.doesnt.exist"], config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
["http://invalid.url.that.doesnt.exist"],
config=run_config,
dispatcher=dispatcher,
)
assert len(results) == 1 assert len(results) == 1
assert not results[0].success assert not results[0].success
@@ -121,27 +140,31 @@ class TestDispatchStrategies:
base_delay=(0.1, 0.2), base_delay=(0.1, 0.2),
max_delay=1.0, max_delay=1.0,
max_retries=2, max_retries=2,
rate_limit_codes=[200] # Force rate limiting for testing rate_limit_codes=[200], # Force rate limiting for testing
) ),
) )
start_time = time.time() start_time = time.time()
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
duration = time.time() - start_time duration = time.time() - start_time
assert len(results) == len(urls) assert len(results) == len(urls)
assert duration > 1.0 # Ensure rate limiting caused delays assert duration > 1.0 # Ensure rate limiting caused delays
async def test_monitor_integration(self, browser_config, run_config, test_urls): async def test_monitor_integration(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
monitor = CrawlerMonitor(max_visible_rows=5, display_mode=DisplayMode.DETAILED) monitor = CrawlerMonitor(
dispatcher = MemoryAdaptiveDispatcher( max_visible_rows=5, display_mode=DisplayMode.DETAILED
max_session_permit=2, )
monitor=monitor dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
# Check monitor stats # Check monitor stats
assert len(monitor.stats) == len(test_urls) assert len(monitor.stats) == len(test_urls)
assert all(stat.end_time is not None for stat in monitor.stats.values()) assert all(stat.end_time is not None for stat in monitor.stats.values())
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v", "--asyncio-mode=auto"]) pytest.main([__file__, "-v", "--asyncio-mode=auto"])

View File

@@ -2,9 +2,9 @@ import os
import re import re
import sys import sys
import pytest import pytest
import json
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import asyncio import asyncio
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
@@ -59,19 +59,21 @@ from crawl4ai.async_webcrawler import AsyncWebCrawler
# assert result.success # assert result.success
# assert "github" in result.html.lower() # assert "github" in result.html.lower()
# Add this test to your existing test file # Add this test to your existing test file
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_typescript_commits_multi_page(): async def test_typescript_commits_multi_page():
first_commit = "" first_commit = ""
async def on_execution_started(page): async def on_execution_started(page):
nonlocal first_commit nonlocal first_commit
try: try:
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4')) # Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
while True: while True:
await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4') await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4")
commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4') commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4")
commit = await commit.evaluate('(element) => element.textContent') commit = await commit.evaluate("(element) => element.textContent")
commit = re.sub(r'\s+', '', commit) commit = re.sub(r"\s+", "", commit)
if commit and commit != first_commit: if commit and commit != first_commit:
first_commit = commit first_commit = commit
break break
@@ -79,9 +81,8 @@ async def test_typescript_commits_multi_page():
except Exception as e: except Exception as e:
print(f"Warning: New content didn't appear after JavaScript execution: {e}") print(f"Warning: New content didn't appear after JavaScript execution: {e}")
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started) crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
url = "https://github.com/microsoft/TypeScript/commits/main" url = "https://github.com/microsoft/TypeScript/commits/main"
session_id = "typescript_commits_session" session_id = "typescript_commits_session"
@@ -97,19 +98,21 @@ async def test_typescript_commits_multi_page():
url=url, # Only use URL for the first page url=url, # Only use URL for the first page
session_id=session_id, session_id=session_id,
css_selector="li.Box-sc-g0xbh4-0", css_selector="li.Box-sc-g0xbh4-0",
js=js_next_page if page > 0 else None, # Don't click 'next' on the first page js=js_next_page
if page > 0
else None, # Don't click 'next' on the first page
bypass_cache=True, bypass_cache=True,
js_only=page > 0 # Use js_only for subsequent pages js_only=page > 0, # Use js_only for subsequent pages
) )
assert result.success, f"Failed to crawl page {page + 1}" assert result.success, f"Failed to crawl page {page + 1}"
# Parse the HTML and extract commits # Parse the HTML and extract commits
soup = BeautifulSoup(result.cleaned_html, 'html.parser') soup = BeautifulSoup(result.cleaned_html, "html.parser")
commits = soup.select("li") commits = soup.select("li")
# Take first commit find h4 extract text # Take first commit find h4 extract text
first_commit = commits[0].find("h4").text first_commit = commits[0].find("h4").text
first_commit = re.sub(r'\s+', '', first_commit) first_commit = re.sub(r"\s+", "", first_commit)
all_commits.extend(commits) all_commits.extend(commits)
print(f"Page {page + 1}: Found {len(commits)} commits") print(f"Page {page + 1}: Found {len(commits)} commits")
@@ -118,10 +121,13 @@ async def test_typescript_commits_multi_page():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
# Assertions # Assertions
assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}" assert (
len(all_commits) >= 90
), f"Expected at least 90 commits, but got {len(all_commits)}"
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,11 +1,15 @@
import json import json
import time import time
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from crawl4ai.content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy from crawl4ai.content_scraping_strategy import (
from typing import Dict, Any, List, Tuple WebScrapingStrategy,
LXMLWebScrapingStrategy,
)
from typing import Dict, List, Tuple
import difflib import difflib
from lxml import html as lhtml, etree from lxml import html as lhtml, etree
def normalize_dom(element): def normalize_dom(element):
""" """
Recursively normalizes an lxml HTML element: Recursively normalizes an lxml HTML element:
@@ -15,7 +19,7 @@ def normalize_dom(element):
Returns the same element (mutated). Returns the same element (mutated).
""" """
# Remove comment nodes # Remove comment nodes
comments = element.xpath('//comment()') comments = element.xpath("//comment()")
for c in comments: for c in comments:
p = c.getparent() p = c.getparent()
if p is not None: if p is not None:
@@ -53,8 +57,8 @@ def strip_html_body(root):
tag_name = (root.tag or "").lower() tag_name = (root.tag or "").lower()
# Case 1: The root is <html> # Case 1: The root is <html>
if tag_name == 'html': if tag_name == "html":
bodies = root.xpath('./body') bodies = root.xpath("./body")
if bodies: if bodies:
body = bodies[0] body = bodies[0]
new_div = lhtml.Element("div") new_div = lhtml.Element("div")
@@ -66,7 +70,7 @@ def strip_html_body(root):
return root return root
# Case 2: The root is <body> # Case 2: The root is <body>
elif tag_name == 'body': elif tag_name == "body":
new_div = lhtml.Element("div") new_div = lhtml.Element("div")
for child in root: for child in root:
new_div.append(child) new_div.append(child)
@@ -92,7 +96,9 @@ def compare_nodes(node1, node2, differences, path="/"):
attrs1 = list(node1.attrib.items()) attrs1 = list(node1.attrib.items())
attrs2 = list(node2.attrib.items()) attrs2 = list(node2.attrib.items())
if attrs1 != attrs2: if attrs1 != attrs2:
differences.append(f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}") differences.append(
f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}"
)
# 3) Compare text (trim or unify whitespace as needed) # 3) Compare text (trim or unify whitespace as needed)
text1 = (node1.text or "").strip() text1 = (node1.text or "").strip()
@@ -102,7 +108,9 @@ def compare_nodes(node1, node2, differences, path="/"):
text2 = " ".join(text2.split()) text2 = " ".join(text2.split())
if text1 != text2: if text1 != text2:
# If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup # If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup
differences.append(f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'") differences.append(
f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'"
)
# 4) Compare number of children # 4) Compare number of children
children1 = list(node1) children1 = list(node1)
@@ -123,7 +131,9 @@ def compare_nodes(node1, node2, differences, path="/"):
tail1 = (node1.tail or "").strip() tail1 = (node1.tail or "").strip()
tail2 = (node2.tail or "").strip() tail2 = (node2.tail or "").strip()
if tail1 != tail2: if tail1 != tail2:
differences.append(f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'") differences.append(
f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'"
)
def compare_html_structurally(html1, html2): def compare_html_structurally(html1, html2):
@@ -156,11 +166,11 @@ def compare_html_structurally(html1, html2):
return differences return differences
def generate_large_html(n_elements=1000): def generate_large_html(n_elements=1000):
html = ['<!DOCTYPE html><html><head></head><body>'] html = ["<!DOCTYPE html><html><head></head><body>"]
for i in range(n_elements): for i in range(n_elements):
html.append(f''' html.append(
f"""
<div class="article"> <div class="article">
<h2>Heading {i}</h2> <h2>Heading {i}</h2>
<p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p> <p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p>
@@ -170,9 +180,11 @@ def generate_large_html(n_elements=1000):
<li>List item {i}.2</li> <li>List item {i}.2</li>
</ul> </ul>
</div> </div>
''') """
html.append('</body></html>') )
return ''.join(html) html.append("</body></html>")
return "".join(html)
def generate_complicated_html(): def generate_complicated_html():
""" """
@@ -352,13 +364,12 @@ def get_test_scenarios():
return TEST_SCENARIOS return TEST_SCENARIOS
class ScraperEquivalenceTester: class ScraperEquivalenceTester:
def __init__(self): def __init__(self):
self.test_cases = { self.test_cases = {
'basic': self.generate_basic_html(), "basic": self.generate_basic_html(),
'complex': self.generate_complex_html(), "complex": self.generate_complex_html(),
'malformed': self.generate_malformed_html(), "malformed": self.generate_malformed_html(),
# 'real_world': self.load_real_samples() # 'real_world': self.load_real_samples()
} }
@@ -399,20 +410,19 @@ class ScraperEquivalenceTester:
def load_real_samples(self): def load_real_samples(self):
# Load some real-world HTML samples you've collected # Load some real-world HTML samples you've collected
samples = { samples = {
'article': open('tests/samples/article.html').read(), "article": open("tests/samples/article.html").read(),
'product': open('tests/samples/product.html').read(), "product": open("tests/samples/product.html").read(),
'blog': open('tests/samples/blog.html').read() "blog": open("tests/samples/blog.html").read(),
} }
return samples return samples
def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]: def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]:
"""Detailed comparison of link structures""" """Detailed comparison of link structures"""
differences = [] differences = []
for category in ['internal', 'external']: for category in ["internal", "external"]:
old_urls = {link['href'] for link in old_links[category]} old_urls = {link["href"] for link in old_links[category]}
new_urls = {link['href'] for link in new_links[category]} new_urls = {link["href"] for link in new_links[category]}
missing = old_urls - new_urls missing = old_urls - new_urls
extra = new_urls - old_urls extra = new_urls - old_urls
@@ -425,10 +435,10 @@ class ScraperEquivalenceTester:
# Compare link attributes for common URLs # Compare link attributes for common URLs
common = old_urls & new_urls common = old_urls & new_urls
for url in common: for url in common:
old_link = next(l for l in old_links[category] if l['href'] == url) old_link = next(l for l in old_links[category] if l["href"] == url)
new_link = next(l for l in new_links[category] if l['href'] == url) new_link = next(l for l in new_links[category] if l["href"] == url)
for attr in ['text', 'title']: for attr in ["text", "title"]:
if old_link[attr] != new_link[attr]: if old_link[attr] != new_link[attr]:
differences.append( differences.append(
f"Link attribute mismatch for {url} - {attr}:" f"Link attribute mismatch for {url} - {attr}:"
@@ -441,9 +451,9 @@ class ScraperEquivalenceTester:
"""Detailed comparison of media elements""" """Detailed comparison of media elements"""
differences = [] differences = []
for media_type in ['images', 'videos', 'audios']: for media_type in ["images", "videos", "audios"]:
old_srcs = {item['src'] for item in old_media[media_type]} old_srcs = {item["src"] for item in old_media[media_type]}
new_srcs = {item['src'] for item in new_media[media_type]} new_srcs = {item["src"] for item in new_media[media_type]}
missing = old_srcs - new_srcs missing = old_srcs - new_srcs
extra = new_srcs - old_srcs extra = new_srcs - old_srcs
@@ -456,10 +466,10 @@ class ScraperEquivalenceTester:
# Compare media attributes for common sources # Compare media attributes for common sources
common = old_srcs & new_srcs common = old_srcs & new_srcs
for src in common: for src in common:
old_item = next(m for m in old_media[media_type] if m['src'] == src) old_item = next(m for m in old_media[media_type] if m["src"] == src)
new_item = next(m for m in new_media[media_type] if m['src'] == src) new_item = next(m for m in new_media[media_type] if m["src"] == src)
for attr in ['alt', 'description']: for attr in ["alt", "description"]:
if old_item.get(attr) != new_item.get(attr): if old_item.get(attr) != new_item.get(attr):
differences.append( differences.append(
f"{media_type} attribute mismatch for {src} - {attr}:" f"{media_type} attribute mismatch for {src} - {attr}:"
@@ -474,10 +484,10 @@ class ScraperEquivalenceTester:
differences = [] differences = []
def normalize_html(html: str) -> Tuple[str, str]: def normalize_html(html: str) -> Tuple[str, str]:
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
# Get both structure and text # Get both structure and text
structure = ' '.join(tag.name for tag in soup.find_all()) structure = " ".join(tag.name for tag in soup.find_all())
text = ' '.join(soup.get_text().split()) text = " ".join(soup.get_text().split())
return structure, text return structure, text
old_structure, old_text = normalize_html(old_html) old_structure, old_text = normalize_html(old_html)
@@ -487,46 +497,47 @@ class ScraperEquivalenceTester:
if abs(len(old_structure) - len(new_structure)) > 100: if abs(len(old_structure) - len(new_structure)) > 100:
# if old_structure != new_structure: # if old_structure != new_structure:
diff = difflib.unified_diff( diff = difflib.unified_diff(
old_structure.split(), old_structure.split(), new_structure.split(), lineterm=""
new_structure.split(),
lineterm=''
) )
differences.append("HTML structure differences:\n" + '\n'.join(diff)) differences.append("HTML structure differences:\n" + "\n".join(diff))
# Compare text content # Compare text content
if abs(len(old_text) - len(new_text)) > 100: if abs(len(old_text) - len(new_text)) > 100:
# if old_text != new_text: # if old_text != new_text:
# Show detailed text differences # Show detailed text differences
text_diff = difflib.unified_diff( text_diff = difflib.unified_diff(
old_text.split(), old_text.split(), new_text.split(), lineterm=""
new_text.split(),
lineterm=''
) )
differences.append("Text content differences:\n" + '\n'.join(text_diff)) differences.append("Text content differences:\n" + "\n".join(text_diff))
return differences return differences
def compare_results(self, old_result: Dict, new_result: Dict) -> Dict[str, List[str]]: def compare_results(
self, old_result: Dict, new_result: Dict
) -> Dict[str, List[str]]:
"""Comprehensive comparison of scraper outputs""" """Comprehensive comparison of scraper outputs"""
differences = {} differences = {}
# Compare links # Compare links
link_differences = self.deep_compare_links(old_result['links'], new_result['links']) link_differences = self.deep_compare_links(
old_result["links"], new_result["links"]
)
if link_differences: if link_differences:
differences['links'] = link_differences differences["links"] = link_differences
# Compare media # Compare media
media_differences = self.deep_compare_media(old_result['media'], new_result['media']) media_differences = self.deep_compare_media(
old_result["media"], new_result["media"]
)
if media_differences: if media_differences:
differences['media'] = media_differences differences["media"] = media_differences
# Compare HTML # Compare HTML
html_differences = self.compare_html_content( html_differences = self.compare_html_content(
old_result['cleaned_html'], old_result["cleaned_html"], new_result["cleaned_html"]
new_result['cleaned_html']
) )
if html_differences: if html_differences:
differences['html'] = html_differences differences["html"] = html_differences
return differences return differences
@@ -535,10 +546,7 @@ class ScraperEquivalenceTester:
# We'll still keep some "test_cases" logic from above (basic, complex, malformed). # We'll still keep some "test_cases" logic from above (basic, complex, malformed).
# But we add a new section for the complicated HTML scenarios. # But we add a new section for the complicated HTML scenarios.
results = { results = {"tests": [], "summary": {"passed": 0, "failed": 0}}
'tests': [],
'summary': {'passed': 0, 'failed': 0}
}
# 1) First, run the existing 3 built-in test cases (basic, complex, malformed). # 1) First, run the existing 3 built-in test cases (basic, complex, malformed).
# for case_name, html in self.test_cases.items(): # for case_name, html in self.test_cases.items():
@@ -616,33 +624,38 @@ class ScraperEquivalenceTester:
lxml_time = time.time() - start lxml_time = time.time() - start
diffs = {} diffs = {}
link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links']) link_diff = self.deep_compare_links(
orig_result["links"], lxml_result["links"]
)
if link_diff: if link_diff:
diffs['links'] = link_diff diffs["links"] = link_diff
media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media']) media_diff = self.deep_compare_media(
orig_result["media"], lxml_result["media"]
)
if media_diff: if media_diff:
diffs['media'] = media_diff diffs["media"] = media_diff
html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html']) html_diff = self.compare_html_content(
orig_result["cleaned_html"], lxml_result["cleaned_html"]
)
if html_diff: if html_diff:
diffs['html'] = html_diff diffs["html"] = html_diff
test_result = { test_result = {
'case': f"complicated_{scenario_name}", "case": f"complicated_{scenario_name}",
'lxml_mode': { "lxml_mode": {"differences": diffs, "execution_time": lxml_time},
'differences': diffs, "original_time": orig_time,
'execution_time': lxml_time
},
'original_time': orig_time
} }
results['tests'].append(test_result) results["tests"].append(test_result)
if not diffs: if not diffs:
results['summary']['passed'] += 1 results["summary"]["passed"] += 1
print(f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)") print(
f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)"
)
else: else:
results['summary']['failed'] += 1 results["summary"]["failed"] += 1
print("❌ Differences found:") print("❌ Differences found:")
for category, dlist in diffs.items(): for category, dlist in diffs.items():
print(f" {category}:") print(f" {category}:")
@@ -658,19 +671,21 @@ class ScraperEquivalenceTester:
print(f"Passed: {results['summary']['passed']}") print(f"Passed: {results['summary']['passed']}")
print(f"Failed: {results['summary']['failed']}") print(f"Failed: {results['summary']['failed']}")
for test in results['tests']: for test in results["tests"]:
print(f"\nTest Case: {test['case']}") print(f"\nTest Case: {test['case']}")
if not test['lxml_mode']['differences']: if not test["lxml_mode"]["differences"]:
print("✅ All implementations produced identical results") print("✅ All implementations produced identical results")
print(f"Times - Original: {test['original_time']:.3f}s, " print(
f"LXML: {test['lxml_mode']['execution_time']:.3f}s") f"Times - Original: {test['original_time']:.3f}s, "
f"LXML: {test['lxml_mode']['execution_time']:.3f}s"
)
else: else:
print("❌ Differences found:") print("❌ Differences found:")
if test['lxml_mode']['differences']: if test["lxml_mode"]["differences"]:
print("\nLXML Mode Differences:") print("\nLXML Mode Differences:")
for category, diffs in test['lxml_mode']['differences'].items(): for category, diffs in test["lxml_mode"]["differences"].items():
print(f"\n{category}:") print(f"\n{category}:")
for diff in diffs: for diff in diffs:
print(f" - {diff}") print(f" - {diff}")
@@ -682,7 +697,7 @@ def main():
tester.print_report(results) tester.print_report(results)
# Save detailed results for debugging # Save detailed results for debugging
with open('scraper_equivalence_results.json', 'w') as f: with open("scraper_equivalence_results.json", "w") as f:
json.dump(results, f, indent=2) json.dump(results, f, indent=2)

View File

@@ -4,10 +4,10 @@
# - **State:** open # - **State:** open
import os, sys, time import os, sys, time
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
import asyncio
import os import os
import time import time
from typing import Dict, Any from typing import Dict, Any
@@ -16,12 +16,12 @@ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Get current directory # Get current directory
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
def print_test_result(name: str, result: Dict[str, Any], execution_time: float): def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
"""Helper function to print test results.""" """Helper function to print test results."""
print(f"\n{'='*20} {name} {'='*20}") print(f"\n{'='*20} {name} {'='*20}")
print(f"Execution time: {execution_time:.4f} seconds") print(f"Execution time: {execution_time:.4f} seconds")
# Save markdown to files # Save markdown to files
for key, content in result.items(): for key, content in result.items():
if isinstance(content, str): if isinstance(content, str):
@@ -36,6 +36,7 @@ def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
# print(preview) # print(preview)
# print(f"Total length: {len(content)} characters") # print(f"Total length: {len(content)} characters")
def test_basic_markdown_conversion(): def test_basic_markdown_conversion():
"""Test basic markdown conversion with links.""" """Test basic markdown conversion with links."""
with open(__location__ + "/data/wikipedia.html", "r") as f: with open(__location__ + "/data/wikipedia.html", "r") as f:
@@ -45,23 +46,29 @@ def test_basic_markdown_conversion():
start_time = time.perf_counter() start_time = time.perf_counter()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=cleaned_html, cleaned_html=cleaned_html, base_url="https://en.wikipedia.org"
base_url="https://en.wikipedia.org"
) )
execution_time = time.perf_counter() - start_time execution_time = time.perf_counter() - start_time
print_test_result("Basic Markdown Conversion", { print_test_result(
'raw': result.raw_markdown, "Basic Markdown Conversion",
'with_citations': result.markdown_with_citations, {
'references': result.references_markdown "raw": result.raw_markdown,
}, execution_time) "with_citations": result.markdown_with_citations,
"references": result.references_markdown,
},
execution_time,
)
# Basic assertions # Basic assertions
assert result.raw_markdown, "Raw markdown should not be empty" assert result.raw_markdown, "Raw markdown should not be empty"
assert result.markdown_with_citations, "Markdown with citations should not be empty" assert result.markdown_with_citations, "Markdown with citations should not be empty"
assert result.references_markdown, "References should not be empty" assert result.references_markdown, "References should not be empty"
assert "" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets" assert "" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets"
assert "## References" in result.references_markdown, "Should contain references section" assert (
"## References" in result.references_markdown
), "Should contain references section"
def test_relative_links(): def test_relative_links():
"""Test handling of relative links with base URL.""" """Test handling of relative links with base URL."""
@@ -72,14 +79,14 @@ def test_relative_links():
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://en.wikipedia.org"
base_url="https://en.wikipedia.org"
) )
assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown
assert "https://example.com" in result.references_markdown assert "https://example.com" in result.references_markdown
assert "https://en.wikipedia.org/images/test.png" in result.references_markdown assert "https://en.wikipedia.org/images/test.png" in result.references_markdown
def test_duplicate_links(): def test_duplicate_links():
"""Test handling of duplicate links.""" """Test handling of duplicate links."""
markdown = """ markdown = """
@@ -88,14 +95,14 @@ def test_duplicate_links():
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://example.com"
base_url="https://example.com"
) )
# Count citations in markdown # Count citations in markdown
citations = result.markdown_with_citations.count("⟨1⟩") citations = result.markdown_with_citations.count("⟨1⟩")
assert citations == 2, "Same link should use same citation number" assert citations == 2, "Same link should use same citation number"
def test_link_descriptions(): def test_link_descriptions():
"""Test handling of link titles and descriptions.""" """Test handling of link titles and descriptions."""
markdown = """ markdown = """
@@ -104,12 +111,16 @@ def test_link_descriptions():
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://example.com"
base_url="https://example.com"
) )
assert "Test Title" in result.references_markdown, "Link title should be in references" assert (
assert "link with description" in result.references_markdown, "Link text should be in references" "Test Title" in result.references_markdown
), "Link title should be in references"
assert (
"link with description" in result.references_markdown
), "Link text should be in references"
def test_performance_large_document(): def test_performance_large_document():
"""Test performance with large document.""" """Test performance with large document."""
@@ -125,18 +136,20 @@ def test_performance_large_document():
for i in range(iterations): for i in range(iterations):
start_time = time.perf_counter() start_time = time.perf_counter()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://en.wikipedia.org"
base_url="https://en.wikipedia.org"
) )
end_time = time.perf_counter() end_time = time.perf_counter()
times.append(end_time - start_time) times.append(end_time - start_time)
avg_time = sum(times) / len(times) avg_time = sum(times) / len(times)
print(f"\n{'='*20} Performance Test {'='*20}") print(f"\n{'='*20} Performance Test {'='*20}")
print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds") print(
f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds"
)
print(f"Min time: {min(times):.4f} seconds") print(f"Min time: {min(times):.4f} seconds")
print(f"Max time: {max(times):.4f} seconds") print(f"Max time: {max(times):.4f} seconds")
def test_image_links(): def test_image_links():
"""Test handling of image links.""" """Test handling of image links."""
markdown = """ markdown = """
@@ -146,12 +159,16 @@ def test_image_links():
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://example.com"
base_url="https://example.com"
) )
assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved" assert (
assert "Image Title" in result.references_markdown, "Image title should be in references" "![" in result.markdown_with_citations
), "Image markdown syntax should be preserved"
assert (
"Image Title" in result.references_markdown
), "Image title should be in references"
if __name__ == "__main__": if __name__ == "__main__":
print("Running markdown generation strategy tests...") print("Running markdown generation strategy tests...")
@@ -162,4 +179,3 @@ if __name__ == "__main__":
test_link_descriptions() test_link_descriptions()
test_performance_large_document() test_performance_large_document()
test_image_links() test_image_links()

View File

@@ -1,8 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,24 +8,37 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_word_count_threshold(): async def test_word_count_threshold():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True) result_no_threshold = await crawler.arun(
result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True) url=url, word_count_threshold=0, bypass_cache=True
)
result_with_threshold = await crawler.arun(
url=url, word_count_threshold=50, bypass_cache=True
)
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown) assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_css_selector(): async def test_css_selector():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
css_selector = "h1, h2, h3" css_selector = "h1, h2, h3"
result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True) result = await crawler.arun(
url=url, css_selector=css_selector, bypass_cache=True
)
assert result.success assert result.success
assert "<h1" in result.cleaned_html or "<h2" in result.cleaned_html or "<h3" in result.cleaned_html assert (
"<h1" in result.cleaned_html
or "<h2" in result.cleaned_html
or "<h3" in result.cleaned_html
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_javascript_execution(): async def test_javascript_execution():
@@ -37,12 +48,15 @@ async def test_javascript_execution():
# Crawl without JS # Crawl without JS
result_without_more = await crawler.arun(url=url, bypass_cache=True) result_without_more = await crawler.arun(url=url, bypass_cache=True)
js_code = ["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"] js_code = [
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
]
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True) result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
assert result_with_more.success assert result_with_more.success
assert len(result_with_more.markdown) > len(result_without_more.markdown) assert len(result_with_more.markdown) > len(result_without_more.markdown)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot(): async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -53,16 +67,20 @@ async def test_screenshot():
assert result.screenshot assert result.screenshot
assert isinstance(result.screenshot, str) # Should be a base64 encoded string assert isinstance(result.screenshot, str) # Should be a base64 encoded string
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_user_agent(): async def test_custom_user_agent():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0" custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True) result = await crawler.arun(
url=url, user_agent=custom_user_agent, bypass_cache=True
)
assert result.success assert result.success
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful # Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_media_and_links(): async def test_extract_media_and_links():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -72,10 +90,11 @@ async def test_extract_media_and_links():
assert result.success assert result.success
assert result.media assert result.media
assert isinstance(result.media, dict) assert isinstance(result.media, dict)
assert 'images' in result.media assert "images" in result.media
assert result.links assert result.links
assert isinstance(result.links, dict) assert isinstance(result.links, dict)
assert 'internal' in result.links and 'external' in result.links assert "internal" in result.links and "external" in result.links
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metadata_extraction(): async def test_metadata_extraction():
@@ -87,7 +106,10 @@ async def test_metadata_extraction():
assert result.metadata assert result.metadata
assert isinstance(result.metadata, dict) assert isinstance(result.metadata, dict)
# Check for common metadata fields # Check for common metadata fields
assert any(key in result.metadata for key in ['title', 'description', 'keywords']) assert any(
key in result.metadata for key in ["title", "description", "keywords"]
)
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,7 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import time import time
# Add the parent directory to the Python path # Add the parent directory to the Python path
@@ -10,6 +9,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_crawl_speed(): async def test_crawl_speed():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -24,6 +24,7 @@ async def test_crawl_speed():
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds" assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_crawling_performance(): async def test_concurrent_crawling_performance():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -32,7 +33,7 @@ async def test_concurrent_crawling_performance():
"https://www.example.com", "https://www.example.com",
"https://www.python.org", "https://www.python.org",
"https://www.github.com", "https://www.github.com",
"https://www.stackoverflow.com" "https://www.stackoverflow.com",
] ]
start_time = time.time() start_time = time.time()
@@ -45,7 +46,10 @@ async def test_concurrent_crawling_performance():
assert all(result.success for result in results) assert all(result.success for result in results)
assert len(results) == len(urls) assert len(results) == len(urls)
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" assert (
total_time < len(urls) * 5
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_crawl_speed_with_caching(): async def test_crawl_speed_with_caching():
@@ -66,7 +70,10 @@ async def test_crawl_speed_with_caching():
print(f"First crawl time: {first_crawl_time:.2f} seconds") print(f"First crawl time: {first_crawl_time:.2f} seconds")
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds") print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster" assert (
second_crawl_time < first_crawl_time / 2
), "Cached crawl not significantly faster"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import base64 import base64
from PIL import Image from PIL import Image
import io import io
@@ -12,6 +11,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_screenshot(): async def test_basic_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -26,6 +26,7 @@ async def test_basic_screenshot():
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_with_wait_for(): async def test_screenshot_with_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -34,10 +35,7 @@ async def test_screenshot_with_wait_for():
wait_for = "css:#content" # Wait for the main content to load wait_for = "css:#content" # Wait for the main content to load
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
bypass_cache=True,
screenshot=True,
wait_for=wait_for
) )
assert result.success assert result.success
@@ -51,6 +49,7 @@ async def test_screenshot_with_wait_for():
# You might want to add more specific checks here, like image dimensions # You might want to add more specific checks here, like image dimensions
# or even use image recognition to verify certain elements are present # or even use image recognition to verify certain elements are present
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_with_js_wait_for(): async def test_screenshot_with_js_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -58,10 +57,7 @@ async def test_screenshot_with_js_wait_for():
wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null" wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null"
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
bypass_cache=True,
screenshot=True,
wait_for=wait_for
) )
assert result.success assert result.success
@@ -71,6 +67,7 @@ async def test_screenshot_with_js_wait_for():
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_without_wait_for(): async def test_screenshot_without_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -85,6 +82,7 @@ async def test_screenshot_without_wait_for():
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_comparison(): async def test_screenshot_comparison():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -93,17 +91,12 @@ async def test_screenshot_comparison():
# Take screenshot without wait_for # Take screenshot without wait_for
result_without_wait = await crawler.arun( result_without_wait = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True
bypass_cache=True,
screenshot=True
) )
# Take screenshot with wait_for # Take screenshot with wait_for
result_with_wait = await crawler.arun( result_with_wait = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
bypass_cache=True,
screenshot=True,
wait_for=wait_for
) )
assert result_without_wait.success and result_with_wait.success assert result_without_wait.success and result_with_wait.success
@@ -111,14 +104,19 @@ async def test_screenshot_comparison():
assert result_with_wait.screenshot is not None assert result_with_wait.screenshot is not None
# Compare the two screenshots # Compare the two screenshots
image_without_wait = Image.open(io.BytesIO(base64.b64decode(result_without_wait.screenshot))) image_without_wait = Image.open(
image_with_wait = Image.open(io.BytesIO(base64.b64decode(result_with_wait.screenshot))) io.BytesIO(base64.b64decode(result_without_wait.screenshot))
)
image_with_wait = Image.open(
io.BytesIO(base64.b64decode(result_with_wait.screenshot))
)
# This is a simple size comparison. In a real-world scenario, you might want to use # This is a simple size comparison. In a real-world scenario, you might want to use
# more sophisticated image comparison techniques. # more sophisticated image comparison techniques.
assert image_with_wait.size[0] >= image_without_wait.size[0] assert image_with_wait.size[0] >= image_without_wait.size[0]
assert image_with_wait.size[1] >= image_without_wait.size[1] assert image_with_wait.size[1] >= image_without_wait.size[1]
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -6,15 +6,24 @@ import base64
import os import os
from typing import Dict, Any from typing import Dict, Any
class Crawl4AiTester: class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
self.base_url = base_url self.base_url = base_url
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback self.api_token = api_token or os.getenv(
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} "CRAWL4AI_API_TOKEN"
) # Check environment variable as fallback
self.headers = (
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
)
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job # Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
if response.status_code == 403: if response.status_code == 403:
raise Exception("API token is invalid or missing") raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"] task_id = response.json()["task_id"]
@@ -24,9 +33,13 @@ class Crawl4AiTester:
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) result = requests.get(
f"{self.base_url}/task/{task_id}", headers=self.headers
)
status = result.json() status = result.json()
if status["status"] == "failed": if status["status"] == "failed":
@@ -39,17 +52,23 @@ class Crawl4AiTester:
time.sleep(2) time.sleep(2)
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) response = requests.post(
f"{self.base_url}/crawl_sync",
json=request_data,
headers=self.headers,
timeout=60,
)
if response.status_code == 408: if response.status_code == 408:
raise TimeoutError("Task did not complete within server timeout") raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def test_docker_deployment(version="basic"): def test_docker_deployment(version="basic"):
tester = Crawl4AiTester( tester = Crawl4AiTester(
# base_url="http://localhost:11235" , # base_url="http://localhost:11235" ,
base_url="https://crawl4ai-sby74.ondigitalocean.app", base_url="https://crawl4ai-sby74.ondigitalocean.app",
api_token="test" api_token="test",
) )
print(f"Testing Crawl4AI Docker {version} version") print(f"Testing Crawl4AI Docker {version} version")
@@ -60,7 +79,7 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10) health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json()) print("Health check:", health.json())
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
if i == max_retries - 1: if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts") print(f"Failed to connect after {max_retries} attempts")
sys.exit(1) sys.exit(1)
@@ -88,7 +107,7 @@ def test_basic_crawl(tester: Crawl4AiTester):
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -96,19 +115,21 @@ def test_basic_crawl(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0 assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_sync(tester: Crawl4AiTester): def test_basic_crawl_sync(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Sync) ===") print("\n=== Testing Basic Crawl (Sync) ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_sync(request) result = tester.submit_sync(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['status'] == 'completed' assert result["status"] == "completed"
assert result['result']['success'] assert result["result"]["success"]
assert len(result['result']['markdown']) > 0 assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester): def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
@@ -119,32 +140,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}") print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester): def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description", "css_selector": ".wide-tease-item__description",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True "extra": {"word_count_threshold": 10},
},
"extra": {"word_count_threshold": 10}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}") print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester): def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
schema = { schema = {
@@ -165,19 +183,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price", "name": "price",
"selector": "td:nth-child(2)", "selector": "td:nth-child(2)",
"type": "text", "type": "text",
} },
], ],
} }
request = { request = {
"urls": "https://www.coinbase.com/explore", "urls": "https://www.coinbase.com/explore",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -187,6 +200,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester): def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===") print("\n=== Testing LLM Extraction ===")
schema = { schema = {
@@ -194,18 +208,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": { "properties": {
"model_name": { "model_name": {
"type": "string", "type": "string",
"description": "Name of the OpenAI model." "description": "Name of the OpenAI model.",
}, },
"input_fee": { "input_fee": {
"type": "string", "type": "string",
"description": "Fee for input token for the OpenAI model." "description": "Fee for input token for the OpenAI model.",
}, },
"output_fee": { "output_fee": {
"type": "string", "type": "string",
"description": "Fee for output token for the OpenAI model." "description": "Fee for output token for the OpenAI model.",
}
}, },
"required": ["model_name", "input_fee", "output_fee"] },
"required": ["model_name", "input_fee", "output_fee"],
} }
request = { request = {
@@ -218,10 +232,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"), "api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
}
}, },
"crawler_params": {"word_count_threshold": 1} },
"crawler_params": {"word_count_threshold": 1},
} }
try: try:
@@ -233,6 +247,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester): def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===") print("\n=== Testing LLM with Ollama ===")
schema = { schema = {
@@ -240,18 +255,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
} },
} }
request = { request = {
@@ -263,11 +278,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2", "provider": "ollama/llama2",
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics." "instruction": "Extract the main article information including title, summary, and main topics.",
} },
}, },
"extra": {"word_count_threshold": 1}, "extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True} "crawler_params": {"verbose": True},
} }
try: try:
@@ -278,6 +293,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Ollama extraction test failed: {str(e)}") print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester): def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===") print("\n=== Testing Cosine Extraction ===")
request = { request = {
@@ -289,9 +305,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy", "semantic_filter": "business finance economy",
"word_count_threshold": 10, "word_count_threshold": 10,
"max_dist": 0.2, "max_dist": 0.2,
"top_k": 3 "top_k": 3,
} },
} },
} }
try: try:
@@ -303,15 +319,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Cosine extraction test failed: {str(e)}") print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester): def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -326,6 +341,7 @@ def test_screenshot(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
if __name__ == "__main__": if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic" version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full" # version = "full"

View File

@@ -1,9 +1,9 @@
import asyncio import asyncio
from pathlib import Path
from crawl4ai.docs_manager import DocsManager from crawl4ai.docs_manager import DocsManager
from click.testing import CliRunner from click.testing import CliRunner
from crawl4ai.cli import cli from crawl4ai.cli import cli
def test_cli(): def test_cli():
"""Test all CLI commands""" """Test all CLI commands"""
runner = CliRunner() runner = CliRunner()
@@ -35,9 +35,10 @@ def test_cli():
# print(f"First 200 chars: {result.output[:200]}...") # print(f"First 200 chars: {result.output[:200]}...")
print("\n5. Testing combine all sections...") print("\n5. Testing combine all sections...")
result = runner.invoke(cli, ['docs', 'combine', '--mode', 'condensed']) result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"])
print(f"Status: {'' if result.exit_code == 0 else ''}") print(f"Status: {'' if result.exit_code == 0 else ''}")
print(f"First 200 chars: {result.output[:200]}...") print(f"First 200 chars: {result.output[:200]}...")
if __name__ == "__main__": if __name__ == "__main__":
test_cli() test_cli()

View File

@@ -6,11 +6,14 @@ import base64
import os import os
from typing import Dict, Any from typing import Dict, Any
class Crawl4AiTester: class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235"): def __init__(self, base_url: str = "http://localhost:11235"):
self.base_url = base_url self.base_url = base_url
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job # Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data) response = requests.post(f"{self.base_url}/crawl", json=request_data)
task_id = response.json()["task_id"] task_id = response.json()["task_id"]
@@ -20,7 +23,9 @@ class Crawl4AiTester:
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}") result = requests.get(f"{self.base_url}/task/{task_id}")
status = result.json() status = result.json()
@@ -34,6 +39,7 @@ class Crawl4AiTester:
time.sleep(2) time.sleep(2)
def test_docker_deployment(version="basic"): def test_docker_deployment(version="basic"):
tester = Crawl4AiTester() tester = Crawl4AiTester()
print(f"Testing Crawl4AI Docker {version} version") print(f"Testing Crawl4AI Docker {version} version")
@@ -45,7 +51,7 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10) health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json()) print("Health check:", health.json())
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
if i == max_retries - 1: if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts") print(f"Failed to connect after {max_retries} attempts")
sys.exit(1) sys.exit(1)
@@ -68,16 +74,14 @@ def test_docker_deployment(version="basic"):
def test_basic_crawl(tester: Crawl4AiTester): def test_basic_crawl(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl ===") print("\n=== Testing Basic Crawl ===")
request = { request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0 assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester): def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
request = { request = {
@@ -87,32 +91,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}") print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester): def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description", "css_selector": ".wide-tease-item__description",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True "extra": {"word_count_threshold": 10},
},
"extra": {"word_count_threshold": 10}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}") print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester): def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
schema = { schema = {
@@ -133,19 +134,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price", "name": "price",
"selector": "td:nth-child(2)", "selector": "td:nth-child(2)",
"type": "text", "type": "text",
} },
], ],
} }
request = { request = {
"urls": "https://www.coinbase.com/explore", "urls": "https://www.coinbase.com/explore",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -155,6 +151,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester): def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===") print("\n=== Testing LLM Extraction ===")
schema = { schema = {
@@ -162,18 +159,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": { "properties": {
"model_name": { "model_name": {
"type": "string", "type": "string",
"description": "Name of the OpenAI model." "description": "Name of the OpenAI model.",
}, },
"input_fee": { "input_fee": {
"type": "string", "type": "string",
"description": "Fee for input token for the OpenAI model." "description": "Fee for input token for the OpenAI model.",
}, },
"output_fee": { "output_fee": {
"type": "string", "type": "string",
"description": "Fee for output token for the OpenAI model." "description": "Fee for output token for the OpenAI model.",
}
}, },
"required": ["model_name", "input_fee", "output_fee"] },
"required": ["model_name", "input_fee", "output_fee"],
} }
request = { request = {
@@ -186,10 +183,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"), "api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
}
}, },
"crawler_params": {"word_count_threshold": 1} },
"crawler_params": {"word_count_threshold": 1},
} }
try: try:
@@ -201,6 +198,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester): def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===") print("\n=== Testing LLM with Ollama ===")
schema = { schema = {
@@ -208,18 +206,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
} },
} }
request = { request = {
@@ -231,11 +229,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2", "provider": "ollama/llama2",
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics." "instruction": "Extract the main article information including title, summary, and main topics.",
} },
}, },
"extra": {"word_count_threshold": 1}, "extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True} "crawler_params": {"verbose": True},
} }
try: try:
@@ -246,6 +244,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Ollama extraction test failed: {str(e)}") print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester): def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===") print("\n=== Testing Cosine Extraction ===")
request = { request = {
@@ -257,9 +256,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy", "semantic_filter": "business finance economy",
"word_count_threshold": 10, "word_count_threshold": 10,
"max_dist": 0.2, "max_dist": 0.2,
"top_k": 3 "top_k": 3,
} },
} },
} }
try: try:
@@ -271,15 +270,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Cosine extraction test failed: {str(e)}") print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester): def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
@@ -294,6 +292,7 @@ def test_screenshot(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
if __name__ == "__main__": if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic" version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full" # version = "full"

View File

@@ -3,6 +3,7 @@ from crawl4ai.async_logger import AsyncLogger
from pathlib import Path from pathlib import Path
import asyncio import asyncio
async def main(): async def main():
current_file = Path(__file__).resolve() current_file = Path(__file__).resolve()
# base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs" # base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
@@ -26,8 +27,7 @@ async def main():
# Generate index files # Generate index files
print("\nGenerating index files...") print("\nGenerating index files...")
await manager.generate_index_files( await manager.generate_index_files(
force_generate_facts=False, force_generate_facts=False, clear_bm25_cache=False
clear_bm25_cache=False
) )
# Test some relevant queries about Crawl4AI # Test some relevant queries about Crawl4AI
@@ -41,9 +41,12 @@ async def main():
results = manager.search(query, top_k=2) results = manager.search(query, top_k=2)
print(f"Results length: {len(results)} characters") print(f"Results length: {len(results)} characters")
if results: if results:
print("First 200 chars of results:", results[:200].replace('\n', ' '), "...") print(
"First 200 chars of results:", results[:200].replace("\n", " "), "..."
)
else: else:
print("No results found") print("No results found")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -3,8 +3,8 @@ import aiohttp
import json import json
import time import time
import os import os
from typing import Optional, Dict, Any from typing import Dict, Any
from pydantic import BaseModel, HttpUrl
class NBCNewsAPITest: class NBCNewsAPITest:
def __init__(self, base_url: str = "http://localhost:8000"): def __init__(self, base_url: str = "http://localhost:8000"):
@@ -20,7 +20,9 @@ class NBCNewsAPITest:
await self.session.close() await self.session.close()
async def submit_crawl(self, request_data: Dict[str, Any]) -> str: async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response: async with self.session.post(
f"{self.base_url}/crawl", json=request_data
) as response:
result = await response.json() result = await response.json()
return result["task_id"] return result["task_id"]
@@ -28,11 +30,15 @@ class NBCNewsAPITest:
async with self.session.get(f"{self.base_url}/task/{task_id}") as response: async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
return await response.json() return await response.json()
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]: async def wait_for_task(
self, task_id: str, timeout: int = 300, poll_interval: int = 2
) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
status = await self.get_task_status(task_id) status = await self.get_task_status(task_id)
if status["status"] in ["completed", "failed"]: if status["status"] in ["completed", "failed"]:
@@ -44,13 +50,11 @@ class NBCNewsAPITest:
async with self.session.get(f"{self.base_url}/health") as response: async with self.session.get(f"{self.base_url}/health") as response:
return await response.json() return await response.json()
async def test_basic_crawl(): async def test_basic_crawl():
print("\n=== Testing Basic Crawl ===") print("\n=== Testing Basic Crawl ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
request = { request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
@@ -58,6 +62,7 @@ async def test_basic_crawl():
assert "result" in result assert "result" in result
assert result["result"]["success"] assert result["result"]["success"]
async def test_js_execution(): async def test_js_execution():
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -68,9 +73,7 @@ async def test_js_execution():
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -78,13 +81,14 @@ async def test_js_execution():
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["result"]["success"] assert result["result"]["success"]
async def test_css_selector(): async def test_css_selector():
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description" "css_selector": ".wide-tease-item__description",
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -92,6 +96,7 @@ async def test_css_selector():
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["result"]["success"] assert result["result"]["success"]
async def test_structured_extraction(): async def test_structured_extraction():
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -99,34 +104,25 @@ async def test_structured_extraction():
"name": "NBC News Articles", "name": "NBC News Articles",
"baseSelector": "article.tease-card", "baseSelector": "article.tease-card",
"fields": [ "fields": [
{ {"name": "title", "selector": "h2", "type": "text"},
"name": "title",
"selector": "h2",
"type": "text"
},
{ {
"name": "description", "name": "description",
"selector": ".tease-card__description", "selector": ".tease-card__description",
"type": "text" "type": "text",
}, },
{ {
"name": "link", "name": "link",
"selector": "a", "selector": "a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
} },
] ],
} }
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -136,6 +132,7 @@ async def test_structured_extraction():
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
async def test_batch_crawl(): async def test_batch_crawl():
print("\n=== Testing Batch Crawl ===") print("\n=== Testing Batch Crawl ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -143,12 +140,10 @@ async def test_batch_crawl():
"urls": [ "urls": [
"https://www.nbcnews.com/business", "https://www.nbcnews.com/business",
"https://www.nbcnews.com/business/consumer", "https://www.nbcnews.com/business/consumer",
"https://www.nbcnews.com/business/economy" "https://www.nbcnews.com/business/economy",
], ],
"priority": 6, "priority": 6,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -157,6 +152,7 @@ async def test_batch_crawl():
assert "results" in result assert "results" in result
assert len(result["results"]) == 3 assert len(result["results"]) == 3
async def test_llm_extraction(): async def test_llm_extraction():
print("\n=== Testing LLM Extraction with Ollama ===") print("\n=== Testing LLM Extraction with Ollama ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -165,19 +161,19 @@ async def test_llm_extraction():
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
}
}, },
"required": ["article_title", "summary", "main_topics"] },
"required": ["article_title", "summary", "main_topics"],
} }
request = { request = {
@@ -191,13 +187,10 @@ async def test_llm_extraction():
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed. "instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
Focus on the primary business news article on the page.""" Focus on the primary business news article on the page.""",
}
}, },
"crawler_params": { },
"headless": True, "crawler_params": {"headless": True, "word_count_threshold": 1},
"word_count_threshold": 1
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
@@ -205,12 +198,13 @@ async def test_llm_extraction():
if result["status"] == "completed": if result["status"] == "completed":
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted article analysis:") print("Extracted article analysis:")
print(json.dumps(extracted, indent=2)) print(json.dumps(extracted, indent=2))
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["result"]["success"] assert result["result"]["success"]
async def test_screenshot(): async def test_screenshot():
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -218,9 +212,7 @@ async def test_screenshot():
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -229,6 +221,7 @@ async def test_screenshot():
assert result["result"]["success"] assert result["result"]["success"]
assert result["result"]["screenshot"] is not None assert result["result"]["screenshot"] is not None
async def test_priority_handling(): async def test_priority_handling():
print("\n=== Testing Priority Handling ===") print("\n=== Testing Priority Handling ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -236,7 +229,7 @@ async def test_priority_handling():
low_priority = { low_priority = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 1, "priority": 1,
"crawler_params": {"headless": True} "crawler_params": {"headless": True},
} }
low_task_id = await api.submit_crawl(low_priority) low_task_id = await api.submit_crawl(low_priority)
@@ -244,7 +237,7 @@ async def test_priority_handling():
high_priority = { high_priority = {
"urls": "https://www.nbcnews.com/business/consumer", "urls": "https://www.nbcnews.com/business/consumer",
"priority": 10, "priority": 10,
"crawler_params": {"headless": True} "crawler_params": {"headless": True},
} }
high_task_id = await api.submit_crawl(high_priority) high_task_id = await api.submit_crawl(high_priority)
@@ -256,6 +249,7 @@ async def test_priority_handling():
assert high_result["status"] == "completed" assert high_result["status"] == "completed"
assert low_result["status"] == "completed" assert low_result["status"] == "completed"
async def main(): async def main():
try: try:
# Start with health check # Start with health check
@@ -277,5 +271,6 @@ async def main():
print(f"Test failed: {str(e)}") print(f"Test failed: {str(e)}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,21 +1,26 @@
import nest_asyncio import nest_asyncio
nest_asyncio.apply() nest_asyncio.apply()
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, LXMLWebScrapingStrategy, CacheMode from crawl4ai import (
AsyncWebCrawler,
CrawlerRunConfig,
LXMLWebScrapingStrategy,
CacheMode,
)
async def main(): async def main():
config = CrawlerRunConfig( config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
scraping_strategy=LXMLWebScrapingStrategy() # Faster alternative to default BeautifulSoup scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=config)
url="https://example.com",
config=config
)
print(f"Success: {result.success}") print(f"Success: {result.success}")
print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}") print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,10 +1,19 @@
import unittest, os import unittest, os
from crawl4ai.web_crawler import WebCrawler from crawl4ai.web_crawler import WebCrawler
from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking from crawl4ai.chunking_strategy import (
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy RegexChunking,
FixedLengthWordChunking,
SlidingWindowChunking,
)
from crawl4ai.extraction_strategy import (
CosineStrategy,
LLMExtractionStrategy,
TopicExtractionStrategy,
NoExtractionStrategy,
)
class TestWebCrawler(unittest.TestCase): class TestWebCrawler(unittest.TestCase):
def setUp(self): def setUp(self):
self.crawler = WebCrawler() self.crawler = WebCrawler()
@@ -14,52 +23,72 @@ class TestWebCrawler(unittest.TestCase):
def test_run_default_strategies(self): def test_run_default_strategies(self):
result = self.crawler.run( result = self.crawler.run(
url='https://www.nbcnews.com/business', url="https://www.nbcnews.com/business",
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=RegexChunking(), chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(), bypass_cache=True extraction_strategy=CosineStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success, "Failed to crawl and extract using default strategies"
) )
self.assertTrue(result.success, "Failed to crawl and extract using default strategies")
def test_run_different_strategies(self): def test_run_different_strategies(self):
url = 'https://www.nbcnews.com/business' url = "https://www.nbcnews.com/business"
# Test with FixedLengthWordChunking and LLMExtractionStrategy # Test with FixedLengthWordChunking and LLMExtractionStrategy
result = self.crawler.run( result = self.crawler.run(
url=url, url=url,
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100), chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY")
),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy",
) )
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy")
# Test with SlidingWindowChunking and TopicExtractionStrategy # Test with SlidingWindowChunking and TopicExtractionStrategy
result = self.crawler.run( result = self.crawler.run(
url=url, url=url,
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50), chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True extraction_strategy=TopicExtractionStrategy(num_keywords=5),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy",
) )
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy")
def test_invalid_url(self): def test_invalid_url(self):
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
self.crawler.run(url='invalid_url', bypass_cache=True) self.crawler.run(url="invalid_url", bypass_cache=True)
self.assertIn("Invalid URL", str(context.exception)) self.assertIn("Invalid URL", str(context.exception))
def test_unsupported_extraction_strategy(self): def test_unsupported_extraction_strategy(self):
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True) self.crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy="UnsupportedStrategy",
bypass_cache=True,
)
self.assertIn("Unsupported extraction strategy", str(context.exception)) self.assertIn("Unsupported extraction strategy", str(context.exception))
def test_invalid_css_selector(self): def test_invalid_css_selector(self):
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True) self.crawler.run(
url="https://www.nbcnews.com/business",
css_selector="invalid_selector",
bypass_cache=True,
)
self.assertIn("Invalid CSS selector", str(context.exception)) self.assertIn("Invalid CSS selector", str(context.exception))
def test_crawl_with_cache_and_bypass_cache(self): def test_crawl_with_cache_and_bypass_cache(self):
url = 'https://www.nbcnews.com/business' url = "https://www.nbcnews.com/business"
# First crawl with cache enabled # First crawl with cache enabled
result = self.crawler.run(url=url, bypass_cache=False) result = self.crawler.run(url=url, bypass_cache=False)
@@ -70,10 +99,7 @@ class TestWebCrawler(unittest.TestCase):
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data") self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
def test_fetch_multiple_pages(self): def test_fetch_multiple_pages(self):
urls = [ urls = ["https://www.nbcnews.com/business", "https://www.bbc.com/news"]
'https://www.nbcnews.com/business',
'https://www.bbc.com/news'
]
results = [] results = []
for url in urls: for url in urls:
result = self.crawler.run( result = self.crawler.run(
@@ -81,31 +107,42 @@ class TestWebCrawler(unittest.TestCase):
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=RegexChunking(), chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(), extraction_strategy=CosineStrategy(),
bypass_cache=True bypass_cache=True,
) )
results.append(result) results.append(result)
self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages") self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages")
for result in results: for result in results:
self.assertTrue(result.success, "Failed to crawl and extract a page in the list") self.assertTrue(
result.success, "Failed to crawl and extract a page in the list"
)
def test_run_fixed_length_word_chunking_and_no_extraction(self): def test_run_fixed_length_word_chunking_and_no_extraction(self):
result = self.crawler.run( result = self.crawler.run(
url='https://www.nbcnews.com/business', url="https://www.nbcnews.com/business",
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100), chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=NoExtractionStrategy(), bypass_cache=True extraction_strategy=NoExtractionStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy",
) )
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy")
def test_run_sliding_window_and_no_extraction(self): def test_run_sliding_window_and_no_extraction(self):
result = self.crawler.run( result = self.crawler.run(
url='https://www.nbcnews.com/business', url="https://www.nbcnews.com/business",
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50), chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=NoExtractionStrategy(), bypass_cache=True extraction_strategy=NoExtractionStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy",
) )
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()