Apply Ruff Corrections

This commit is contained in:
UncleCode
2025-01-13 19:19:58 +08:00
parent c3370ec5da
commit 8ec12d7d68
84 changed files with 6861 additions and 5076 deletions

8
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,8 @@
# .pre-commit-config.yaml
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@@ -2,14 +2,28 @@
from .async_webcrawler import AsyncWebCrawler, CacheMode
from .async_configs import BrowserConfig, CrawlerRunConfig
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy
from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy
from .content_scraping_strategy import (
ContentScrapingStrategy,
WebScrapingStrategy,
LXMLWebScrapingStrategy,
)
from .extraction_strategy import (
ExtractionStrategy,
LLMExtractionStrategy,
CosineStrategy,
JsonCssExtractionStrategy,
)
from .chunking_strategy import ChunkingStrategy, RegexChunking
from .markdown_generation_strategy import DefaultMarkdownGenerator
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter
from .models import CrawlResult, MarkdownGenerationResult
from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode
from .__version__ import __version__
from .async_dispatcher import (
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
)
__all__ = [
"AsyncWebCrawler",
@@ -18,39 +32,44 @@ __all__ = [
"ContentScrapingStrategy",
"WebScrapingStrategy",
"LXMLWebScrapingStrategy",
'BrowserConfig',
'CrawlerRunConfig',
'ExtractionStrategy',
'LLMExtractionStrategy',
'CosineStrategy',
'JsonCssExtractionStrategy',
'ChunkingStrategy',
'RegexChunking',
'DefaultMarkdownGenerator',
'PruningContentFilter',
'BM25ContentFilter',
'MemoryAdaptiveDispatcher',
'SemaphoreDispatcher',
'RateLimiter',
'CrawlerMonitor',
'DisplayMode',
'MarkdownGenerationResult',
"BrowserConfig",
"CrawlerRunConfig",
"ExtractionStrategy",
"LLMExtractionStrategy",
"CosineStrategy",
"JsonCssExtractionStrategy",
"ChunkingStrategy",
"RegexChunking",
"DefaultMarkdownGenerator",
"PruningContentFilter",
"BM25ContentFilter",
"MemoryAdaptiveDispatcher",
"SemaphoreDispatcher",
"RateLimiter",
"CrawlerMonitor",
"DisplayMode",
"MarkdownGenerationResult",
]
def is_sync_version_installed():
try:
import selenium
return True
except ImportError:
return False
if is_sync_version_installed():
try:
from .web_crawler import WebCrawler
__all__.append("WebCrawler")
except ImportError:
import warnings
print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.")
print(
"Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies."
)
else:
WebCrawler = None
# import warnings

View File

@@ -5,7 +5,6 @@ from .config import (
PAGE_TIMEOUT,
IMAGE_SCORE_THRESHOLD,
SOCIAL_MEDIA_DOMAINS,
)
from .user_agent_generator import UserAgentGenerator
from .extraction_strategy import ExtractionStrategy
@@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
from typing import Union, List
class BrowserConfig:
"""
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
@@ -103,7 +103,7 @@ class BrowserConfig:
text_mode: bool = False,
light_mode: bool = False,
extra_args: list = None,
debugging_port : int = 9222,
debugging_port: int = 9222,
):
self.browser_type = browser_type
self.headless = headless
@@ -335,10 +335,8 @@ class CrawlerRunConfig:
prettiify: bool = False,
parser_type: str = "lxml",
scraping_strategy: ContentScrapingStrategy = None,
# SSL Parameters
fetch_ssl_certificate: bool = False,
# Caching Parameters
cache_mode=None,
session_id: str = None,
@@ -346,7 +344,6 @@ class CrawlerRunConfig:
disable_cache: bool = False,
no_cache_read: bool = False,
no_cache_write: bool = False,
# Page Navigation and Timing Parameters
wait_until: str = "domcontentloaded",
page_timeout: int = PAGE_TIMEOUT,
@@ -356,7 +353,6 @@ class CrawlerRunConfig:
mean_delay: float = 0.1,
max_range: float = 0.3,
semaphore_count: int = 5,
# Page Interaction Parameters
js_code: Union[str, List[str]] = None,
js_only: bool = False,
@@ -369,7 +365,6 @@ class CrawlerRunConfig:
override_navigator: bool = False,
magic: bool = False,
adjust_viewport_to_content: bool = False,
# Media Handling Parameters
screenshot: bool = False,
screenshot_wait_for: float = None,
@@ -378,17 +373,14 @@ class CrawlerRunConfig:
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
image_score_threshold: int = IMAGE_SCORE_THRESHOLD,
exclude_external_images: bool = False,
# Link and Domain Handling Parameters
exclude_social_media_domains: list = None,
exclude_external_links: bool = False,
exclude_social_media_links: bool = False,
exclude_domains: list = None,
# Debugging and Logging Parameters
verbose: bool = True,
log_console: bool = False,
url: str = None,
):
self.url = url
@@ -453,7 +445,9 @@ class CrawlerRunConfig:
self.exclude_external_images = exclude_external_images
# Link and Domain Handling Parameters
self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
self.exclude_social_media_domains = (
exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
)
self.exclude_external_links = exclude_external_links
self.exclude_social_media_links = exclude_social_media_links
self.exclude_domains = exclude_domains or []
@@ -466,11 +460,15 @@ class CrawlerRunConfig:
if self.extraction_strategy is not None and not isinstance(
self.extraction_strategy, ExtractionStrategy
):
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy")
raise ValueError(
"extraction_strategy must be an instance of ExtractionStrategy"
)
if self.chunking_strategy is not None and not isinstance(
self.chunking_strategy, ChunkingStrategy
):
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy")
raise ValueError(
"chunking_strategy must be an instance of ChunkingStrategy"
)
# Set default chunking strategy if None
if self.chunking_strategy is None:
@@ -494,10 +492,8 @@ class CrawlerRunConfig:
prettiify=kwargs.get("prettiify", False),
parser_type=kwargs.get("parser_type", "lxml"),
scraping_strategy=kwargs.get("scraping_strategy"),
# SSL Parameters
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
# Caching Parameters
cache_mode=kwargs.get("cache_mode"),
session_id=kwargs.get("session_id"),
@@ -505,7 +501,6 @@ class CrawlerRunConfig:
disable_cache=kwargs.get("disable_cache", False),
no_cache_read=kwargs.get("no_cache_read", False),
no_cache_write=kwargs.get("no_cache_write", False),
# Page Navigation and Timing Parameters
wait_until=kwargs.get("wait_until", "domcontentloaded"),
page_timeout=kwargs.get("page_timeout", 60000),
@@ -515,7 +510,6 @@ class CrawlerRunConfig:
mean_delay=kwargs.get("mean_delay", 0.1),
max_range=kwargs.get("max_range", 0.3),
semaphore_count=kwargs.get("semaphore_count", 5),
# Page Interaction Parameters
js_code=kwargs.get("js_code"),
js_only=kwargs.get("js_only", False),
@@ -528,26 +522,31 @@ class CrawlerRunConfig:
override_navigator=kwargs.get("override_navigator", False),
magic=kwargs.get("magic", False),
adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
# Media Handling Parameters
screenshot=kwargs.get("screenshot", False),
screenshot_wait_for=kwargs.get("screenshot_wait_for"),
screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD),
screenshot_height_threshold=kwargs.get(
"screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD
),
pdf=kwargs.get("pdf", False),
image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD),
image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD),
image_description_min_word_threshold=kwargs.get(
"image_description_min_word_threshold",
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
),
image_score_threshold=kwargs.get(
"image_score_threshold", IMAGE_SCORE_THRESHOLD
),
exclude_external_images=kwargs.get("exclude_external_images", False),
# Link and Domain Handling Parameters
exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS),
exclude_social_media_domains=kwargs.get(
"exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS
),
exclude_external_links=kwargs.get("exclude_external_links", False),
exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
exclude_domains=kwargs.get("exclude_domains", []),
# Debugging and Logging Parameters
verbose=kwargs.get("verbose", True),
log_console=kwargs.get("log_console", False),
url=kwargs.get("url"),
)

View File

@@ -2,27 +2,25 @@ import asyncio
import base64
import time
from abc import ABC, abstractmethod
from typing import Callable, Dict, Any, List, Optional, Awaitable, Union
import os, sys, shutil
import tempfile, subprocess
from playwright.async_api import async_playwright, Page, Browser, Error, BrowserContext
from typing import Callable, Dict, Any, List, Optional, Union
import os
import sys
import shutil
import tempfile
import subprocess
from playwright.async_api import Page, Error, BrowserContext
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from pathlib import Path
from playwright.async_api import ProxySettings
from pydantic import BaseModel
import hashlib
import json
import uuid
from .js_snippet import load_js_script
from .models import AsyncCrawlResponse
from .utils import get_error_context
from .user_agent_generator import UserAgentGenerator
from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT
from .async_configs import BrowserConfig, CrawlerRunConfig
from .async_logger import AsyncLogger
from playwright_stealth import StealthConfig, stealth_async
from playwright_stealth import StealthConfig
from .ssl_certificate import SSLCertificate
stealth_config = StealthConfig(
@@ -94,6 +92,7 @@ class ManagedBrowser:
temp_dir: str
debugging_port: int
host: str
def __init__(
self,
browser_type: str = "chromium",
@@ -139,7 +138,7 @@ class ManagedBrowser:
self.user_data_dir = self.temp_dir
# Get browser path and args based on OS and browser type
browser_path = self._get_browser_path()
# browser_path = self._get_browser_path()
args = self._get_browser_args()
# Start browser process
@@ -300,6 +299,7 @@ class BrowserManager:
sessions (dict): Dictionary to store session information
session_ttl (int): Session timeout in seconds
"""
def __init__(self, browser_config: BrowserConfig, logger=None):
"""
Initialize the BrowserManager with a browser configuration.
@@ -453,7 +453,12 @@ class BrowserManager:
return browser_args
async def setup_context(self, context: BrowserContext, crawlerRunConfig: CrawlerRunConfig = None, is_default=False):
async def setup_context(
self,
context: BrowserContext,
crawlerRunConfig: CrawlerRunConfig = None,
is_default=False,
):
"""
Set up a browser context with the configured options.
@@ -496,9 +501,9 @@ class BrowserManager:
context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT)
if self.config.downloads_path:
context._impl_obj._options["accept_downloads"] = True
context._impl_obj._options["downloads_path"] = (
self.config.downloads_path
)
context._impl_obj._options[
"downloads_path"
] = self.config.downloads_path
# Handle user agent and browser hints
if self.config.user_agent:
@@ -511,7 +516,15 @@ class BrowserManager:
# Add default cookie
await context.add_cookies(
[{"name": "cookiesEnabled", "value": "true", "url": crawlerRunConfig.url if crawlerRunConfig else "https://crawl4ai.com/"}]
[
{
"name": "cookiesEnabled",
"value": "true",
"url": crawlerRunConfig.url
if crawlerRunConfig
else "https://crawl4ai.com/",
}
]
)
# Handle navigator overrides
@@ -541,20 +554,57 @@ class BrowserManager:
blocked_extensions = [
# Images
'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'ico', 'bmp', 'tiff', 'psd',
"jpg",
"jpeg",
"png",
"gif",
"webp",
"svg",
"ico",
"bmp",
"tiff",
"psd",
# Fonts
'woff', 'woff2', 'ttf', 'otf', 'eot',
"woff",
"woff2",
"ttf",
"otf",
"eot",
# Styles
# 'css', 'less', 'scss', 'sass',
# Media
'mp4', 'webm', 'ogg', 'avi', 'mov', 'wmv', 'flv', 'm4v',
'mp3', 'wav', 'aac', 'm4a', 'opus', 'flac',
"mp4",
"webm",
"ogg",
"avi",
"mov",
"wmv",
"flv",
"m4v",
"mp3",
"wav",
"aac",
"m4a",
"opus",
"flac",
# Documents
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx',
"pdf",
"doc",
"docx",
"xls",
"xlsx",
"ppt",
"pptx",
# Archives
'zip', 'rar', '7z', 'tar', 'gz',
"zip",
"rar",
"7z",
"tar",
"gz",
# Scripts and data
'xml', 'swf', 'wasm'
"xml",
"swf",
"wasm",
]
# Common context settings
@@ -672,12 +722,12 @@ class AsyncCrawlerStrategy(ABC):
Abstract base class for crawler strategies.
Subclasses must implement the crawl method.
"""
@abstractmethod
async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
pass # 4 + 3
class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
"""
Crawler strategy using Playwright.
@@ -706,6 +756,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Run the crawler for a single URL.
"""
def __init__(
self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs
):
@@ -917,7 +968,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
"or explicitly prefixed with 'js:' or 'css:'."
)
async def csp_compliant_wait( self, page: Page, user_wait_function: str, timeout: float = 30000 ):
async def csp_compliant_wait(
self, page: Page, user_wait_function: str, timeout: float = 30000
):
"""
Wait for a condition in a CSP-compliant way.
@@ -1045,7 +1098,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
page, context = await self.browser_manager.get_page(session_id, user_agent)
return session_id
async def crawl( self, url: str, config: CrawlerRunConfig, **kwargs ) -> AsyncCrawlResponse:
async def crawl(
self, url: str, config: CrawlerRunConfig, **kwargs
) -> AsyncCrawlResponse:
"""
Crawls a given URL or processes raw HTML/local file content based on the URL prefix.
@@ -1104,7 +1159,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
"URL must start with 'http://', 'https://', 'file://', or 'raw:'"
)
async def _crawl_web( self, url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse:
async def _crawl_web(
self, url: str, config: CrawlerRunConfig
) -> AsyncCrawlResponse:
"""
Internal method to crawl web URLs with the specified configuration.
@@ -1190,9 +1247,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
nonce = hashlib.sha256(os.urandom(32)).hexdigest()
# Add CSP headers to the request
await page.set_extra_http_headers({
'Content-Security-Policy': f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'"
})
await page.set_extra_http_headers(
{
"Content-Security-Policy": f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'"
}
)
response = await page.goto(
url, wait_until=config.wait_until, timeout=config.page_timeout
@@ -1200,7 +1259,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
except Error as e:
raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}")
await self.execute_hook("after_goto", page, context=context, url=url, response=response)
await self.execute_hook(
"after_goto", page, context=context, url=url, response=response
)
if response is None:
status_code = 200
@@ -1229,14 +1290,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
style.opacity !== '0';
return isVisible;
}""",
timeout=30000
timeout=30000,
)
if not is_visible and not config.ignore_body_visibility:
visibility_info = await self.check_visibility(page)
raise Error(f"Body element is hidden: {visibility_info}")
except Error as e:
except Error:
visibility_info = await self.check_visibility(page)
if self.config.verbose:
@@ -1249,7 +1310,6 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if not config.ignore_body_visibility:
raise Error(f"Body element is hidden: {visibility_info}")
# try:
# await page.wait_for_selector("body", state="attached", timeout=30000)
@@ -1303,7 +1363,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
images_loaded = await self.csp_compliant_wait(
page,
"() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)",
timeout=1000
timeout=1000,
)
if not images_loaded and self.logger:
@@ -1316,8 +1376,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if not self.browser_config.text_mode and config.adjust_viewport_to_content:
try:
dimensions = await self.get_page_dimensions(page)
page_height = dimensions['height']
page_width = dimensions['width']
page_height = dimensions["height"]
page_width = dimensions["width"]
# page_width = await page.evaluate(
# "document.documentElement.scrollWidth"
# )
@@ -1364,12 +1424,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if config.js_code:
# execution_result = await self.execute_user_script(page, config.js_code)
execution_result = await self.robust_execute_user_script(page, config.js_code)
execution_result = await self.robust_execute_user_script(
page, config.js_code
)
if not execution_result["success"]:
self.logger.warning(
message="User script execution had issues: {error}",
tag="JS_EXEC",
params={"error": execution_result.get("error")}
params={"error": execution_result.get("error")},
)
await self.execute_hook("on_execution_started", page, context=context)
@@ -1425,7 +1487,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# Get final HTML content
html = await page.content()
await self.execute_hook("before_return_html", page = page, html = html, context=context)
await self.execute_hook(
"before_return_html", page=page, html=html, context=context
)
# Handle PDF and screenshot generation
start_export_time = time.perf_counter()
@@ -1511,7 +1575,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# total_height = await page.evaluate("document.documentElement.scrollHeight")
dimensions = await self.get_page_dimensions(page)
total_height = dimensions['height']
total_height = dimensions["height"]
while current_position < total_height:
current_position = min(current_position + viewport_height, total_height)
@@ -1521,7 +1585,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# new_height = await page.evaluate("document.documentElement.scrollHeight")
dimensions = await self.get_page_dimensions(page)
new_height = dimensions['height']
new_height = dimensions["height"]
if new_height > total_height:
total_height = new_height
@@ -1598,7 +1662,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
remove_overlays_js = load_js_script("remove_overlay_elements")
try:
await page.evaluate(f"""
await page.evaluate(
f"""
(() => {{
try {{
{remove_overlays_js}
@@ -1611,7 +1676,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
}};
}}
}})()
""")
"""
)
await page.wait_for_timeout(500) # Wait for any animations to complete
except Exception as e:
self.logger.warning(
@@ -1707,8 +1773,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
try:
# Get page height
dimensions = await self.get_page_dimensions(page)
page_width = dimensions['width']
page_height = dimensions['height']
page_width = dimensions["width"]
page_height = dimensions["height"]
# page_height = await page.evaluate("document.documentElement.scrollHeight")
# page_width = await page.evaluate("document.documentElement.scrollWidth")
@@ -1826,7 +1892,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
tag="WARNING",
)
async def robust_execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]:
async def robust_execute_user_script(
self, page: Page, js_code: Union[str, List[str]]
) -> Dict[str, Any]:
"""
Executes user-provided JavaScript code with proper error handling and context,
supporting both synchronous and async user code, plus navigations.
@@ -1846,7 +1914,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Dict[str, Any]: The results of the execution
"""
try:
await page.wait_for_load_state('domcontentloaded')
await page.wait_for_load_state("domcontentloaded")
if isinstance(js_code, str):
scripts = [js_code]
@@ -1861,7 +1929,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# then wait for the new page to load before continuing
result = None
try:
result = await page.evaluate(f"""
result = await page.evaluate(
f"""
(async () => {{
try {{
{script}
@@ -1870,51 +1939,60 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
return {{ success: false, error: err.toString(), stack: err.stack }};
}}
}})();
""")
"""
)
except Error as e:
# If it's due to navigation destroying the context, handle gracefully
if "Execution context was destroyed" in str(e):
self.logger.info("Navigation triggered by script, waiting for load state", tag="JS_EXEC")
self.logger.info(
"Navigation triggered by script, waiting for load state",
tag="JS_EXEC",
)
try:
await page.wait_for_load_state('load', timeout=30000)
await page.wait_for_load_state("load", timeout=30000)
except Error as nav_err:
self.logger.warning(
message="Navigation wait failed: {error}",
tag="JS_EXEC",
params={"error": str(nav_err)}
params={"error": str(nav_err)},
)
try:
await page.wait_for_load_state('networkidle', timeout=30000)
await page.wait_for_load_state(
"networkidle", timeout=30000
)
except Error as nav_err:
self.logger.warning(
message="Network idle wait failed: {error}",
tag="JS_EXEC",
params={"error": str(nav_err)}
params={"error": str(nav_err)},
)
# Return partial success, or adapt as you see fit
result = {
"success": True,
"info": "Navigation triggered, ignoring context destroyed error"
"info": "Navigation triggered, ignoring context destroyed error",
}
else:
# It's some other error, log and continue
self.logger.error(
message="Playwright execution error: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
result = {"success": False, "error": str(e)}
# If we made it this far with no repeated error, do post-load waits
t1 = time.time()
try:
await page.wait_for_load_state('domcontentloaded', timeout=5000)
print("DOM content loaded after script execution in", time.time() - t1)
await page.wait_for_load_state("domcontentloaded", timeout=5000)
print(
"DOM content loaded after script execution in",
time.time() - t1,
)
except Error as e:
self.logger.warning(
message="DOM content load timeout: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
# t1 = time.time()
@@ -1935,7 +2013,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error(
message="Script chunk failed: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
results.append({"success": False, "error": str(e)})
@@ -1945,11 +2023,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error(
message="Script execution failed: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
return {"success": False, "error": str(e)}
async def execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]:
async def execute_user_script(
self, page: Page, js_code: Union[str, List[str]]
) -> Dict[str, Any]:
"""
Executes user-provided JavaScript code with proper error handling and context.
@@ -1962,7 +2042,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
"""
try:
# Ensure the page is ready for script execution
await page.wait_for_load_state('domcontentloaded')
await page.wait_for_load_state("domcontentloaded")
# Handle single script or multiple scripts
if isinstance(js_code, str):
@@ -1974,7 +2054,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
for script in scripts:
try:
# Execute the script and wait for network idle
result = await page.evaluate(f"""
result = await page.evaluate(
f"""
(() => {{
return new Promise((resolve) => {{
try {{
@@ -2007,15 +2088,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
}}
}});
}})()
""")
"""
)
# Wait for network idle after script execution
t1 = time.time()
await page.wait_for_load_state('domcontentloaded', timeout=5000)
print("DOM content loaded after script execution in", time.time() - t1)
await page.wait_for_load_state("domcontentloaded", timeout=5000)
print(
"DOM content loaded after script execution in", time.time() - t1
)
t1 = time.time()
await page.wait_for_load_state('networkidle', timeout=5000)
await page.wait_for_load_state("networkidle", timeout=5000)
print("Network idle after script execution in", time.time() - t1)
results.append(result if result else {"success": True})
@@ -2025,7 +2109,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error(
message="Playwright execution error: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
results.append({"success": False, "error": str(e)})
@@ -2035,7 +2119,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error(
message="Script execution failed: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
return {"success": False, "error": str(e)}
@@ -2043,7 +2127,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error(
message="Script execution failed: {error}",
tag="JS_EXEC",
params={"error": str(e)}
params={"error": str(e)},
)
return {"success": False, "error": str(e)}
@@ -2057,7 +2141,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Returns:
Boolean indicating visibility
"""
return await page.evaluate("""
return await page.evaluate(
"""
() => {
const element = document.body;
if (!element) return false;
@@ -2067,7 +2152,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
style.opacity !== '0';
return isVisible;
}
""")
"""
)
async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1):
"""
@@ -2079,7 +2165,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
y: Vertical scroll position
"""
result = await self.csp_scroll_to(page, x, y)
if result['success']:
if result["success"]:
await page.wait_for_timeout(delay * 1000)
return result
@@ -2126,11 +2212,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
}}"""
)
if not result['success']:
if not result["success"]:
self.logger.warning(
message="Scroll operation failed: {error}",
tag="SCROLL",
params={"error": result.get('error')}
params={"error": result.get("error")},
)
return result
@@ -2139,12 +2225,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
self.logger.error(
message="Failed to execute scroll: {error}",
tag="SCROLL",
params={"error": str(e)}
params={"error": str(e)},
)
return {
"success": False,
"error": str(e)
}
return {"success": False, "error": str(e)}
async def get_page_dimensions(self, page: Page):
"""
@@ -2156,12 +2239,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
Returns:
Dict containing width and height of the page
"""
return await page.evaluate("""
return await page.evaluate(
"""
() => {
const {scrollWidth, scrollHeight} = document.documentElement;
return {width: scrollWidth, height: scrollHeight};
}
""")
"""
)
async def page_need_scroll(self, page: Page) -> bool:
"""
@@ -2174,18 +2259,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
bool: True if page needs scrolling
"""
try:
need_scroll = await page.evaluate("""
need_scroll = await page.evaluate(
"""
() => {
const scrollHeight = document.documentElement.scrollHeight;
const viewportHeight = window.innerHeight;
return scrollHeight > viewportHeight;
}
""")
"""
)
return need_scroll
except Exception as e:
self.logger.warning(
message="Failed to check scroll need: {error}. Defaulting to True for safety.",
tag="SCROLL",
params={"error": str(e)}
params={"error": str(e)},
)
return True # Default to scrolling if check fails

View File

@@ -1,27 +1,29 @@
import os, sys
import os
from pathlib import Path
import aiosqlite
import asyncio
from typing import Optional, Tuple, Dict
from typing import Optional, Dict
from contextlib import asynccontextmanager
import logging
import json # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash
from .models import CrawlResult, MarkdownGenerationResult
import xxhash
import aiofiles
from .config import NEED_MIGRATION
from .version_manager import VersionManager
from .async_logger import AsyncLogger
from .utils import get_error_context, create_box_message
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
base_directory = DB_PATH = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(base_directory, "crawl4ai.db")
class AsyncDatabaseManager:
def __init__(self, pool_size: int = 10, max_retries: int = 3):
self.db_path = DB_PATH
@@ -37,10 +39,9 @@ class AsyncDatabaseManager:
self.logger = AsyncLogger(
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
verbose=False,
tag_width=10
tag_width=10,
)
async def initialize(self):
"""Initialize the database and connection pool"""
try:
@@ -67,28 +68,32 @@ class AsyncDatabaseManager:
if needs_update:
self.logger.info("New version detected, running updates", tag="INIT")
await self.update_db_schema()
from .migrations import run_migration # Import here to avoid circular imports
from .migrations import (
run_migration,
) # Import here to avoid circular imports
await run_migration()
self.version_manager.update_version() # Update stored version after successful migration
self.logger.success("Version update completed successfully", tag="COMPLETE")
self.logger.success(
"Version update completed successfully", tag="COMPLETE"
)
else:
self.logger.success("Database initialization completed successfully", tag="COMPLETE")
self.logger.success(
"Database initialization completed successfully", tag="COMPLETE"
)
except Exception as e:
self.logger.error(
message="Database initialization error: {error}",
tag="ERROR",
params={"error": str(e)}
params={"error": str(e)},
)
self.logger.info(
message="Database will be initialized on first use",
tag="INIT"
message="Database will be initialized on first use", tag="INIT"
)
raise
async def cleanup(self):
"""Cleanup connections when shutting down"""
async with self.pool_lock:
@@ -107,6 +112,7 @@ class AsyncDatabaseManager:
self._initialized = True
except Exception as e:
import sys
error_context = get_error_context(sys.exc_info())
self.logger.error(
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
@@ -115,8 +121,8 @@ class AsyncDatabaseManager:
params={
"error": str(e),
"context": error_context["code_context"],
"traceback": error_context["full_traceback"]
}
"traceback": error_context["full_traceback"],
},
)
raise
@@ -127,29 +133,40 @@ class AsyncDatabaseManager:
async with self.pool_lock:
if task_id not in self.connection_pool:
try:
conn = await aiosqlite.connect(
self.db_path,
timeout=30.0
)
await conn.execute('PRAGMA journal_mode = WAL')
await conn.execute('PRAGMA busy_timeout = 5000')
conn = await aiosqlite.connect(self.db_path, timeout=30.0)
await conn.execute("PRAGMA journal_mode = WAL")
await conn.execute("PRAGMA busy_timeout = 5000")
# Verify database structure
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor:
async with conn.execute(
"PRAGMA table_info(crawled_data)"
) as cursor:
columns = await cursor.fetchall()
column_names = [col[1] for col in columns]
expected_columns = {
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content',
'success', 'media', 'links', 'metadata', 'screenshot',
'response_headers', 'downloaded_files'
"url",
"html",
"cleaned_html",
"markdown",
"extracted_content",
"success",
"media",
"links",
"metadata",
"screenshot",
"response_headers",
"downloaded_files",
}
missing_columns = expected_columns - set(column_names)
if missing_columns:
raise ValueError(f"Database missing columns: {missing_columns}")
raise ValueError(
f"Database missing columns: {missing_columns}"
)
self.connection_pool[task_id] = conn
except Exception as e:
import sys
error_context = get_error_context(sys.exc_info())
error_message = (
f"Unexpected error in db get_connection at line {error_context['line_no']} "
@@ -158,7 +175,7 @@ class AsyncDatabaseManager:
f"Code context:\n{error_context['code_context']}"
)
self.logger.error(
message=create_box_message(error_message, type= "error"),
message=create_box_message(error_message, type="error"),
)
raise
@@ -167,6 +184,7 @@ class AsyncDatabaseManager:
except Exception as e:
import sys
error_context = get_error_context(sys.exc_info())
error_message = (
f"Unexpected error in db get_connection at line {error_context['line_no']} "
@@ -175,7 +193,7 @@ class AsyncDatabaseManager:
f"Code context:\n{error_context['code_context']}"
)
self.logger.error(
message=create_box_message(error_message, type= "error"),
message=create_box_message(error_message, type="error"),
)
raise
finally:
@@ -185,7 +203,6 @@ class AsyncDatabaseManager:
del self.connection_pool[task_id]
self.connection_semaphore.release()
async def execute_with_retry(self, operation, *args):
"""Execute database operations with retry logic"""
for attempt in range(self.max_retries):
@@ -200,10 +217,7 @@ class AsyncDatabaseManager:
message="Operation failed after {retries} attempts: {error}",
tag="ERROR",
force_verbose=True,
params={
"retries": self.max_retries,
"error": str(e)
}
params={"retries": self.max_retries, "error": str(e)},
)
raise
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
@@ -211,7 +225,8 @@ class AsyncDatabaseManager:
async def ainit_db(self):
"""Initialize database schema"""
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
await db.execute('''
await db.execute(
"""
CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY,
html TEXT,
@@ -226,11 +241,10 @@ class AsyncDatabaseManager:
response_headers TEXT DEFAULT "{}",
downloaded_files TEXT DEFAULT "{}" -- New column added
)
''')
"""
)
await db.commit()
async def update_db_schema(self):
"""Update database schema if needed"""
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
@@ -239,7 +253,14 @@ class AsyncDatabaseManager:
column_names = [column[1] for column in columns]
# List of new columns to add
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files']
new_columns = [
"media",
"links",
"metadata",
"screenshot",
"response_headers",
"downloaded_files",
]
for column in new_columns:
if column not in column_names:
@@ -248,22 +269,26 @@ class AsyncDatabaseManager:
async def aalter_db_add_column(self, new_column: str, db):
"""Add new column to the database"""
if new_column == 'response_headers':
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"')
if new_column == "response_headers":
await db.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"'
)
else:
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
await db.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
)
self.logger.info(
message="Added column '{column}' to the database",
tag="INIT",
params={"column": new_column}
params={"column": new_column},
)
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
"""Retrieve cached URL data as CrawlResult"""
async def _get(db):
async with db.execute(
'SELECT * FROM crawled_data WHERE url = ?', (url,)
"SELECT * FROM crawled_data WHERE url = ?", (url,)
) as cursor:
row = await cursor.fetchone()
if not row:
@@ -276,42 +301,54 @@ class AsyncDatabaseManager:
# Load content from files using stored hashes
content_fields = {
'html': row_dict['html'],
'cleaned_html': row_dict['cleaned_html'],
'markdown': row_dict['markdown'],
'extracted_content': row_dict['extracted_content'],
'screenshot': row_dict['screenshot'],
'screenshots': row_dict['screenshot'],
"html": row_dict["html"],
"cleaned_html": row_dict["cleaned_html"],
"markdown": row_dict["markdown"],
"extracted_content": row_dict["extracted_content"],
"screenshot": row_dict["screenshot"],
"screenshots": row_dict["screenshot"],
}
for field, hash_value in content_fields.items():
if hash_value:
content = await self._load_content(
hash_value,
field.split('_')[0] # Get content type from field name
field.split("_")[0], # Get content type from field name
)
row_dict[field] = content or ""
else:
row_dict[field] = ""
# Parse JSON fields
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown']
json_fields = [
"media",
"links",
"metadata",
"response_headers",
"markdown",
]
for field in json_fields:
try:
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {}
row_dict[field] = (
json.loads(row_dict[field]) if row_dict[field] else {}
)
except json.JSONDecodeError:
row_dict[field] = {}
if isinstance(row_dict['markdown'], Dict):
row_dict['markdown_v2'] = row_dict['markdown']
if row_dict['markdown'].get('raw_markdown'):
row_dict['markdown'] = row_dict['markdown']['raw_markdown']
if isinstance(row_dict["markdown"], Dict):
row_dict["markdown_v2"] = row_dict["markdown"]
if row_dict["markdown"].get("raw_markdown"):
row_dict["markdown"] = row_dict["markdown"]["raw_markdown"]
# Parse downloaded_files
try:
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else []
row_dict["downloaded_files"] = (
json.loads(row_dict["downloaded_files"])
if row_dict["downloaded_files"]
else []
)
except json.JSONDecodeError:
row_dict['downloaded_files'] = []
row_dict["downloaded_files"] = []
# Remove any fields not in CrawlResult model
valid_fields = CrawlResult.__annotations__.keys()
@@ -326,7 +363,7 @@ class AsyncDatabaseManager:
message="Error retrieving cached URL: {error}",
tag="ERROR",
force_verbose=True,
params={"error": str(e)}
params={"error": str(e)},
)
return None
@@ -334,37 +371,52 @@ class AsyncDatabaseManager:
"""Cache CrawlResult data"""
# Store content files and get hashes
content_map = {
'html': (result.html, 'html'),
'cleaned_html': (result.cleaned_html or "", 'cleaned'),
'markdown': None,
'extracted_content': (result.extracted_content or "", 'extracted'),
'screenshot': (result.screenshot or "", 'screenshots')
"html": (result.html, "html"),
"cleaned_html": (result.cleaned_html or "", "cleaned"),
"markdown": None,
"extracted_content": (result.extracted_content or "", "extracted"),
"screenshot": (result.screenshot or "", "screenshots"),
}
try:
if isinstance(result.markdown, MarkdownGenerationResult):
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown')
elif hasattr(result, 'markdown_v2'):
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown')
content_map["markdown"] = (
result.markdown.model_dump_json(),
"markdown",
)
elif hasattr(result, "markdown_v2"):
content_map["markdown"] = (
result.markdown_v2.model_dump_json(),
"markdown",
)
elif isinstance(result.markdown, str):
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown')
content_map["markdown"] = (
markdown_result.model_dump_json(),
"markdown",
)
else:
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
content_map["markdown"] = (
MarkdownGenerationResult().model_dump_json(),
"markdown",
)
except Exception as e:
self.logger.warning(
message=f"Error processing markdown content: {str(e)}",
tag="WARNING"
message=f"Error processing markdown content: {str(e)}", tag="WARNING"
)
# Fallback to empty markdown result
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
content_map["markdown"] = (
MarkdownGenerationResult().model_dump_json(),
"markdown",
)
content_hashes = {}
for field, (content, content_type) in content_map.items():
content_hashes[field] = await self._store_content(content, content_type)
async def _cache(db):
await db.execute('''
await db.execute(
"""
INSERT INTO crawled_data (
url, html, cleaned_html, markdown,
extracted_content, success, media, links, metadata,
@@ -383,20 +435,22 @@ class AsyncDatabaseManager:
screenshot = excluded.screenshot,
response_headers = excluded.response_headers,
downloaded_files = excluded.downloaded_files
''', (
""",
(
result.url,
content_hashes['html'],
content_hashes['cleaned_html'],
content_hashes['markdown'],
content_hashes['extracted_content'],
content_hashes["html"],
content_hashes["cleaned_html"],
content_hashes["markdown"],
content_hashes["extracted_content"],
result.success,
json.dumps(result.media),
json.dumps(result.links),
json.dumps(result.metadata or {}),
content_hashes['screenshot'],
content_hashes["screenshot"],
json.dumps(result.response_headers or {}),
json.dumps(result.downloaded_files or [])
))
json.dumps(result.downloaded_files or []),
),
)
try:
await self.execute_with_retry(_cache)
@@ -405,14 +459,14 @@ class AsyncDatabaseManager:
message="Error caching URL: {error}",
tag="ERROR",
force_verbose=True,
params={"error": str(e)}
params={"error": str(e)},
)
async def aget_total_count(self) -> int:
"""Get total number of cached URLs"""
async def _count(db):
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor:
async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor:
result = await cursor.fetchone()
return result[0] if result else 0
@@ -423,14 +477,15 @@ class AsyncDatabaseManager:
message="Error getting total count: {error}",
tag="ERROR",
force_verbose=True,
params={"error": str(e)}
params={"error": str(e)},
)
return 0
async def aclear_db(self):
"""Clear all data from the database"""
async def _clear(db):
await db.execute('DELETE FROM crawled_data')
await db.execute("DELETE FROM crawled_data")
try:
await self.execute_with_retry(_clear)
@@ -439,13 +494,14 @@ class AsyncDatabaseManager:
message="Error clearing database: {error}",
tag="ERROR",
force_verbose=True,
params={"error": str(e)}
params={"error": str(e)},
)
async def aflush_db(self):
"""Drop the entire table"""
async def _flush(db):
await db.execute('DROP TABLE IF EXISTS crawled_data')
await db.execute("DROP TABLE IF EXISTS crawled_data")
try:
await self.execute_with_retry(_flush)
@@ -454,10 +510,9 @@ class AsyncDatabaseManager:
message="Error flushing database: {error}",
tag="ERROR",
force_verbose=True,
params={"error": str(e)}
params={"error": str(e)},
)
async def _store_content(self, content: str, content_type: str) -> str:
"""Store content in filesystem and return hash"""
if not content:
@@ -468,28 +523,31 @@ class AsyncDatabaseManager:
# Only write if file doesn't exist
if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
await f.write(content)
return content_hash
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
async def _load_content(
self, content_hash: str, content_type: str
) -> Optional[str]:
"""Load content from filesystem by hash"""
if not content_hash:
return None
file_path = os.path.join(self.content_paths[content_type], content_hash)
try:
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read()
except:
self.logger.error(
message="Failed to load content: {file_path}",
tag="ERROR",
force_verbose=True,
params={"file_path": file_path}
params={"file_path": file_path},
)
return None
# Create a singleton instance
async_db_manager = AsyncDatabaseManager()

View File

@@ -1,14 +1,19 @@
from typing import Dict, Optional, List
from .async_configs import *
from .models import *
from typing import Dict, Optional, List, Tuple
from .async_configs import CrawlerRunConfig
from .models import (
CrawlResult,
CrawlerTaskResult,
CrawlStatus,
DisplayMode,
CrawlStats,
DomainState,
)
from rich.live import Live
from rich.table import Table
from rich.console import Console
from rich.style import Style
from rich import box
from datetime import datetime, timedelta
from dataclasses import dataclass
import time
import psutil
@@ -26,7 +31,7 @@ class RateLimiter:
base_delay: Tuple[float, float] = (1.0, 3.0),
max_delay: float = 60.0,
max_retries: int = 3,
rate_limit_codes: List[int] = None
rate_limit_codes: List[int] = None,
):
self.base_delay = base_delay
self.max_delay = max_delay
@@ -68,21 +73,24 @@ class RateLimiter:
# Exponential backoff with random jitter
state.current_delay = min(
state.current_delay * 2 * random.uniform(0.75, 1.25),
self.max_delay
state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay
)
else:
# Gradually reduce delay on success
state.current_delay = max(
random.uniform(*self.base_delay),
state.current_delay * 0.75
random.uniform(*self.base_delay), state.current_delay * 0.75
)
state.fail_count = 0
return True
class CrawlerMonitor:
def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED):
def __init__(
self,
max_visible_rows: int = 15,
display_mode: DisplayMode = DisplayMode.DETAILED,
):
self.console = Console()
self.max_visible_rows = max_visible_rows
self.display_mode = display_mode
@@ -98,7 +106,9 @@ class CrawlerMonitor:
self.live.stop()
def add_task(self, task_id: str, url: str):
self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED)
self.stats[task_id] = CrawlStats(
task_id=task_id, url=url, status=CrawlStatus.QUEUED
)
self.live.update(self._create_table())
def update_task(self, task_id: str, **kwargs):
@@ -114,20 +124,30 @@ class CrawlerMonitor:
title="Crawler Status Overview",
title_style="bold magenta",
header_style="bold blue",
show_lines=True
show_lines=True,
)
# Calculate statistics
total_tasks = len(self.stats)
queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED)
in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS)
completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED)
failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED)
queued = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED
)
in_progress = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
)
completed = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
)
failed = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
)
# Memory statistics
current_memory = self.process.memory_info().rss / (1024 * 1024)
total_task_memory = sum(stat.memory_usage for stat in self.stats.values())
peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0)
peak_memory = max(
(stat.peak_memory for stat in self.stats.values()), default=0.0
)
# Duration
duration = datetime.now() - self.start_time
@@ -137,53 +157,43 @@ class CrawlerMonitor:
table.add_column("Count", justify="right")
table.add_column("Percentage", justify="right")
table.add_row(
"Total Tasks",
str(total_tasks),
"100%"
)
table.add_row("Total Tasks", str(total_tasks), "100%")
table.add_row(
"[yellow]In Queue[/yellow]",
str(queued),
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
)
table.add_row(
"[blue]In Progress[/blue]",
str(in_progress),
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
)
table.add_row(
"[green]Completed[/green]",
str(completed),
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
)
table.add_row(
"[red]Failed[/red]",
str(failed),
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
)
# Add memory information
table.add_section()
table.add_row(
"[magenta]Current Memory[/magenta]",
f"{current_memory:.1f} MB",
""
"[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", ""
)
table.add_row(
"[magenta]Total Task Memory[/magenta]",
f"{total_task_memory:.1f} MB",
""
"[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", ""
)
table.add_row(
"[magenta]Peak Task Memory[/magenta]",
f"{peak_memory:.1f} MB",
""
"[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", ""
)
table.add_row(
"[yellow]Runtime[/yellow]",
str(timedelta(seconds=int(duration.total_seconds()))),
""
"",
)
return table
@@ -193,7 +203,7 @@ class CrawlerMonitor:
box=box.ROUNDED,
title="Crawler Performance Monitor",
title_style="bold magenta",
header_style="bold blue"
header_style="bold blue",
)
# Add columns
@@ -207,12 +217,15 @@ class CrawlerMonitor:
# Add summary row
total_memory = sum(stat.memory_usage for stat in self.stats.values())
active_count = sum(1 for stat in self.stats.values()
if stat.status == CrawlStatus.IN_PROGRESS)
completed_count = sum(1 for stat in self.stats.values()
if stat.status == CrawlStatus.COMPLETED)
failed_count = sum(1 for stat in self.stats.values()
if stat.status == CrawlStatus.FAILED)
active_count = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
)
completed_count = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
)
failed_count = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
)
table.add_row(
"[bold yellow]SUMMARY",
@@ -220,9 +233,13 @@ class CrawlerMonitor:
f"Active: {active_count}",
f"{total_memory:.1f}",
f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))),
str(
timedelta(
seconds=int((datetime.now() - self.start_time).total_seconds())
)
),
f"{completed_count}{failed_count}",
style="bold"
style="bold",
)
table.add_section()
@@ -233,16 +250,16 @@ class CrawlerMonitor:
key=lambda x: (
x.status != CrawlStatus.IN_PROGRESS,
x.status != CrawlStatus.QUEUED,
x.end_time or datetime.max
)
)[:self.max_visible_rows]
x.end_time or datetime.max,
),
)[: self.max_visible_rows]
for stat in visible_stats:
status_style = {
CrawlStatus.QUEUED: "white",
CrawlStatus.IN_PROGRESS: "yellow",
CrawlStatus.COMPLETED: "green",
CrawlStatus.FAILED: "red"
CrawlStatus.FAILED: "red",
}[stat.status]
table.add_row(
@@ -252,7 +269,7 @@ class CrawlerMonitor:
f"{stat.memory_usage:.1f}",
f"{stat.peak_memory:.1f}",
stat.duration,
stat.error_message[:40] if stat.error_message else ""
stat.error_message[:40] if stat.error_message else "",
)
return table
@@ -268,7 +285,7 @@ class BaseDispatcher(ABC):
def __init__(
self,
rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None
monitor: Optional[CrawlerMonitor] = None,
):
self.crawler = None
self._domain_last_hit: Dict[str, float] = {}
@@ -282,7 +299,7 @@ class BaseDispatcher(ABC):
url: str,
config: CrawlerRunConfig,
task_id: str,
monitor: Optional[CrawlerMonitor] = None
monitor: Optional[CrawlerMonitor] = None,
) -> CrawlerTaskResult:
pass
@@ -290,12 +307,13 @@ class BaseDispatcher(ABC):
async def run_urls(
self,
urls: List[str],
crawler: "AsyncWebCrawler",
crawler: "AsyncWebCrawler", # noqa: F821
config: CrawlerRunConfig,
monitor: Optional[CrawlerMonitor] = None
monitor: Optional[CrawlerMonitor] = None,
) -> List[CrawlerTaskResult]:
pass
class MemoryAdaptiveDispatcher(BaseDispatcher):
def __init__(
self,
@@ -304,7 +322,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
max_session_permit: int = 20,
memory_wait_timeout: float = 300.0, # 5 minutes default timeout
rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None
monitor: Optional[CrawlerMonitor] = None,
):
super().__init__(rate_limiter, monitor)
self.memory_threshold_percent = memory_threshold_percent
@@ -324,7 +342,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
try:
if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time)
self.monitor.update_task(
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
)
self.concurrent_sessions += 1
if self.rate_limiter:
@@ -350,7 +370,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
peak_memory=peak_memory,
start_time=start_time,
end_time=datetime.now(),
error_message=error_message
error_message=error_message,
)
if not result.success:
@@ -364,7 +384,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
error_message = str(e)
if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e))
result = CrawlResult(
url=url, html="", metadata={}, success=False, error_message=str(e)
)
finally:
end_time = datetime.now()
@@ -374,7 +396,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
end_time=end_time,
memory_usage=memory_usage,
peak_memory=peak_memory,
error_message=error_message
error_message=error_message,
)
self.concurrent_sessions -= 1
@@ -386,13 +408,13 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
peak_memory=peak_memory,
start_time=start_time,
end_time=end_time,
error_message=error_message
error_message=error_message,
)
async def run_urls(
self,
urls: List[str],
crawler: "AsyncWebCrawler",
crawler: "AsyncWebCrawler", # noqa: F821
config: CrawlerRunConfig,
) -> List[CrawlerTaskResult]:
self.crawler = crawler
@@ -417,7 +439,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
if psutil.virtual_memory().percent >= self.memory_threshold_percent:
# Check if we've exceeded the timeout
if time.time() - wait_start_time > self.memory_wait_timeout:
raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds")
raise MemoryError(
f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds"
)
await asyncio.sleep(self.check_interval)
continue
@@ -430,8 +454,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
continue
done, pending = await asyncio.wait(
active_tasks,
return_when=asyncio.FIRST_COMPLETED
active_tasks, return_when=asyncio.FIRST_COMPLETED
)
pending_tasks.extend(done)
@@ -442,13 +465,14 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
if self.monitor:
self.monitor.stop()
class SemaphoreDispatcher(BaseDispatcher):
def __init__(
self,
semaphore_count: int = 5,
max_session_permit: int = 20,
rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None
monitor: Optional[CrawlerMonitor] = None,
):
super().__init__(rate_limiter, monitor)
self.semaphore_count = semaphore_count
@@ -459,7 +483,7 @@ class SemaphoreDispatcher(BaseDispatcher):
url: str,
config: CrawlerRunConfig,
task_id: str,
semaphore: asyncio.Semaphore = None
semaphore: asyncio.Semaphore = None,
) -> CrawlerTaskResult:
start_time = datetime.now()
error_message = ""
@@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher):
try:
if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time)
self.monitor.update_task(
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
)
if self.rate_limiter:
await self.rate_limiter.wait_if_needed(url)
@@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher):
peak_memory=peak_memory,
start_time=start_time,
end_time=datetime.now(),
error_message=error_message
error_message=error_message,
)
if not result.success:
@@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher):
error_message = str(e)
if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e))
result = CrawlResult(
url=url, html="", metadata={}, success=False, error_message=str(e)
)
finally:
end_time = datetime.now()
@@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher):
end_time=end_time,
memory_usage=memory_usage,
peak_memory=peak_memory,
error_message=error_message
error_message=error_message,
)
return CrawlerTaskResult(
@@ -528,12 +556,12 @@ class SemaphoreDispatcher(BaseDispatcher):
peak_memory=peak_memory,
start_time=start_time,
end_time=end_time,
error_message=error_message
error_message=error_message,
)
async def run_urls(
self,
crawler: "AsyncWebCrawler",
crawler: "AsyncWebCrawler", # noqa: F821
urls: List[str],
config: CrawlerRunConfig,
) -> List[CrawlerTaskResult]:

View File

@@ -1,10 +1,10 @@
from enum import Enum
from typing import Optional, Dict, Any, Union
from colorama import Fore, Back, Style, init
import time
from typing import Optional, Dict, Any
from colorama import Fore, Style, init
import os
from datetime import datetime
class LogLevel(Enum):
DEBUG = 1
INFO = 2
@@ -12,6 +12,7 @@ class LogLevel(Enum):
WARNING = 4
ERROR = 5
class AsyncLogger:
"""
Asynchronous logger with support for colored console output and file logging.
@@ -19,16 +20,16 @@ class AsyncLogger:
"""
DEFAULT_ICONS = {
'INIT': '',
'READY': '',
'FETCH': '',
'SCRAPE': '',
'EXTRACT': '',
'COMPLETE': '',
'ERROR': '×',
'DEBUG': '',
'INFO': '',
'WARNING': '',
"INIT": "",
"READY": "",
"FETCH": "",
"SCRAPE": "",
"EXTRACT": "",
"COMPLETE": "",
"ERROR": "×",
"DEBUG": "",
"INFO": "",
"WARNING": "",
}
DEFAULT_COLORS = {
@@ -46,7 +47,7 @@ class AsyncLogger:
tag_width: int = 10,
icons: Optional[Dict[str, str]] = None,
colors: Optional[Dict[LogLevel, str]] = None,
verbose: bool = True
verbose: bool = True,
):
"""
Initialize the logger.
@@ -77,18 +78,20 @@ class AsyncLogger:
def _get_icon(self, tag: str) -> str:
"""Get the icon for a tag, defaulting to info icon if not found."""
return self.icons.get(tag, self.icons['INFO'])
return self.icons.get(tag, self.icons["INFO"])
def _write_to_file(self, message: str):
"""Write a message to the log file if configured."""
if self.log_file:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
with open(self.log_file, 'a', encoding='utf-8') as f:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
with open(self.log_file, "a", encoding="utf-8") as f:
# Strip ANSI color codes for file output
clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '')
clean_message = message.replace(Fore.RESET, "").replace(
Style.RESET_ALL, ""
)
for color in vars(Fore).values():
if isinstance(color, str):
clean_message = clean_message.replace(color, '')
clean_message = clean_message.replace(color, "")
f.write(f"[{timestamp}] {clean_message}\n")
def _log(
@@ -99,7 +102,7 @@ class AsyncLogger:
params: Optional[Dict[str, Any]] = None,
colors: Optional[Dict[str, str]] = None,
base_color: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Core logging method that handles message formatting and output.
@@ -128,12 +131,13 @@ class AsyncLogger:
if key in params:
value_str = str(params[key])
formatted_message = formatted_message.replace(
value_str,
f"{color}{value_str}{Style.RESET_ALL}"
value_str, f"{color}{value_str}{Style.RESET_ALL}"
)
except KeyError as e:
formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template"
formatted_message = (
f"LOGGING ERROR: Missing parameter {e} in message template"
)
level = LogLevel.ERROR
else:
formatted_message = message
@@ -175,7 +179,7 @@ class AsyncLogger:
success: bool,
timing: float,
tag: str = "FETCH",
url_length: int = 50
url_length: int = 50,
):
"""
Convenience method for logging URL fetch status.
@@ -195,20 +199,16 @@ class AsyncLogger:
"url": url,
"url_length": url_length,
"status": success,
"timing": timing
"timing": timing,
},
colors={
"status": Fore.GREEN if success else Fore.RED,
"timing": Fore.YELLOW
}
"timing": Fore.YELLOW,
},
)
def error_status(
self,
url: str,
error: str,
tag: str = "ERROR",
url_length: int = 50
self, url: str, error: str, tag: str = "ERROR", url_length: int = 50
):
"""
Convenience method for logging error status.
@@ -223,9 +223,5 @@ class AsyncLogger:
level=LogLevel.ERROR,
message="{url:.{url_length}}... | Error: {error}",
tag=tag,
params={
"url": url,
"url_length": url_length,
"error": error
}
params={"url": url, "url_length": url_length, "error": error},
)

View File

@@ -1,42 +1,47 @@
import os, sys
import os
import sys
import time
import warnings
from enum import Enum
from colorama import init, Fore, Back, Style
from colorama import Fore
from pathlib import Path
from typing import Optional, List, Union
from typing import Optional, List
import json
import asyncio
# from contextlib import nullcontext, asynccontextmanager
from contextlib import asynccontextmanager
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult, RateLimiter
from .async_database import async_db_manager
from .chunking_strategy import *
from .content_filter_strategy import *
from .extraction_strategy import *
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse
from .chunking_strategy import * # noqa: F403
from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking
from .content_filter_strategy import * # noqa: F403
from .content_filter_strategy import RelevantContentFilter
from .extraction_strategy import * # noqa: F403
from .extraction_strategy import NoExtractionStrategy, ExtractionStrategy
from .async_crawler_strategy import (
AsyncCrawlerStrategy,
AsyncPlaywrightCrawlerStrategy,
AsyncCrawlResponse,
)
from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode
from .markdown_generation_strategy import DefaultMarkdownGenerator, MarkdownGenerationStrategy
from .content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy
from .markdown_generation_strategy import (
DefaultMarkdownGenerator,
MarkdownGenerationStrategy,
)
from .async_logger import AsyncLogger
from .async_configs import BrowserConfig, CrawlerRunConfig
from .async_dispatcher import *
from .async_dispatcher import * # noqa: F403
from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher
from .config import (
MIN_WORD_THRESHOLD,
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
URL_LOG_SHORTEN_LENGTH
)
from .config import MIN_WORD_THRESHOLD
from .utils import (
sanitize_input_encode,
InvalidCSSSelectorError,
format_html,
fast_format_html,
create_box_message
create_box_message,
get_error_context,
)
from urllib.parse import urlparse
import random
from .__version__ import __version__ as crawl4ai_version
@@ -104,6 +109,7 @@ class AsyncWebCrawler:
result = await crawler.arun(url="https://example.com", config=crawler_config)
print(result.markdown)
"""
_domain_last_hit = {}
def __init__(
@@ -131,10 +137,18 @@ class AsyncWebCrawler:
# Handle browser configuration
browser_config = config
if browser_config is not None:
if any(k in kwargs for k in ["browser_type", "headless", "viewport_width", "viewport_height"]):
if any(
k in kwargs
for k in [
"browser_type",
"headless",
"viewport_width",
"viewport_height",
]
):
self.logger.warning(
message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.",
tag="WARNING"
tag="WARNING",
)
else:
# Create browser config from kwargs for backwards compatibility
@@ -146,18 +160,15 @@ class AsyncWebCrawler:
self.logger = AsyncLogger(
log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"),
verbose=self.browser_config.verbose,
tag_width=10
tag_width=10,
)
# Initialize crawler strategy
params = {
k:v for k, v in kwargs.items() if k in ['browser_congig', 'logger']
}
params = {k: v for k, v in kwargs.items() if k in ["browser_congig", "logger"]}
self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy(
browser_config=browser_config,
logger=self.logger,
**params # Pass remaining kwargs for backwards compatibility
**params, # Pass remaining kwargs for backwards compatibility
)
# If craweler strategy doesnt have logger, use crawler logger
@@ -172,7 +183,7 @@ class AsyncWebCrawler:
"Use 'always_bypass_cache' instead. "
"Pass warning=False to suppress this warning.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
self.always_bypass_cache = always_by_pass_cache
else:
@@ -323,7 +334,7 @@ class AsyncWebCrawler:
"screenshot": screenshot,
"pdf": pdf,
"verbose": verbose,
**kwargs
**kwargs,
}
config = CrawlerRunConfig.from_kwargs(config_kwargs)
@@ -334,7 +345,7 @@ class AsyncWebCrawler:
"Cache control boolean flags are deprecated and will be removed in version 0.5.0. "
"Use 'cache_mode' parameter instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
# Convert legacy parameters if cache_mode not provided
@@ -343,7 +354,7 @@ class AsyncWebCrawler:
disable_cache=disable_cache,
bypass_cache=bypass_cache,
no_cache_read=no_cache_read,
no_cache_write=no_cache_write
no_cache_write=no_cache_write,
)
# Default to ENABLED if no cache mode specified
@@ -351,7 +362,9 @@ class AsyncWebCrawler:
config.cache_mode = CacheMode.ENABLED
# Create cache context
cache_context = CacheContext(url, config.cache_mode, self.always_bypass_cache)
cache_context = CacheContext(
url, config.cache_mode, self.always_bypass_cache
)
# Initialize processing variables
async_response: AsyncCrawlResponse = None
@@ -367,8 +380,14 @@ class AsyncWebCrawler:
if cached_result:
html = sanitize_input_encode(cached_result.html)
extracted_content = sanitize_input_encode(cached_result.extracted_content or "")
extracted_content = None if not extracted_content or extracted_content == "[]" else extracted_content
extracted_content = sanitize_input_encode(
cached_result.extracted_content or ""
)
extracted_content = (
None
if not extracted_content or extracted_content == "[]"
else extracted_content
)
# If screenshot is requested but its not in cache, then set cache_result to None
screenshot_data = cached_result.screenshot
pdf_data = cached_result.pdf
@@ -379,7 +398,7 @@ class AsyncWebCrawler:
url=cache_context.display_url,
success=bool(html),
timing=time.perf_counter() - start_time,
tag="FETCH"
tag="FETCH",
)
# Fetch fresh content if needed
@@ -392,7 +411,7 @@ class AsyncWebCrawler:
# Pass config to crawl method
async_response = await self.crawler_strategy.crawl(
url,
config=config # Pass the entire config object
config=config, # Pass the entire config object
)
html = sanitize_input_encode(async_response.html)
@@ -404,7 +423,7 @@ class AsyncWebCrawler:
url=cache_context.display_url,
success=bool(html),
timing=t2 - t1,
tag="FETCH"
tag="FETCH",
)
# Process the HTML content
@@ -416,14 +435,16 @@ class AsyncWebCrawler:
screenshot=screenshot_data,
pdf_data=pdf_data,
verbose=config.verbose,
is_raw_html = True if url.startswith("raw:") else False,
**kwargs
is_raw_html=True if url.startswith("raw:") else False,
**kwargs,
)
crawl_result.status_code = async_response.status_code
crawl_result.response_headers = async_response.response_headers
crawl_result.downloaded_files = async_response.downloaded_files
crawl_result.ssl_certificate = async_response.ssl_certificate # Add SSL certificate
crawl_result.ssl_certificate = (
async_response.ssl_certificate
) # Add SSL certificate
# # Check and set values from async_response to crawl_result
# try:
@@ -446,7 +467,7 @@ class AsyncWebCrawler:
# )
crawl_result.success = bool(html)
crawl_result.session_id = getattr(config, 'session_id', None)
crawl_result.session_id = getattr(config, "session_id", None)
self.logger.success(
message="{url:.50}... | Status: {status} | Total: {timing}",
@@ -454,12 +475,12 @@ class AsyncWebCrawler:
params={
"url": cache_context.display_url,
"status": crawl_result.success,
"timing": f"{time.perf_counter() - start_time:.2f}s"
"timing": f"{time.perf_counter() - start_time:.2f}s",
},
colors={
"status": Fore.GREEN if crawl_result.success else Fore.RED,
"timing": Fore.YELLOW
}
"timing": Fore.YELLOW,
},
)
# Update cache if appropriate
@@ -475,16 +496,13 @@ class AsyncWebCrawler:
params={
"url": cache_context.display_url,
"status": True,
"timing": f"{time.perf_counter() - start_time:.2f}s"
"timing": f"{time.perf_counter() - start_time:.2f}s",
},
colors={
"status": Fore.GREEN,
"timing": Fore.YELLOW
}
colors={"status": Fore.GREEN, "timing": Fore.YELLOW},
)
cached_result.success = bool(html)
cached_result.session_id = getattr(config, 'session_id', None)
cached_result.session_id = getattr(config, "session_id", None)
return cached_result
except Exception as e:
@@ -502,14 +520,11 @@ class AsyncWebCrawler:
self.logger.error_status(
url=url,
error=create_box_message(error_message, type="error"),
tag="ERROR"
tag="ERROR",
)
return CrawlResult(
url=url,
html="",
success=False,
error_message=error_message
url=url, html="", success=False, error_message=error_message
)
async def aprocess_html(
@@ -549,25 +564,23 @@ class AsyncWebCrawler:
scraping_strategy.logger = self.logger
# Process HTML content
params = {k:v for k, v in config.to_dict().items() if k not in ["url"]}
params = {k: v for k, v in config.to_dict().items() if k not in ["url"]}
# add keys from kwargs to params that doesn't exist in params
params.update({k:v for k, v in kwargs.items() if k not in params.keys()})
params.update({k: v for k, v in kwargs.items() if k not in params.keys()})
result = scraping_strategy.scrap(
url,
html,
**params
)
result = scraping_strategy.scrap(url, html, **params)
if result is None:
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}")
raise ValueError(
f"Process HTML, Failed to extract content from the website: {url}"
)
except InvalidCSSSelectorError as e:
raise ValueError(str(e))
except Exception as e:
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}")
raise ValueError(
f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}"
)
# Extract results - handle both dict and ScrapingResult
if isinstance(result, dict):
@@ -582,17 +595,21 @@ class AsyncWebCrawler:
metadata = result.metadata
# Markdown Generation
markdown_generator: Optional[MarkdownGenerationStrategy] = config.markdown_generator or DefaultMarkdownGenerator()
markdown_generator: Optional[MarkdownGenerationStrategy] = (
config.markdown_generator or DefaultMarkdownGenerator()
)
# Uncomment if by default we want to use PruningContentFilter
# if not config.content_filter and not markdown_generator.content_filter:
# markdown_generator.content_filter = PruningContentFilter()
markdown_result: MarkdownGenerationResult = markdown_generator.generate_markdown(
markdown_result: MarkdownGenerationResult = (
markdown_generator.generate_markdown(
cleaned_html=cleaned_html,
base_url=url,
# html2text_options=kwargs.get('html2text', {})
)
)
markdown_v2 = markdown_result
markdown = sanitize_input_encode(markdown_result.raw_markdown)
@@ -600,15 +617,15 @@ class AsyncWebCrawler:
self.logger.info(
message="Processed {url:.50}... | Time: {timing}ms",
tag="SCRAPE",
params={
"url": _url,
"timing": int((time.perf_counter() - t1) * 1000)
}
params={"url": _url, "timing": int((time.perf_counter() - t1) * 1000)},
)
# Handle content extraction if needed
if (not bool(extracted_content) and config.extraction_strategy and not isinstance(config.extraction_strategy, NoExtractionStrategy)):
if (
not bool(extracted_content)
and config.extraction_strategy
and not isinstance(config.extraction_strategy, NoExtractionStrategy)
):
t1 = time.perf_counter()
# Choose content based on input_format
@@ -617,30 +634,33 @@ class AsyncWebCrawler:
self.logger.warning(
message="Fit markdown requested but not available. Falling back to raw markdown.",
tag="EXTRACT",
params={"url": _url}
params={"url": _url},
)
content_format = "markdown"
content = {
"markdown": markdown,
"html": html,
"fit_markdown": markdown_result.raw_markdown
"fit_markdown": markdown_result.raw_markdown,
}.get(content_format, markdown)
# Use IdentityChunking for HTML input, otherwise use provided chunking strategy
chunking = IdentityChunking() if content_format == "html" else config.chunking_strategy
chunking = (
IdentityChunking()
if content_format == "html"
else config.chunking_strategy
)
sections = chunking.chunk(content)
extracted_content = config.extraction_strategy.run(url, sections)
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
extracted_content = json.dumps(
extracted_content, indent=4, default=str, ensure_ascii=False
)
# Log extraction completion
self.logger.info(
message="Completed for {url:.50}... | Time: {timing}s",
tag="EXTRACT",
params={
"url": _url,
"timing": time.perf_counter() - t1
}
params={"url": _url, "timing": time.perf_counter() - t1},
)
# Handle screenshot and PDF data
@@ -739,7 +759,7 @@ class AsyncWebCrawler:
screenshot=screenshot,
pdf=pdf,
verbose=verbose,
**kwargs
**kwargs,
)
# # Initialize the dispatcher with the selected strategy
@@ -754,29 +774,25 @@ class AsyncWebCrawler:
dispatcher = MemoryAdaptiveDispatcher(
self,
rate_limiter=RateLimiter(
base_delay=(1.0, 3.0),
max_delay=60.0,
max_retries=3
)
base_delay=(1.0, 3.0), max_delay=60.0, max_retries=3
),
)
# Run the URLs through the dispatcher
_results: List[CrawlerTaskResult] = await dispatcher.run_urls(
crawler=self,
urls=urls,
config=config
crawler=self, urls=urls, config=config
)
results: CrawlResult = []
for res in _results:
_res : CrawlResult = res.result
_res: CrawlResult = res.result
dispatch_result: DispatchResult = DispatchResult(
task_id=res.task_id,
memory_usage=res.memory_usage,
peak_memory=res.peak_memory,
start_time=res.start_time,
end_time=res.end_time,
error_message=res.error_message
error_message=res.error_message,
)
_res.dispatch_result = dispatch_result
results.append(_res)

View File

@@ -12,6 +12,7 @@ class CacheMode(Enum):
- WRITE_ONLY: Only write to cache, don't read
- BYPASS: Bypass cache for this operation
"""
ENABLED = "enabled"
DISABLED = "disabled"
READ_ONLY = "read_only"
@@ -36,6 +37,7 @@ class CacheContext:
is_raw_html (bool): True if the URL is raw HTML, False otherwise.
_url_display (str): The display name for the URL (web, local file, or raw HTML).
"""
def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False):
"""
Initializes the CacheContext with the provided URL and cache mode.
@@ -48,8 +50,8 @@ class CacheContext:
self.url = url
self.cache_mode = cache_mode
self.always_bypass = always_bypass
self.is_cacheable = url.startswith(('http://', 'https://', 'file://'))
self.is_web_url = url.startswith(('http://', 'https://'))
self.is_cacheable = url.startswith(("http://", "https://", "file://"))
self.is_web_url = url.startswith(("http://", "https://"))
self.is_local_file = url.startswith("file://")
self.is_raw_html = url.startswith("raw:")
self._url_display = url if not self.is_raw_html else "Raw HTML"
@@ -94,7 +96,7 @@ def _legacy_to_cache_mode(
disable_cache: bool = False,
bypass_cache: bool = False,
no_cache_read: bool = False,
no_cache_write: bool = False
no_cache_write: bool = False,
) -> CacheMode:
"""
Converts legacy cache parameters to the new CacheMode enum.

View File

@@ -3,7 +3,7 @@ import re
from collections import Counter
import string
from .model_loader import load_nltk_punkt
from .utils import *
# Define the abstract base class for chunking strategies
class ChunkingStrategy(ABC):
@@ -24,19 +24,23 @@ class ChunkingStrategy(ABC):
"""
pass
# Create an identity chunking strategy f(x) = [x]
class IdentityChunking(ChunkingStrategy):
"""
Chunking strategy that returns the input text as a single chunk.
"""
def chunk(self, text: str) -> list:
return [text]
# Regex-based chunking
class RegexChunking(ChunkingStrategy):
"""
Chunking strategy that splits text based on regular expression patterns.
"""
def __init__(self, patterns=None, **kwargs):
"""
Initialize the RegexChunking object.
@@ -45,7 +49,7 @@ class RegexChunking(ChunkingStrategy):
patterns (list): A list of regular expression patterns to split text.
"""
if patterns is None:
patterns = [r'\n\n'] # Default split pattern
patterns = [r"\n\n"] # Default split pattern
self.patterns = patterns
def chunk(self, text: str) -> list:
@@ -57,18 +61,19 @@ class RegexChunking(ChunkingStrategy):
paragraphs = new_paragraphs
return paragraphs
# NLP-based sentence chunking
class NlpSentenceChunking(ChunkingStrategy):
"""
Chunking strategy that splits text into sentences using NLTK's sentence tokenizer.
"""
def __init__(self, **kwargs):
"""
Initialize the NlpSentenceChunking object.
"""
load_nltk_punkt()
def chunk(self, text: str) -> list:
# Improved regex for sentence splitting
# sentence_endings = re.compile(
@@ -77,11 +82,13 @@ class NlpSentenceChunking(ChunkingStrategy):
# sentences = sentence_endings.split(text)
# sens = [sent.strip() for sent in sentences if sent]
from nltk.tokenize import sent_tokenize
sentences = sent_tokenize(text)
sens = [sent.strip() for sent in sentences]
return list(set(sens))
# Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy):
"""
@@ -100,6 +107,7 @@ class TopicSegmentationChunking(ChunkingStrategy):
num_keywords (int): The number of keywords to extract for each topic segment.
"""
import nltk as nl
self.tokenizer = nl.tokenize.TextTilingTokenizer()
self.num_keywords = num_keywords
@@ -111,8 +119,14 @@ class TopicSegmentationChunking(ChunkingStrategy):
def extract_keywords(self, text: str) -> list:
# Tokenize and remove stopwords and punctuation
import nltk as nl
tokens = nl.toknize.word_tokenize(text)
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
tokens = [
token.lower()
for token in tokens
if token not in nl.corpus.stopwords.words("english")
and token not in string.punctuation
]
# Calculate frequency distribution
freq_dist = Counter(tokens)
@@ -123,9 +137,12 @@ class TopicSegmentationChunking(ChunkingStrategy):
# Segment the text into topics
segments = self.chunk(text)
# Extract keywords for each topic segment
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
segments_with_topics = [
(segment, self.extract_keywords(segment)) for segment in segments
]
return segments_with_topics
# Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy):
"""
@@ -136,6 +153,7 @@ class FixedLengthWordChunking(ChunkingStrategy):
2. Create chunks of fixed length
3. Return the list of chunks
"""
def __init__(self, chunk_size=100, **kwargs):
"""
Initialize the fixed-length word chunking strategy with the given chunk size.
@@ -147,7 +165,11 @@ class FixedLengthWordChunking(ChunkingStrategy):
def chunk(self, text: str) -> list:
words = text.split()
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
return [
" ".join(words[i : i + self.chunk_size])
for i in range(0, len(words), self.chunk_size)
]
# Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy):
@@ -159,6 +181,7 @@ class SlidingWindowChunking(ChunkingStrategy):
2. Create chunks of fixed length
3. Return the list of chunks
"""
def __init__(self, window_size=100, step=50, **kwargs):
"""
Initialize the sliding window chunking strategy with the given window size and
@@ -179,15 +202,16 @@ class SlidingWindowChunking(ChunkingStrategy):
return [text]
for i in range(0, len(words) - self.window_size + 1, self.step):
chunk = ' '.join(words[i:i + self.window_size])
chunk = " ".join(words[i : i + self.window_size])
chunks.append(chunk)
# Handle the last chunk if it doesn't align perfectly
if i + self.window_size < len(words):
chunks.append(' '.join(words[-self.window_size:]))
chunks.append(" ".join(words[-self.window_size :]))
return chunks
class OverlappingWindowChunking(ChunkingStrategy):
"""
Chunking strategy that splits text into overlapping word chunks.
@@ -198,6 +222,7 @@ class OverlappingWindowChunking(ChunkingStrategy):
3. Slide the window by the overlap size
4. Return the list of chunks
"""
def __init__(self, window_size=1000, overlap=100, **kwargs):
"""
Initialize the overlapping window chunking strategy with the given window size and
@@ -220,7 +245,7 @@ class OverlappingWindowChunking(ChunkingStrategy):
start = 0
while start < len(words):
end = start + self.window_size
chunk = ' '.join(words[start:end])
chunk = " ".join(words[start:end])
chunks.append(chunk)
if end >= len(words):

View File

@@ -8,14 +8,21 @@ from .async_logger import AsyncLogger
logger = AsyncLogger(verbose=True)
docs_manager = DocsManager(logger)
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
"""Print formatted table with headers and rows"""
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+'
border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+"
def format_row(row):
return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
for cell, w in zip(row, widths)) + '|'
return (
"|"
+ "|".join(
f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
for cell, w in zip(row, widths)
)
+ "|"
)
click.echo(border)
click.echo(format_row(headers))
@@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
click.echo(format_row(row))
click.echo(border)
@click.group()
def cli():
"""Crawl4AI Command Line Interface"""
pass
@cli.group()
def docs():
"""Documentation operations"""
pass
@docs.command()
@click.argument('sections', nargs=-1)
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended')
@click.argument("sections", nargs=-1)
@click.option(
"--mode", type=click.Choice(["extended", "condensed"]), default="extended"
)
def combine(sections: tuple, mode: str):
"""Combine documentation sections"""
try:
@@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str):
logger.error(str(e), tag="ERROR")
sys.exit(1)
@docs.command()
@click.argument('query')
@click.option('--top-k', '-k', default=5)
@click.option('--build-index', is_flag=True, help='Build index if missing')
@click.argument("query")
@click.option("--top-k", "-k", default=5)
@click.option("--build-index", is_flag=True, help="Build index if missing")
def search(query: str, top_k: int, build_index: bool):
"""Search documentation"""
try:
result = docs_manager.search(query, top_k)
if result == "No search index available. Call build_search_index() first.":
if build_index or click.confirm('No search index found. Build it now?'):
if build_index or click.confirm("No search index found. Build it now?"):
asyncio.run(docs_manager.llm_text.generate_index_files())
result = docs_manager.search(query, top_k)
click.echo(result)
@@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool):
click.echo(f"Error: {str(e)}", err=True)
sys.exit(1)
@docs.command()
def update():
"""Update docs from GitHub"""
@@ -73,22 +87,25 @@ def update():
click.echo(f"Error: {str(e)}", err=True)
sys.exit(1)
@docs.command()
@click.option('--force-facts', is_flag=True, help='Force regenerate fact files')
@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache')
@click.option("--force-facts", is_flag=True, help="Force regenerate fact files")
@click.option("--clear-cache", is_flag=True, help="Clear BM25 cache")
def index(force_facts: bool, clear_cache: bool):
"""Build or rebuild search indexes"""
try:
asyncio.run(docs_manager.ensure_docs_exist())
asyncio.run(docs_manager.llm_text.generate_index_files(
force_generate_facts=force_facts,
clear_bm25_cache=clear_cache
))
asyncio.run(
docs_manager.llm_text.generate_index_files(
force_generate_facts=force_facts, clear_bm25_cache=clear_cache
)
)
click.echo("Search indexes built successfully")
except Exception as e:
click.echo(f"Error: {str(e)}", err=True)
sys.exit(1)
# Add docs list command
@docs.command()
def list():
@@ -101,5 +118,6 @@ def list():
click.echo(f"Error: {str(e)}", err=True)
sys.exit(1)
if __name__ == '__main__':
if __name__ == "__main__":
cli()

View File

@@ -22,7 +22,7 @@ PROVIDER_MODELS = {
}
# Chunk token threshold
CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens
CHUNK_TOKEN_THRESHOLD = 2**11 # 2048 tokens
OVERLAP_RATE = 0.1
WORD_TOKEN_RATE = 1.3
@@ -30,19 +30,41 @@ WORD_TOKEN_RATE = 1.3
MIN_WORD_THRESHOLD = 1
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1
IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height']
ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark']
IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"]
ONLY_TEXT_ELIGIBLE_TAGS = [
"b",
"i",
"u",
"span",
"del",
"ins",
"sub",
"sup",
"strong",
"em",
"code",
"kbd",
"var",
"s",
"q",
"abbr",
"cite",
"dfn",
"time",
"small",
"mark",
]
SOCIAL_MEDIA_DOMAINS = [
'facebook.com',
'twitter.com',
'x.com',
'linkedin.com',
'instagram.com',
'pinterest.com',
'tiktok.com',
'snapchat.com',
'reddit.com',
]
"facebook.com",
"twitter.com",
"x.com",
"linkedin.com",
"instagram.com",
"pinterest.com",
"tiktok.com",
"snapchat.com",
"reddit.com",
]
# Threshold for the Image extraction - Range is 1 to 6
# Images are scored based on point based system, to filter based on usefulness. Points are assigned
@@ -60,5 +82,5 @@ NEED_MIGRATION = True
URL_LOG_SHORTEN_LENGTH = 30
SHOW_DEPRECATION_WARNINGS = True
SCREENSHOT_HEIGHT_TRESHOLD = 10000
PAGE_TIMEOUT=60000
DOWNLOAD_PAGE_TIMEOUT=60000
PAGE_TIMEOUT = 60000
DOWNLOAD_PAGE_TIMEOUT = 60000

View File

@@ -1,44 +1,85 @@
import re
from bs4 import BeautifulSoup, Tag
from typing import List, Tuple, Dict
from typing import List, Tuple
from rank_bm25 import BM25Okapi
from time import perf_counter
from collections import deque
from bs4 import BeautifulSoup, NavigableString, Tag, Comment
from bs4 import NavigableString, Comment
from .utils import clean_tokens
from abc import ABC, abstractmethod
import math
from snowballstemmer import stemmer
class RelevantContentFilter(ABC):
"""Abstract base class for content filtering strategies"""
def __init__(self, user_query: str = None):
self.user_query = user_query
self.included_tags = {
# Primary structure
'article', 'main', 'section', 'div',
"article",
"main",
"section",
"div",
# List structures
'ul', 'ol', 'li', 'dl', 'dt', 'dd',
"ul",
"ol",
"li",
"dl",
"dt",
"dd",
# Text content
'p', 'span', 'blockquote', 'pre', 'code',
"p",
"span",
"blockquote",
"pre",
"code",
# Headers
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
# Tables
'table', 'thead', 'tbody', 'tr', 'td', 'th',
"table",
"thead",
"tbody",
"tr",
"td",
"th",
# Other semantic elements
'figure', 'figcaption', 'details', 'summary',
"figure",
"figcaption",
"details",
"summary",
# Text formatting
'em', 'strong', 'b', 'i', 'mark', 'small',
"em",
"strong",
"b",
"i",
"mark",
"small",
# Rich content
'time', 'address', 'cite', 'q'
"time",
"address",
"cite",
"q",
}
self.excluded_tags = {
'nav', 'footer', 'header', 'aside', 'script',
'style', 'form', 'iframe', 'noscript'
"nav",
"footer",
"header",
"aside",
"script",
"style",
"form",
"iframe",
"noscript",
}
self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}
self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"}
self.negative_patterns = re.compile(
r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share',
re.I
r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I
)
self.min_word_count = 2
@@ -62,28 +103,30 @@ class RelevantContentFilter(ABC):
except Exception:
pass
if soup.find('h1'):
query_parts.append(soup.find('h1').get_text())
if soup.find("h1"):
query_parts.append(soup.find("h1").get_text())
# Meta tags
temp = ""
for meta_name in ['keywords', 'description']:
meta = soup.find('meta', attrs={'name': meta_name})
if meta and meta.get('content'):
query_parts.append(meta['content'])
temp += meta['content']
for meta_name in ["keywords", "description"]:
meta = soup.find("meta", attrs={"name": meta_name})
if meta and meta.get("content"):
query_parts.append(meta["content"])
temp += meta["content"]
# If still empty, grab first significant paragraph
if not temp:
# Find the first tag P thatits text contains more than 50 characters
for p in body.find_all('p'):
for p in body.find_all("p"):
if len(p.get_text()) > 150:
query_parts.append(p.get_text()[:150])
break
return ' '.join(filter(None, query_parts))
return " ".join(filter(None, query_parts))
def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]:
def extract_text_chunks(
self, body: Tag, min_word_threshold: int = None
) -> List[Tuple[str, str]]:
"""
Extracts text chunks from a BeautifulSoup body element while preserving order.
Returns list of tuples (text, tag_name) for classification.
@@ -96,14 +139,42 @@ class RelevantContentFilter(ABC):
"""
# Tags to ignore - inline elements that shouldn't break text flow
INLINE_TAGS = {
'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code',
'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q',
'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup',
'textarea', 'time', 'tt', 'var'
"a",
"abbr",
"acronym",
"b",
"bdo",
"big",
"br",
"button",
"cite",
"code",
"dfn",
"em",
"i",
"img",
"input",
"kbd",
"label",
"map",
"object",
"q",
"samp",
"script",
"select",
"small",
"span",
"strong",
"sub",
"sup",
"textarea",
"time",
"tt",
"var",
}
# Tags that typically contain meaningful headers
HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'}
HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"}
chunks = []
current_text = []
@@ -111,9 +182,8 @@ class RelevantContentFilter(ABC):
def should_break_chunk(tag: Tag) -> bool:
"""Determine if a tag should cause a break in the current text chunk"""
return (
tag.name not in INLINE_TAGS
and not (tag.name == 'p' and len(current_text) == 0)
return tag.name not in INLINE_TAGS and not (
tag.name == "p" and len(current_text) == 0
)
# Use deque for efficient push/pop operations
@@ -125,9 +195,11 @@ class RelevantContentFilter(ABC):
if visited:
# End of block element - flush accumulated text
if current_text and should_break_chunk(element):
text = ' '.join(''.join(current_text).split())
text = " ".join("".join(current_text).split())
if text:
tag_type = 'header' if element.name in HEADER_TAGS else 'content'
tag_type = (
"header" if element.name in HEADER_TAGS else "content"
)
chunks.append((chunk_index, text, tag_type, element))
chunk_index += 1
current_text = []
@@ -153,18 +225,23 @@ class RelevantContentFilter(ABC):
# Handle any remaining text
if current_text:
text = ' '.join(''.join(current_text).split())
text = " ".join("".join(current_text).split())
if text:
chunks.append((chunk_index, text, 'content', body))
chunks.append((chunk_index, text, "content", body))
if min_word_threshold:
chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold]
chunks = [
chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold
]
return chunks
def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]:
def _deprecated_extract_text_chunks(
self, soup: BeautifulSoup
) -> List[Tuple[int, str, Tag]]:
"""Common method for extracting text chunks"""
_text_cache = {}
def fast_text(element: Tag) -> str:
elem_id = id(element)
if elem_id in _text_cache:
@@ -175,7 +252,7 @@ class RelevantContentFilter(ABC):
text = content.strip()
if text:
texts.append(text)
result = ' '.join(texts)
result = " ".join(texts)
_text_cache[elem_id] = result
return result
@@ -210,10 +287,9 @@ class RelevantContentFilter(ABC):
"""Common method for exclusion logic"""
if tag.name in self.excluded_tags:
return True
class_id = ' '.join(filter(None, [
' '.join(tag.get('class', [])),
tag.get('id', '')
]))
class_id = " ".join(
filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")])
)
return bool(self.negative_patterns.search(class_id))
def clean_element(self, tag: Tag) -> str:
@@ -221,8 +297,16 @@ class RelevantContentFilter(ABC):
if not tag or not isinstance(tag, Tag):
return ""
unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'}
unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'}
unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"}
unwanted_attrs = {
"style",
"onclick",
"onmouseover",
"align",
"bgcolor",
"class",
"id",
}
# Use string builder pattern for better performance
builder = []
@@ -237,28 +321,29 @@ class RelevantContentFilter(ABC):
return
# Start tag
builder.append(f'<{elem.name}')
builder.append(f"<{elem.name}")
# Add cleaned attributes
attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs}
for key, value in attrs.items():
builder.append(f' {key}="{value}"')
builder.append('>')
builder.append(">")
# Process children
for child in elem.children:
render_tag(child)
# Close tag
builder.append(f'</{elem.name}>')
builder.append(f"</{elem.name}>")
try:
render_tag(tag)
return ''.join(builder)
return "".join(builder)
except Exception:
return str(tag) # Fallback to original if anything fails
class BM25ContentFilter(RelevantContentFilter):
"""
Content filtering using BM25 algorithm with priority tag handling.
@@ -280,7 +365,13 @@ class BM25ContentFilter(RelevantContentFilter):
Methods:
filter_content(self, html: str, min_word_threshold: int = None)
"""
def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'):
def __init__(
self,
user_query: str = None,
bm25_threshold: float = 1.0,
language: str = "english",
):
"""
Initializes the BM25ContentFilter class, if not provided, falls back to page metadata.
@@ -295,17 +386,17 @@ class BM25ContentFilter(RelevantContentFilter):
super().__init__(user_query=user_query)
self.bm25_threshold = bm25_threshold
self.priority_tags = {
'h1': 5.0,
'h2': 4.0,
'h3': 3.0,
'title': 4.0,
'strong': 2.0,
'b': 1.5,
'em': 1.5,
'blockquote': 2.0,
'code': 2.0,
'pre': 1.5,
'th': 1.5, # Table headers
"h1": 5.0,
"h2": 4.0,
"h3": 3.0,
"title": 4.0,
"strong": 2.0,
"b": 1.5,
"em": 1.5,
"blockquote": 2.0,
"code": 2.0,
"pre": 1.5,
"th": 1.5, # Table headers
}
self.stemmer = stemmer(language)
@@ -327,13 +418,13 @@ class BM25ContentFilter(RelevantContentFilter):
if not html or not isinstance(html, str):
return []
soup = BeautifulSoup(html, 'lxml')
soup = BeautifulSoup(html, "lxml")
# Check if body is present
if not soup.body:
# Wrap in body tag if missing
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
body = soup.find('body')
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
body = soup.find("body")
query = self.extract_page_query(soup, body)
@@ -354,9 +445,13 @@ class BM25ContentFilter(RelevantContentFilter):
# for _, chunk, _, _ in candidates]
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()]
for _, chunk, _, _ in candidates]
tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()]
tokenized_corpus = [
[self.stemmer.stemWord(word) for word in chunk.lower().split()]
for _, chunk, _, _ in candidates
]
tokenized_query = [
self.stemmer.stemWord(word) for word in query.lower().split()
]
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
# for _, chunk, _, _ in candidates]
@@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter):
# Filter candidates by threshold
selected_candidates = [
(index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates
(index, chunk, tag)
for adjusted_score, index, chunk, tag in adjusted_candidates
if adjusted_score >= self.bm25_threshold
]
@@ -390,6 +486,7 @@ class BM25ContentFilter(RelevantContentFilter):
return [self.clean_element(tag) for _, _, tag in selected_candidates]
class PruningContentFilter(RelevantContentFilter):
"""
Content filtering using pruning algorithm with dynamic threshold.
@@ -411,8 +508,14 @@ class PruningContentFilter(RelevantContentFilter):
Methods:
filter_content(self, html: str, min_word_threshold: int = None):
"""
def __init__(self, user_query: str = None, min_word_threshold: int = None,
threshold_type: str = 'fixed', threshold: float = 0.48):
def __init__(
self,
user_query: str = None,
min_word_threshold: int = None,
threshold_type: str = "fixed",
threshold: float = 0.48,
):
"""
Initializes the PruningContentFilter class, if not provided, falls back to page metadata.
@@ -432,49 +535,49 @@ class PruningContentFilter(RelevantContentFilter):
# Add tag importance for dynamic threshold
self.tag_importance = {
'article': 1.5,
'main': 1.4,
'section': 1.3,
'p': 1.2,
'h1': 1.4,
'h2': 1.3,
'h3': 1.2,
'div': 0.7,
'span': 0.6
"article": 1.5,
"main": 1.4,
"section": 1.3,
"p": 1.2,
"h1": 1.4,
"h2": 1.3,
"h3": 1.2,
"div": 0.7,
"span": 0.6,
}
# Metric configuration
self.metric_config = {
'text_density': True,
'link_density': True,
'tag_weight': True,
'class_id_weight': True,
'text_length': True,
"text_density": True,
"link_density": True,
"tag_weight": True,
"class_id_weight": True,
"text_length": True,
}
self.metric_weights = {
'text_density': 0.4,
'link_density': 0.2,
'tag_weight': 0.2,
'class_id_weight': 0.1,
'text_length': 0.1,
"text_density": 0.4,
"link_density": 0.2,
"tag_weight": 0.2,
"class_id_weight": 0.1,
"text_length": 0.1,
}
self.tag_weights = {
'div': 0.5,
'p': 1.0,
'article': 1.5,
'section': 1.0,
'span': 0.3,
'li': 0.5,
'ul': 0.5,
'ol': 0.5,
'h1': 1.2,
'h2': 1.1,
'h3': 1.0,
'h4': 0.9,
'h5': 0.8,
'h6': 0.7,
"div": 0.5,
"p": 1.0,
"article": 1.5,
"section": 1.0,
"span": 0.3,
"li": 0.5,
"ul": 0.5,
"ol": 0.5,
"h1": 1.2,
"h2": 1.1,
"h3": 1.0,
"h4": 0.9,
"h5": 0.8,
"h6": 0.7,
}
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
@@ -495,22 +598,22 @@ class PruningContentFilter(RelevantContentFilter):
if not html or not isinstance(html, str):
return []
soup = BeautifulSoup(html, 'lxml')
soup = BeautifulSoup(html, "lxml")
if not soup.body:
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
# Remove comments and unwanted tags
self._remove_comments(soup)
self._remove_unwanted_tags(soup)
# Prune tree starting from body
body = soup.find('body')
body = soup.find("body")
self._prune_tree(body)
# Extract remaining content as list of HTML strings
content_blocks = []
for element in body.children:
if isinstance(element, str) or not hasattr(element, 'name'):
if isinstance(element, str) or not hasattr(element, "name"):
continue
if len(element.get_text(strip=True)) > 0:
content_blocks.append(str(element))
@@ -535,24 +638,28 @@ class PruningContentFilter(RelevantContentFilter):
Args:
node (Tag): The node from which the pruning starts.
"""
if not node or not hasattr(node, 'name') or node.name is None:
if not node or not hasattr(node, "name") or node.name is None:
return
text_len = len(node.get_text(strip=True))
tag_len = len(node.encode_contents().decode('utf-8'))
link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s)
tag_len = len(node.encode_contents().decode("utf-8"))
link_text_len = sum(
len(s.strip())
for s in (a.string for a in node.find_all("a", recursive=False))
if s
)
metrics = {
'node': node,
'tag_name': node.name,
'text_len': text_len,
'tag_len': tag_len,
'link_text_len': link_text_len
"node": node,
"tag_name": node.name,
"text_len": text_len,
"tag_len": tag_len,
"link_text_len": link_text_len,
}
score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len)
if self.threshold_type == 'fixed':
if self.threshold_type == "fixed":
should_remove = score < self.threshold
else: # dynamic
tag_importance = self.tag_importance.get(node.name, 0.7)
@@ -572,7 +679,7 @@ class PruningContentFilter(RelevantContentFilter):
if should_remove:
node.decompose()
else:
children = [child for child in node.children if hasattr(child, 'name')]
children = [child for child in node.children if hasattr(child, "name")]
for child in children:
self._prune_tree(child)
@@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter):
"""Computes the composite score"""
if self.min_word_threshold:
# Get raw text from metrics node - avoid extra processing
text = metrics['node'].get_text(strip=True)
word_count = text.count(' ') + 1
text = metrics["node"].get_text(strip=True)
word_count = text.count(" ") + 1
if word_count < self.min_word_threshold:
return -1.0 # Guaranteed removal
score = 0.0
total_weight = 0.0
if self.metric_config['text_density']:
if self.metric_config["text_density"]:
density = text_len / tag_len if tag_len > 0 else 0
score += self.metric_weights['text_density'] * density
total_weight += self.metric_weights['text_density']
score += self.metric_weights["text_density"] * density
total_weight += self.metric_weights["text_density"]
if self.metric_config['link_density']:
if self.metric_config["link_density"]:
density = 1 - (link_text_len / text_len if text_len > 0 else 0)
score += self.metric_weights['link_density'] * density
total_weight += self.metric_weights['link_density']
score += self.metric_weights["link_density"] * density
total_weight += self.metric_weights["link_density"]
if self.metric_config['tag_weight']:
tag_score = self.tag_weights.get(metrics['tag_name'], 0.5)
score += self.metric_weights['tag_weight'] * tag_score
total_weight += self.metric_weights['tag_weight']
if self.metric_config["tag_weight"]:
tag_score = self.tag_weights.get(metrics["tag_name"], 0.5)
score += self.metric_weights["tag_weight"] * tag_score
total_weight += self.metric_weights["tag_weight"]
if self.metric_config['class_id_weight']:
class_score = self._compute_class_id_weight(metrics['node'])
score += self.metric_weights['class_id_weight'] * max(0, class_score)
total_weight += self.metric_weights['class_id_weight']
if self.metric_config["class_id_weight"]:
class_score = self._compute_class_id_weight(metrics["node"])
score += self.metric_weights["class_id_weight"] * max(0, class_score)
total_weight += self.metric_weights["class_id_weight"]
if self.metric_config['text_length']:
score += self.metric_weights['text_length'] * math.log(text_len + 1)
total_weight += self.metric_weights['text_length']
if self.metric_config["text_length"]:
score += self.metric_weights["text_length"] * math.log(text_len + 1)
total_weight += self.metric_weights["text_length"]
return score / total_weight if total_weight > 0 else 0
def _compute_class_id_weight(self, node):
"""Computes the class ID weight"""
class_id_score = 0
if 'class' in node.attrs:
classes = ' '.join(node['class'])
if "class" in node.attrs:
classes = " ".join(node["class"])
if self.negative_patterns.match(classes):
class_id_score -= 0.5
if 'id' in node.attrs:
element_id = node['id']
if "id" in node.attrs:
element_id = node["id"]
if self.negative_patterns.match(element_id):
class_id_score -= 0.5
return class_id_score

File diff suppressed because it is too large Load Diff

View File

@@ -15,32 +15,30 @@ import logging, time
import base64
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from typing import List, Callable
from typing import Callable
import requests
import os
from pathlib import Path
from .utils import *
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
logger = logging.getLogger("selenium.webdriver.remote.remote_connection")
logger.setLevel(logging.WARNING)
logger_driver = logging.getLogger('selenium.webdriver.common.service')
logger_driver = logging.getLogger("selenium.webdriver.common.service")
logger_driver.setLevel(logging.WARNING)
urllib3_logger = logging.getLogger('urllib3.connectionpool')
urllib3_logger = logging.getLogger("urllib3.connectionpool")
urllib3_logger.setLevel(logging.WARNING)
# Disable http.client logging
http_client_logger = logging.getLogger('http.client')
http_client_logger = logging.getLogger("http.client")
http_client_logger.setLevel(logging.WARNING)
# Disable driver_finder and service logging
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder')
driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder")
driver_finder_logger.setLevel(logging.WARNING)
class CrawlerStrategy(ABC):
@abstractmethod
def crawl(self, url: str, **kwargs) -> str:
@@ -58,8 +56,9 @@ class CrawlerStrategy(ABC):
def set_hook(self, hook_type: str, hook: Callable):
pass
class CloudCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html = False):
def __init__(self, use_cached_html=False):
super().__init__()
self.use_cached_html = use_cached_html
@@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
html = response["results"][0]["html"]
return sanitize_input_encode(html)
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
super().__init__()
@@ -87,9 +87,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if kwargs.get("user_agent"):
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
else:
user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
user_agent = kwargs.get(
"user_agent",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
)
self.options.add_argument(f"--user-agent={user_agent}")
self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
self.options.add_argument(
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
)
self.options.headless = kwargs.get("headless", True)
if self.options.headless:
@@ -123,11 +128,11 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# Hooks
self.hooks = {
'on_driver_created': None,
'on_user_agent_updated': None,
'before_get_url': None,
'after_get_url': None,
'before_return_html': None
"on_driver_created": None,
"on_user_agent_updated": None,
"before_get_url": None,
"after_get_url": None,
"before_return_html": None,
}
# chromedriver_autoinstaller.install()
@@ -138,7 +143,6 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver()
# self.service = Service(chromedriver_autoinstaller.install())
# chromedriver_path = ChromeDriverManager().install()
# self.service = Service(chromedriver_path)
# self.service.log_path = "NUL"
@@ -148,14 +152,12 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.service = Service()
self.driver = webdriver.Chrome(options=self.options)
self.driver = self.execute_hook('on_driver_created', self.driver)
self.driver = self.execute_hook("on_driver_created", self.driver)
if kwargs.get("cookies"):
for cookie in kwargs.get("cookies"):
self.driver.add_cookie(cookie)
def set_hook(self, hook_type: str, hook: Callable):
if hook_type in self.hooks:
self.hooks[hook_type] = hook
@@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if isinstance(result, webdriver.Chrome):
return result
else:
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.")
raise TypeError(
f"Hook {hook_type} must return an instance of webdriver.Chrome or None."
)
# If the hook returns None or there is no hook, return self.driver
return self.driver
@@ -178,13 +182,13 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.add_argument(f"user-agent={user_agent}")
self.driver.quit()
self.driver = webdriver.Chrome(service=self.service, options=self.options)
self.driver = self.execute_hook('on_user_agent_updated', self.driver)
self.driver = self.execute_hook("on_user_agent_updated", self.driver)
def set_custom_headers(self, headers: dict):
# Enable Network domain for sending headers
self.driver.execute_cdp_cmd('Network.enable', {})
self.driver.execute_cdp_cmd("Network.enable", {})
# Set extra HTTP headers
self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers})
self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers})
def _ensure_page_load(self, max_checks=6, check_interval=0.01):
initial_length = len(self.driver.page_source)
@@ -202,36 +206,53 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def crawl(self, url: str, **kwargs) -> str:
# Create md5 hash of the URL
import hashlib
url_hash = hashlib.md5(url.encode()).hexdigest()
if self.use_cached_html:
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash)
cache_file_path = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
".crawl4ai",
"cache",
url_hash,
)
if os.path.exists(cache_file_path):
with open(cache_file_path, "r") as f:
return sanitize_input_encode(f.read())
try:
self.driver = self.execute_hook('before_get_url', self.driver)
self.driver = self.execute_hook("before_get_url", self.driver)
if self.verbose:
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
self.driver.get(url) #<html><head></head><body></body></html>
self.driver.get(url) # <html><head></head><body></body></html>
WebDriverWait(self.driver, 20).until(
lambda d: d.execute_script('return document.readyState') == 'complete'
lambda d: d.execute_script("return document.readyState") == "complete"
)
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
)
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
self.driver.execute_script(
"window.scrollTo(0, document.body.scrollHeight);"
)
self.driver = self.execute_hook('after_get_url', self.driver)
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source
can_not_be_done_headless = False # Look at my creativity for naming variables
self.driver = self.execute_hook("after_get_url", self.driver)
html = sanitize_input_encode(
self._ensure_page_load()
) # self.driver.page_source
can_not_be_done_headless = (
False # Look at my creativity for naming variables
)
# TODO: Very ugly approach, but promise to change it!
if kwargs.get('bypass_headless', False) or html == "<html><head></head><body></body></html>":
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...")
if (
kwargs.get("bypass_headless", False)
or html == "<html><head></head><body></body></html>"
):
print(
"[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..."
)
can_not_be_done_headless = True
options = Options()
options.headless = False
@@ -239,7 +260,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
options.add_argument("--window-size=5,5")
driver = webdriver.Chrome(service=self.service, options=options)
driver.get(url)
self.driver = self.execute_hook('after_get_url', driver)
self.driver = self.execute_hook("after_get_url", driver)
html = sanitize_input_encode(driver.page_source)
driver.quit()
@@ -249,17 +270,21 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.driver.execute_script(self.js_code)
# Optionally, wait for some condition after executing the JS code
WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete"
lambda driver: driver.execute_script("return document.readyState")
== "complete"
)
elif self.js_code and type(self.js_code) == list:
for js in self.js_code:
self.driver.execute_script(js)
WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete"
lambda driver: driver.execute_script(
"return document.readyState"
)
== "complete"
)
# Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky)
wait_for = kwargs.get('wait_for', False)
wait_for = kwargs.get("wait_for", False)
if wait_for:
if callable(wait_for):
print("[LOG] 🔄 Waiting for condition...")
@@ -272,10 +297,15 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if not can_not_be_done_headless:
html = sanitize_input_encode(self.driver.page_source)
self.driver = self.execute_hook('before_return_html', self.driver, html)
self.driver = self.execute_hook("before_return_html", self.driver, html)
# Store in cache
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash)
cache_file_path = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
".crawl4ai",
"cache",
url_hash,
)
with open(cache_file_path, "w", encoding="utf-8") as f:
f.write(html)
@@ -284,16 +314,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return html
except InvalidArgumentException as e:
if not hasattr(e, 'msg'):
if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e))
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
except WebDriverException as e:
# If e does nlt have msg attribute create it and set it to str(e)
if not hasattr(e, 'msg'):
if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e))
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
except Exception as e:
if not hasattr(e, 'msg'):
if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e))
raise Exception(f"Failed to crawl {url}: {e.msg}")
@@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
try:
# Get the dimensions of the page
total_width = self.driver.execute_script("return document.body.scrollWidth")
total_height = self.driver.execute_script("return document.body.scrollHeight")
total_height = self.driver.execute_script(
"return document.body.scrollHeight"
)
# Set the window size to the dimensions of the page
self.driver.set_window_size(total_width, total_height)
@@ -313,23 +345,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
image = Image.open(BytesIO(screenshot))
# Convert image to RGB mode (this will handle both RGB and RGBA images)
rgb_image = image.convert('RGB')
rgb_image = image.convert("RGB")
# Convert to JPEG and compress
buffered = BytesIO()
rgb_image.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
if self.verbose:
print(f"[LOG] 📸 Screenshot taken and converted to base64")
print("[LOG] 📸 Screenshot taken and converted to base64")
return img_base64
except Exception as e:
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}")
error_message = sanitize_input_encode(
f"Failed to take screenshot: {str(e)}"
)
print(error_message)
# Generate an image with black background
img = Image.new('RGB', (800, 600), color='black')
img = Image.new("RGB", (800, 600), color="black")
draw = ImageDraw.Draw(img)
# Load a font
@@ -352,7 +386,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# Convert to base64
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_base64

View File

@@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra
os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
def init_db():
global DB_PATH
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY,
html TEXT,
@@ -24,31 +26,42 @@ def init_db():
metadata TEXT DEFAULT "{}",
screenshot TEXT DEFAULT ""
)
''')
"""
)
conn.commit()
conn.close()
def alter_db_add_screenshot(new_column: str = "media"):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
cursor.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error altering database to add screenshot column: {e}")
def check_db_path():
if not DB_PATH:
raise ValueError("Database path is not set or is empty.")
def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
def get_cached_url(
url: str,
) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,))
cursor.execute(
"SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?",
(url,),
)
result = cursor.fetchone()
conn.close()
return result
@@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str
print(f"Error retrieving cached URL: {e}")
return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""):
def cache_url(
url: str,
html: str,
cleaned_html: str,
markdown: str,
extracted_content: str,
success: bool,
media: str = "{}",
links: str = "{}",
metadata: str = "{}",
screenshot: str = "",
):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
@@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
links = excluded.links,
metadata = excluded.metadata,
screenshot = excluded.screenshot
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot))
""",
(
url,
html,
cleaned_html,
markdown,
extracted_content,
success,
media,
links,
metadata,
screenshot,
),
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error caching URL: {e}")
def get_total_count() -> int:
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM crawled_data')
cursor.execute("SELECT COUNT(*) FROM crawled_data")
result = cursor.fetchone()
conn.close()
return result[0]
@@ -93,43 +133,48 @@ def get_total_count() -> int:
print(f"Error getting total count: {e}")
return 0
def clear_db():
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('DELETE FROM crawled_data')
cursor.execute("DELETE FROM crawled_data")
conn.commit()
conn.close()
except Exception as e:
print(f"Error clearing database: {e}")
def flush_db():
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('DROP TABLE crawled_data')
cursor.execute("DROP TABLE crawled_data")
conn.commit()
conn.close()
except Exception as e:
print(f"Error flushing database: {e}")
def update_existing_records(new_column: str = "media", default_value: str = "{}"):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL')
cursor.execute(
f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL'
)
conn.commit()
conn.close()
except Exception as e:
print(f"Error updating existing records: {e}")
if __name__ == "__main__":
# Delete the existing database file
if os.path.exists(DB_PATH):
os.remove(DB_PATH)
init_db()
# alter_db_add_screenshot("COL_NAME")

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from crawl4ai.async_logger import AsyncLogger
from crawl4ai.llmtxt import AsyncLLMTextManager
class DocsManager:
def __init__(self, logger=None):
self.docs_dir = Path.home() / ".crawl4ai" / "docs"
@@ -21,7 +22,10 @@ class DocsManager:
"""Copy from local docs or download from GitHub"""
try:
# Try local first
if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))):
if self.local_docs.exists() and (
any(self.local_docs.glob("*.md"))
or any(self.local_docs.glob("*.tokens"))
):
# Empty the local docs directory
for file_path in self.docs_dir.glob("*.md"):
file_path.unlink()
@@ -36,14 +40,14 @@ class DocsManager:
# Fallback to GitHub
response = requests.get(
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
headers={'Accept': 'application/vnd.github.v3+json'}
headers={"Accept": "application/vnd.github.v3+json"},
)
response.raise_for_status()
for item in response.json():
if item['type'] == 'file' and item['name'].endswith('.md'):
content = requests.get(item['download_url']).text
with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f:
if item["type"] == "file" and item["name"].endswith(".md"):
content = requests.get(item["download_url"]).text
with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f:
f.write(content)
return True
@@ -57,7 +61,11 @@ class DocsManager:
# Remove [0-9]+_ prefix
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
# Exclude those end with .xs.md and .q.md
names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")]
names = [
name
for name in names
if not name.endswith(".xs") and not name.endswith(".q")
]
return names
def generate(self, sections, mode="extended"):

View File

@@ -1,20 +1,48 @@
from abc import ABC, abstractmethod
from typing import Any, List, Dict, Optional, Union
from typing import Any, List, Dict, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import json, time
# from optimum.intel import IPEXModel
from .prompts import *
from .config import *
from .utils import *
from .models import *
import json
import time
import os
from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import (
DEFAULT_PROVIDER, PROVIDER_MODELS,
CHUNK_TOKEN_THRESHOLD,
OVERLAP_RATE,
WORD_TOKEN_RATE,
PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION,
PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
)
from .utils import * # noqa: F403
from .utils import (
sanitize_html,
calculate_batch_size,
escape_json_string,
perform_completion_with_backoff,
extract_xml_data,
split_and_parse_json_objects,
sanitize_input_encode,
)
from .models import * # noqa: F403
from .models import TokenUsage
from .model_loader import * # noqa: F403
from .model_loader import (
get_device,
load_HF_embedding_model,
load_text_multilabel_classifier,
)
from functools import partial
from .model_loader import *
import math
import numpy as np
import re
from bs4 import BeautifulSoup
from lxml import html, etree
from dataclasses import dataclass
class ExtractionStrategy(ABC):
"""
@@ -56,15 +84,20 @@ class ExtractionStrategy(ABC):
"""
extracted_content = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections]
futures = [
executor.submit(self.extract, url, section, **kwargs)
for section in sections
]
for future in as_completed(futures):
extracted_content.extend(future.result())
return extracted_content
class NoExtractionStrategy(ExtractionStrategy):
"""
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
"""
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML.
@@ -72,13 +105,17 @@ class NoExtractionStrategy(ExtractionStrategy):
return [{"index": 0, "content": html}]
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
return [
{"index": i, "tags": [], "content": section}
for i, section in enumerate(sections)
]
#######################################################
# Strategies using clustering for text data extraction #
#######################################################
class CosineStrategy(ExtractionStrategy):
"""
Extract meaningful blocks or chunks from the given HTML using cosine similarity.
@@ -99,7 +136,18 @@ class CosineStrategy(ExtractionStrategy):
model_name (str): The name of the sentence-transformers model.
sim_threshold (float): The similarity threshold for clustering.
"""
def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'sentence-transformers/all-MiniLM-L6-v2', sim_threshold = 0.3, **kwargs):
def __init__(
self,
semantic_filter=None,
word_count_threshold=10,
max_dist=0.2,
linkage_method="ward",
top_k=3,
model_name="sentence-transformers/all-MiniLM-L6-v2",
sim_threshold=0.3,
**kwargs,
):
"""
Initialize the strategy with clustering parameters.
@@ -162,7 +210,6 @@ class CosineStrategy(ExtractionStrategy):
# self.tokenizer = self.model.tokenizer
# self.get_embedding_method = "direct"
if self.verbose:
print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.")
@@ -170,9 +217,15 @@ class CosineStrategy(ExtractionStrategy):
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
if self.verbose:
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
print(
f"[LOG] Model loaded {model_name}, models/reuters, took "
+ str(time.time() - self.timer)
+ " seconds"
)
def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, at_least_k: int = 20) -> List[str]:
def filter_documents_embeddings(
self, documents: List[str], semantic_filter: str, at_least_k: int = 20
) -> List[str]:
"""
Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
@@ -200,23 +253,35 @@ class CosineStrategy(ExtractionStrategy):
document_embeddings = self.get_embeddings(documents)
# Calculate cosine similarity between the query embedding and document embeddings
similarities = cosine_similarity([query_embedding], document_embeddings).flatten()
similarities = cosine_similarity(
[query_embedding], document_embeddings
).flatten()
# Filter documents based on the similarity threshold
filtered_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim >= self.sim_threshold]
filtered_docs = [
(doc, sim)
for doc, sim in zip(documents, similarities)
if sim >= self.sim_threshold
]
# If the number of filtered documents is less than at_least_k, sort remaining documents by similarity
if len(filtered_docs) < at_least_k:
remaining_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim < self.sim_threshold]
remaining_docs = [
(doc, sim)
for doc, sim in zip(documents, similarities)
if sim < self.sim_threshold
]
remaining_docs.sort(key=lambda x: x[1], reverse=True)
filtered_docs.extend(remaining_docs[:at_least_k - len(filtered_docs)])
filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)])
# Extract the document texts from the tuples
filtered_docs = [doc for doc, _ in filtered_docs]
return filtered_docs[:at_least_k]
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=False):
def get_embeddings(
self, sentences: List[str], batch_size=None, bypass_buffer=False
):
"""
Get BERT embeddings for a list of sentences.
@@ -229,17 +294,22 @@ class CosineStrategy(ExtractionStrategy):
# if self.buffer_embeddings.any() and not bypass_buffer:
# return self.buffer_embeddings
if self.device.type in [ "cpu", "gpu", "cuda", "mps"]:
if self.device.type in ["cpu", "gpu", "cuda", "mps"]:
import torch
# Tokenize sentences and convert to tensor
if batch_size is None:
batch_size = self.default_batch_size
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i + batch_size]
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
batch_sentences = sentences[i : i + batch_size]
encoded_input = self.tokenizer(
batch_sentences, padding=True, truncation=True, return_tensors="pt"
)
encoded_input = {
key: tensor.to(self.device) for key, tensor in encoded_input.items()
}
# Ensure no gradients are calculated
with torch.no_grad():
@@ -257,14 +327,14 @@ class CosineStrategy(ExtractionStrategy):
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i + batch_size]
batch_sentences = sentences[i : i + batch_size]
embeddings = self.model(batch_sentences)
all_embeddings.append(embeddings)
self.buffer_embeddings = np.vstack(all_embeddings)
return self.buffer_embeddings
def hierarchical_clustering(self, sentences: List[str], embeddings = None):
def hierarchical_clustering(self, sentences: List[str], embeddings=None):
"""
Perform hierarchical clustering on sentences and return cluster labels.
@@ -277,18 +347,21 @@ class CosineStrategy(ExtractionStrategy):
# Get embeddings
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
self.timer = time.time()
embeddings = self.get_embeddings(sentences, bypass_buffer=True)
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# Compute pairwise cosine distances
distance_matrix = pdist(embeddings, 'cosine')
distance_matrix = pdist(embeddings, "cosine")
# Perform agglomerative clustering respecting order
linked = linkage(distance_matrix, method=self.linkage_method)
# Form flat clusters
labels = fcluster(linked, self.max_dist, criterion='distance')
labels = fcluster(linked, self.max_dist, criterion="distance")
return labels
def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]) -> Dict[int, List[str]]:
def filter_clusters_by_word_count(
self, clusters: Dict[int, List[str]]
) -> Dict[int, List[str]]:
"""
Filter clusters to remove those with a word count below the threshold.
@@ -327,7 +400,9 @@ class CosineStrategy(ExtractionStrategy):
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
# Pre-filter documents using embeddings and semantic_filter
text_chunks = self.filter_documents_embeddings(text_chunks, self.semantic_filter)
text_chunks = self.filter_documents_embeddings(
text_chunks, self.semantic_filter
)
if not text_chunks:
return []
@@ -346,16 +421,19 @@ class CosineStrategy(ExtractionStrategy):
filtered_clusters = self.filter_clusters_by_word_count(clusters)
# Convert filtered clusters to a sorted list of dictionaries
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
cluster_list = [
{"index": int(idx), "tags": [], "content": " ".join(filtered_clusters[idx])}
for idx in sorted(filtered_clusters)
]
if self.verbose:
print(f"[LOG] 🚀 Assign tags using {self.device}")
if self.device.type in ["gpu", "cuda", "mps", "cpu"]:
labels = self.nlp([cluster['content'] for cluster in cluster_list])
labels = self.nlp([cluster["content"] for cluster in cluster_list])
for cluster, label in zip(cluster_list, labels):
cluster['tags'] = label
cluster["tags"] = label
# elif self.device.type == "cpu":
# # Process the text with the loaded model
# texts = [cluster['content'] for cluster in cluster_list]
@@ -393,7 +471,6 @@ class CosineStrategy(ExtractionStrategy):
return self.extract(url, self.DEL.join(sections), **kwargs)
#######################################################
# Strategies using LLM-based extraction for text data #
#######################################################
@@ -419,9 +496,15 @@ class LLMExtractionStrategy(ExtractionStrategy):
total_usage: Accumulated token usage.
"""
def __init__(self,
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None,
instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs):
def __init__(
self,
provider: str = DEFAULT_PROVIDER,
api_token: Optional[str] = None,
instruction: str = None,
schema: Dict = None,
extraction_type="block",
**kwargs,
):
"""
Initialize the strategy with clustering parameters.
@@ -445,14 +528,20 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
super().__init__(**kwargs)
self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, "no-token") or os.getenv("OPENAI_API_KEY")
self.api_token = (
api_token
or PROVIDER_MODELS.get(provider, "no-token")
or os.getenv("OPENAI_API_KEY")
)
self.instruction = instruction
self.extract_type = extraction_type
self.schema = schema
if schema:
self.extract_type = "schema"
self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD)
self.chunk_token_threshold = kwargs.get(
"chunk_token_threshold", CHUNK_TOKEN_THRESHOLD
)
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
self.apply_chunking = kwargs.get("apply_chunking", True)
@@ -467,10 +556,11 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.total_usage = TokenUsage() # Accumulated usage
if not self.api_token:
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.")
raise ValueError(
"API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable."
)
def extract(self, url: str, ix:int, html: str) -> List[Dict[str, Any]]:
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
"""
Extract meaningful blocks or chunks from the given HTML using an LLM.
@@ -515,15 +605,19 @@ class LLMExtractionStrategy(ExtractionStrategy):
prompt_with_variables,
self.api_token,
base_url=self.api_base or self.base_url,
extra_args = self.extra_args
extra_args=self.extra_args,
) # , json_response=self.extract_type == "schema")
# Track usage
usage = TokenUsage(
completion_tokens=response.usage.completion_tokens,
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
completion_tokens_details=response.usage.completion_tokens_details.__dict__ if response.usage.completion_tokens_details else {},
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__ if response.usage.prompt_tokens_details else {}
completion_tokens_details=response.usage.completion_tokens_details.__dict__
if response.usage.completion_tokens_details
else {},
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__
if response.usage.prompt_tokens_details
else {},
)
self.usages.append(usage)
@@ -533,36 +627,44 @@ class LLMExtractionStrategy(ExtractionStrategy):
self.total_usage.total_tokens += usage.total_tokens
try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[
"blocks"
]
blocks = json.loads(blocks)
for block in blocks:
block['error'] = False
except Exception as e:
parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content)
block["error"] = False
except Exception:
parsed, unparsed = split_and_parse_json_objects(
response.choices[0].message.content
)
blocks = parsed
if unparsed:
blocks.append({
"index": 0,
"error": True,
"tags": ["error"],
"content": unparsed
})
blocks.append(
{"index": 0, "error": True, "tags": ["error"], "content": unparsed}
)
if self.verbose:
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix)
print(
"[LOG] Extracted",
len(blocks),
"blocks from URL:",
url,
"block index:",
ix,
)
return blocks
def _merge(self, documents, chunk_token_threshold, overlap):
"""
Merge documents into sections based on chunk_token_threshold and overlap.
"""
chunks = []
# chunks = []
sections = []
total_tokens = 0
# Calculate the total tokens across all documents
for document in documents:
total_tokens += len(document.split(' ')) * self.word_token_rate
total_tokens += len(document.split(" ")) * self.word_token_rate
# Calculate the number of sections needed
num_sections = math.floor(total_tokens / chunk_token_threshold)
@@ -574,7 +676,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
current_chunk = []
for document in documents:
tokens = document.split(' ')
tokens = document.split(" ")
token_count = len(tokens) * self.word_token_rate
if total_token_so_far + token_count <= adjusted_chunk_threshold:
@@ -591,17 +693,16 @@ class LLMExtractionStrategy(ExtractionStrategy):
overlap_tokens = current_chunk[-overlap:]
current_chunk.extend(overlap_tokens)
sections.append(' '.join(current_chunk))
sections.append(" ".join(current_chunk))
current_chunk = tokens
total_token_so_far = token_count
# Add the last chunk
if current_chunk:
sections.append(' '.join(current_chunk))
sections.append(" ".join(current_chunk))
return sections
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
@@ -615,15 +716,18 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
merged_sections = self._merge(
sections, self.chunk_token_threshold,
overlap= int(self.chunk_token_threshold * self.overlap_rate)
sections,
self.chunk_token_threshold,
overlap=int(self.chunk_token_threshold * self.overlap_rate),
)
extracted_content = []
if self.provider.startswith("groq/"):
# Sequential processing with a delay
for ix, section in enumerate(merged_sections):
extract_func = partial(self.extract, url)
extracted_content.extend(extract_func(ix, sanitize_input_encode(section)))
extracted_content.extend(
extract_func(ix, sanitize_input_encode(section))
)
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
@@ -633,7 +737,10 @@ class LLMExtractionStrategy(ExtractionStrategy):
with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url)
futures = [executor.submit(extract_func, ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)]
futures = [
executor.submit(extract_func, ix, sanitize_input_encode(section))
for ix, section in enumerate(merged_sections)
]
for future in as_completed(futures):
try:
@@ -642,17 +749,17 @@ class LLMExtractionStrategy(ExtractionStrategy):
if self.verbose:
print(f"Error in thread execution: {e}")
# Add error information to extracted_content
extracted_content.append({
extracted_content.append(
{
"index": 0,
"error": True,
"tags": ["error"],
"content": str(e)
})
"content": str(e),
}
)
return extracted_content
def show_usage(self) -> None:
"""Print a detailed token usage report showing total and per-request usage."""
print("\n=== Token Usage Summary ===")
@@ -666,14 +773,16 @@ class LLMExtractionStrategy(ExtractionStrategy):
print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}")
print("-" * 48)
for i, usage in enumerate(self.usages, 1):
print(f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}")
print(
f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
)
#######################################################
# New extraction strategies for JSON-based extraction #
#######################################################
class JsonElementExtractionStrategy(ExtractionStrategy):
"""
Abstract base class for extracting structured JSON from HTML content.
@@ -706,8 +815,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
_get_element_attribute(element, attribute): Extracts an attribute's value from an element.
"""
DEL = '\n'
DEL = "\n"
def __init__(self, schema: Dict[str, Any], **kwargs):
"""
@@ -718,9 +826,11 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""
super().__init__(**kwargs)
self.schema = schema
self.verbose = kwargs.get('verbose', False)
self.verbose = kwargs.get("verbose", False)
def extract(self, url: str, html_content: str, *q, **kwargs) -> List[Dict[str, Any]]:
def extract(
self, url: str, html_content: str, *q, **kwargs
) -> List[Dict[str, Any]]:
"""
Extract structured data from HTML content.
@@ -740,20 +850,22 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""
parsed_html = self._parse_html(html_content)
base_elements = self._get_base_elements(parsed_html, self.schema['baseSelector'])
base_elements = self._get_base_elements(
parsed_html, self.schema["baseSelector"]
)
results = []
for element in base_elements:
# Extract base element attributes
item = {}
if 'baseFields' in self.schema:
for field in self.schema['baseFields']:
if "baseFields" in self.schema:
for field in self.schema["baseFields"]:
value = self._extract_single_field(element, field)
if value is not None:
item[field['name']] = value
item[field["name"]] = value
# Extract child fields
field_data = self._extract_item(element, self.schema['fields'])
field_data = self._extract_item(element, self.schema["fields"])
item.update(field_data)
if item:
@@ -778,24 +890,28 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
def _extract_field(self, element, field):
try:
if field['type'] == 'nested':
nested_elements = self._get_elements(element, field['selector'])
if field["type"] == "nested":
nested_elements = self._get_elements(element, field["selector"])
nested_element = nested_elements[0] if nested_elements else None
return self._extract_item(nested_element, field['fields']) if nested_element else {}
return (
self._extract_item(nested_element, field["fields"])
if nested_element
else {}
)
if field['type'] == 'list':
elements = self._get_elements(element, field['selector'])
return [self._extract_list_item(el, field['fields']) for el in elements]
if field["type"] == "list":
elements = self._get_elements(element, field["selector"])
return [self._extract_list_item(el, field["fields"]) for el in elements]
if field['type'] == 'nested_list':
elements = self._get_elements(element, field['selector'])
return [self._extract_item(el, field['fields']) for el in elements]
if field["type"] == "nested_list":
elements = self._get_elements(element, field["selector"])
return [self._extract_item(el, field["fields"]) for el in elements]
return self._extract_single_field(element, field)
except Exception as e:
if self.verbose:
print(f"Error extracting field {field['name']}: {str(e)}")
return field.get('default')
return field.get("default")
def _extract_single_field(self, element, field):
"""
@@ -814,37 +930,37 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
Any: The extracted field value.
"""
if 'selector' in field:
selected = self._get_elements(element, field['selector'])
if "selector" in field:
selected = self._get_elements(element, field["selector"])
if not selected:
return field.get('default')
return field.get("default")
selected = selected[0]
else:
selected = element
value = None
if field['type'] == 'text':
if field["type"] == "text":
value = self._get_element_text(selected)
elif field['type'] == 'attribute':
value = self._get_element_attribute(selected, field['attribute'])
elif field['type'] == 'html':
elif field["type"] == "attribute":
value = self._get_element_attribute(selected, field["attribute"])
elif field["type"] == "html":
value = self._get_element_html(selected)
elif field['type'] == 'regex':
elif field["type"] == "regex":
text = self._get_element_text(selected)
match = re.search(field['pattern'], text)
match = re.search(field["pattern"], text)
value = match.group(1) if match else None
if 'transform' in field:
value = self._apply_transform(value, field['transform'])
if "transform" in field:
value = self._apply_transform(value, field["transform"])
return value if value is not None else field.get('default')
return value if value is not None else field.get("default")
def _extract_list_item(self, element, fields):
item = {}
for field in fields:
value = self._extract_single_field(element, field)
if value is not None:
item[field['name']] = value
item[field["name"]] = value
return item
def _extract_item(self, element, fields):
@@ -866,12 +982,12 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
item = {}
for field in fields:
if field['type'] == 'computed':
if field["type"] == "computed":
value = self._compute_field(item, field)
else:
value = self._extract_field(element, field)
if value is not None:
item[field['name']] = value
item[field["name"]] = value
return item
def _apply_transform(self, value, transform):
@@ -891,24 +1007,24 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
str: The transformed value.
"""
if transform == 'lowercase':
if transform == "lowercase":
return value.lower()
elif transform == 'uppercase':
elif transform == "uppercase":
return value.upper()
elif transform == 'strip':
elif transform == "strip":
return value.strip()
return value
def _compute_field(self, item, field):
try:
if 'expression' in field:
return eval(field['expression'], {}, item)
elif 'function' in field:
return field['function'](item)
if "expression" in field:
return eval(field["expression"], {}, item)
elif "function" in field:
return field["function"](item)
except Exception as e:
if self.verbose:
print(f"Error computing field {field['name']}: {str(e)}")
return field.get('default')
return field.get("default")
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
"""
@@ -946,6 +1062,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
"""Get attribute value from element"""
pass
class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
"""
Concrete implementation of `JsonElementExtractionStrategy` using CSS selectors.
@@ -969,11 +1086,11 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
"""
def __init__(self, schema: Dict[str, Any], **kwargs):
kwargs['input_format'] = 'html' # Force HTML input
kwargs["input_format"] = "html" # Force HTML input
super().__init__(schema, **kwargs)
def _parse_html(self, html_content: str):
return BeautifulSoup(html_content, 'html.parser')
return BeautifulSoup(html_content, "html.parser")
def _get_base_elements(self, parsed_html, selector: str):
return parsed_html.select(selector)
@@ -992,6 +1109,7 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
def _get_element_attribute(self, element, attribute: str):
return element.get(attribute)
class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
"""
Concrete implementation of `JsonElementExtractionStrategy` using XPath selectors.
@@ -1016,7 +1134,7 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
"""
def __init__(self, schema: Dict[str, Any], **kwargs):
kwargs['input_format'] = 'html' # Force HTML input
kwargs["input_format"] = "html" # Force HTML input
super().__init__(schema, **kwargs)
def _parse_html(self, html_content: str):
@@ -1027,31 +1145,31 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
def _css_to_xpath(self, css_selector: str) -> str:
"""Convert CSS selector to XPath if needed"""
if '/' in css_selector: # Already an XPath
if "/" in css_selector: # Already an XPath
return css_selector
return self._basic_css_to_xpath(css_selector)
def _basic_css_to_xpath(self, css_selector: str) -> str:
"""Basic CSS to XPath conversion for common cases"""
if ' > ' in css_selector:
parts = css_selector.split(' > ')
return '//' + '/'.join(parts)
if ' ' in css_selector:
parts = css_selector.split(' ')
return '//' + '//'.join(parts)
return '//' + css_selector
if " > " in css_selector:
parts = css_selector.split(" > ")
return "//" + "/".join(parts)
if " " in css_selector:
parts = css_selector.split(" ")
return "//" + "//".join(parts)
return "//" + css_selector
def _get_elements(self, element, selector: str):
xpath = self._css_to_xpath(selector)
if not xpath.startswith('.'):
xpath = '.' + xpath
if not xpath.startswith("."):
xpath = "." + xpath
return element.xpath(xpath)
def _get_element_text(self, element) -> str:
return ''.join(element.xpath('.//text()')).strip()
return "".join(element.xpath(".//text()")).strip()
def _get_element_html(self, element) -> str:
return etree.tostring(element, encoding='unicode')
return etree.tostring(element, encoding="unicode")
def _get_element_attribute(self, element, attribute: str):
return element.get(attribute)

View File

@@ -903,7 +903,13 @@ class HTML2Text(html.parser.HTMLParser):
self.empty_link = False
if not self.code and not self.pre and not entity_char:
data = escape_md_section(data, snob=self.escape_snob, escape_dot=self.escape_dot, escape_plus=self.escape_plus, escape_dash=self.escape_dash)
data = escape_md_section(
data,
snob=self.escape_snob,
escape_dot=self.escape_dot,
escape_plus=self.escape_plus,
escape_dash=self.escape_dash,
)
self.preceding_data = data
self.o(data, puredata=True)
@@ -1006,6 +1012,7 @@ class HTML2Text(html.parser.HTMLParser):
newlines += 1
return result
def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str:
if bodywidth is None:
bodywidth = config.BODY_WIDTH
@@ -1013,6 +1020,7 @@ def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) ->
return h.handle(html)
class CustomHTML2Text(HTML2Text):
def __init__(self, *args, handle_code_in_pre=False, **kwargs):
super().__init__(*args, **kwargs)
@@ -1041,9 +1049,9 @@ class CustomHTML2Text(HTML2Text):
def update_params(self, **kwargs):
"""Update parameters and set preserved tags."""
for key, value in kwargs.items():
if key == 'preserve_tags':
if key == "preserve_tags":
self.preserve_tags = set(value)
elif key == 'handle_code_in_pre':
elif key == "handle_code_in_pre":
self.handle_code_in_pre = value
else:
setattr(self, key, value)
@@ -1056,17 +1064,19 @@ class CustomHTML2Text(HTML2Text):
self.current_preserved_tag = tag
self.preserved_content = []
# Format opening tag with attributes
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None)
self.preserved_content.append(f'<{tag}{attr_str}>')
attr_str = "".join(
f' {k}="{v}"' for k, v in attrs.items() if v is not None
)
self.preserved_content.append(f"<{tag}{attr_str}>")
self.preserve_depth += 1
return
else:
self.preserve_depth -= 1
if self.preserve_depth == 0:
self.preserved_content.append(f'</{tag}>')
self.preserved_content.append(f"</{tag}>")
# Output the preserved HTML block with proper spacing
preserved_html = ''.join(self.preserved_content)
self.o('\n' + preserved_html + '\n')
preserved_html = "".join(self.preserved_content)
self.o("\n" + preserved_html + "\n")
self.current_preserved_tag = None
return
@@ -1074,29 +1084,31 @@ class CustomHTML2Text(HTML2Text):
if self.preserve_depth > 0:
if start:
# Format nested tags with attributes
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None)
self.preserved_content.append(f'<{tag}{attr_str}>')
attr_str = "".join(
f' {k}="{v}"' for k, v in attrs.items() if v is not None
)
self.preserved_content.append(f"<{tag}{attr_str}>")
else:
self.preserved_content.append(f'</{tag}>')
self.preserved_content.append(f"</{tag}>")
return
# Handle pre tags
if tag == 'pre':
if tag == "pre":
if start:
self.o('```\n') # Markdown code block start
self.o("```\n") # Markdown code block start
self.inside_pre = True
else:
self.o('\n```\n') # Markdown code block end
self.o("\n```\n") # Markdown code block end
self.inside_pre = False
elif tag == 'code':
elif tag == "code":
if self.inside_pre and not self.handle_code_in_pre:
# Ignore code tags inside pre blocks if handle_code_in_pre is False
return
if start:
self.o('`') # Markdown inline code start
self.o("`") # Markdown inline code start
self.inside_code = True
else:
self.o('`') # Markdown inline code end
self.o("`") # Markdown inline code end
self.inside_code = False
else:
super().handle_tag(tag, attrs, start)
@@ -1113,13 +1125,12 @@ class CustomHTML2Text(HTML2Text):
return
if self.inside_code:
# Inline code: no newlines allowed
self.o(data.replace('\n', ' '))
self.o(data.replace("\n", " "))
return
# Default behavior for other tags
super().handle_data(data, entity_char)
# # Handle pre tags
# if tag == 'pre':
# if start:

View File

@@ -1,2 +1,3 @@
class OutCallback:
def __call__(self, s: str) -> None: ...
def __call__(self, s: str) -> None:
...

View File

@@ -210,7 +210,7 @@ def escape_md_section(
snob: bool = False,
escape_dot: bool = True,
escape_plus: bool = True,
escape_dash: bool = True
escape_dash: bool = True,
) -> str:
"""
Escapes markdown-sensitive characters across whole document sections.
@@ -233,6 +233,7 @@ def escape_md_section(
return text
def reformat_table(lines: List[str], right_margin: int) -> List[str]:
"""
Given the lines of a table

View File

@@ -6,6 +6,7 @@ from .async_logger import AsyncLogger, LogLevel
# Initialize logger
logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
def post_install():
"""Run all post-installation tasks"""
logger.info("Running post-installation setup...", tag="INIT")
@@ -13,18 +14,36 @@ def post_install():
run_migration()
logger.success("Post-installation setup completed!", tag="COMPLETE")
def install_playwright():
logger.info("Installing Playwright browsers...", tag="INIT")
try:
# subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"])
subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chromium"])
logger.success("Playwright installation completed successfully.", tag="COMPLETE")
except subprocess.CalledProcessError as e:
subprocess.check_call(
[
sys.executable,
"-m",
"playwright",
"install",
"--with-deps",
"--force",
"chromium",
]
)
logger.success(
"Playwright installation completed successfully.", tag="COMPLETE"
)
except subprocess.CalledProcessError:
# logger.error(f"Error during Playwright installation: {e}", tag="ERROR")
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.")
except Exception as e:
logger.warning(
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
)
except Exception:
# logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR")
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.")
logger.warning(
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
)
def run_migration():
"""Initialize database during installation"""
@@ -33,18 +52,26 @@ def run_migration():
from crawl4ai.async_database import async_db_manager
asyncio.run(async_db_manager.initialize())
logger.success("Database initialization completed successfully.", tag="COMPLETE")
logger.success(
"Database initialization completed successfully.", tag="COMPLETE"
)
except ImportError:
logger.warning("Database module not found. Will initialize on first use.")
except Exception as e:
logger.warning(f"Database initialization failed: {e}")
logger.warning("Database will be initialized on first use")
async def run_doctor():
"""Test if Crawl4AI is working properly"""
logger.info("Running Crawl4AI health check...", tag="INIT")
try:
from .async_webcrawler import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from .async_webcrawler import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
CacheMode,
)
browser_config = BrowserConfig(
headless=True,
@@ -52,7 +79,7 @@ async def run_doctor():
ignore_https_errors=True,
light_mode=True,
viewport_width=1280,
viewport_height=720
viewport_height=720,
)
run_config = CrawlerRunConfig(
@@ -62,10 +89,7 @@ async def run_doctor():
async with AsyncWebCrawler(config=browser_config) as crawler:
logger.info("Testing crawling capabilities...", tag="TEST")
result = await crawler.arun(
url="https://crawl4ai.com",
config=run_config
)
result = await crawler.arun(url="https://crawl4ai.com", config=run_config)
if result and result.markdown:
logger.success("✅ Crawling test passed!", tag="COMPLETE")
@@ -77,7 +101,9 @@ async def run_doctor():
logger.error(f"❌ Test failed: {e}", tag="ERROR")
return False
def doctor():
"""Entry point for the doctor command"""
import asyncio
return asyncio.run(run_doctor())

View File

@@ -1,15 +1,18 @@
import os, sys
import os
# Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free
def load_js_script(script_name):
# Get the path of the current script
current_script_path = os.path.dirname(os.path.realpath(__file__))
# Get the path of the script to load
script_path = os.path.join(current_script_path, script_name + '.js')
script_path = os.path.join(current_script_path, script_name + ".js")
# Check if the script exists
if not os.path.exists(script_path):
raise ValueError(f"Script {script_name} not found in the folder {current_script_path}")
raise ValueError(
f"Script {script_name} not found in the folder {current_script_path}"
)
# Load the content of the script
with open(script_path, 'r') as f:
with open(script_path, "r") as f:
script_content = f.read()
return script_content

View File

@@ -11,16 +11,16 @@ from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from litellm import completion, batch_completion
from litellm import batch_completion
from .async_logger import AsyncLogger
import litellm
import pickle
import hashlib # <--- ADDED for file-hash
from fnmatch import fnmatch
import glob
litellm.set_verbose = False
def _compute_file_hash(file_path: Path) -> str:
"""Compute MD5 hash for the file's entire content."""
hash_md5 = hashlib.md5()
@@ -29,13 +29,14 @@ def _compute_file_hash(file_path: Path) -> str:
hash_md5.update(chunk)
return hash_md5.hexdigest()
class AsyncLLMTextManager:
def __init__(
self,
docs_dir: Path,
logger: Optional[AsyncLogger] = None,
max_concurrent_calls: int = 5,
batch_size: int = 3
batch_size: int = 3,
) -> None:
self.docs_dir = docs_dir
self.logger = logger
@@ -51,7 +52,7 @@ class AsyncLLMTextManager:
contents = []
for file_path in doc_batch:
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
contents.append(f.read())
except Exception as e:
self.logger.error(f"Error reading {file_path}: {str(e)}")
@@ -77,43 +78,53 @@ Wrap your response in <index>...</index> tags.
# Prepare messages for batch processing
messages_list = [
[
{"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"}
{
"role": "user",
"content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}",
}
]
for content in contents if content
for content in contents
if content
]
try:
responses = batch_completion(
model="anthropic/claude-3-5-sonnet-latest",
messages=messages_list,
logger_fn=None
logger_fn=None,
)
# Process responses and save index files
for response, file_path in zip(responses, doc_batch):
try:
index_content_match = re.search(
r'<index>(.*?)</index>',
r"<index>(.*?)</index>",
response.choices[0].message.content,
re.DOTALL
re.DOTALL,
)
if not index_content_match:
self.logger.warning(f"No <index>...</index> content found for {file_path}")
self.logger.warning(
f"No <index>...</index> content found for {file_path}"
)
continue
index_content = re.sub(
r"\n\s*\n", "\n", index_content_match.group(1)
).strip()
if index_content:
index_file = file_path.with_suffix('.q.md')
with open(index_file, 'w', encoding='utf-8') as f:
index_file = file_path.with_suffix(".q.md")
with open(index_file, "w", encoding="utf-8") as f:
f.write(index_content)
self.logger.info(f"Created index file: {index_file}")
else:
self.logger.warning(f"No index content found in response for {file_path}")
self.logger.warning(
f"No index content found in response for {file_path}"
)
except Exception as e:
self.logger.error(f"Error processing response for {file_path}: {str(e)}")
self.logger.error(
f"Error processing response for {file_path}: {str(e)}"
)
except Exception as e:
self.logger.error(f"Error in batch completion: {str(e)}")
@@ -171,7 +182,12 @@ Wrap your response in <index>...</index> tags.
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words("english")) - {
"how", "what", "when", "where", "why", "which",
"how",
"what",
"when",
"where",
"why",
"which",
}
tokens = []
@@ -222,7 +238,9 @@ Wrap your response in <index>...</index> tags.
self.logger.info("Checking which .q.md files need (re)indexing...")
# Gather all .q.md files
q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")]
q_files = [
self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")
]
# We'll store known (unchanged) facts in these lists
existing_facts: List[str] = []
@@ -243,7 +261,9 @@ Wrap your response in <index>...</index> tags.
# Otherwise, load the existing cache and compare hash
cache = self._load_or_create_token_cache(qf)
# If the .q.tokens was out of date (i.e. changed hash), we reindex
if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf):
if len(cache["facts"]) == 0 or cache.get(
"content_hash"
) != _compute_file_hash(qf):
needSet.append(qf)
else:
# File is unchanged → retrieve cached token data
@@ -255,20 +275,29 @@ Wrap your response in <index>...</index> tags.
if not needSet and not clear_cache:
# If no file needs reindexing, try loading existing index
if self.maybe_load_bm25_index(clear_cache=False):
self.logger.info("No new/changed .q.md files found. Using existing BM25 index.")
self.logger.info(
"No new/changed .q.md files found. Using existing BM25 index."
)
return
else:
# If there's no existing index, we must build a fresh index from the old caches
self.logger.info("No existing BM25 index found. Building from cached facts.")
self.logger.info(
"No existing BM25 index found. Building from cached facts."
)
if existing_facts:
self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.")
self.logger.info(
f"Building BM25 index with {len(existing_facts)} cached facts."
)
self.bm25_index = BM25Okapi(existing_tokens)
self.tokenized_facts = existing_facts
with open(self.bm25_index_file, "wb") as f:
pickle.dump({
pickle.dump(
{
"bm25_index": self.bm25_index,
"tokenized_facts": self.tokenized_facts
}, f)
"tokenized_facts": self.tokenized_facts,
},
f,
)
else:
self.logger.warning("No facts found at all. Index remains empty.")
return
@@ -311,7 +340,9 @@ Wrap your response in <index>...</index> tags.
self._save_token_cache(file, fresh_cache)
mem_usage = process.memory_info().rss / 1024 / 1024
self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB")
self.logger.debug(
f"Memory usage after {file.name}: {mem_usage:.2f}MB"
)
except Exception as e:
self.logger.error(f"Error processing {file}: {str(e)}")
@@ -328,21 +359,28 @@ Wrap your response in <index>...</index> tags.
all_tokens = existing_tokens + new_tokens
# 3) Build BM25 index from combined facts
self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).")
self.logger.info(
f"Building BM25 index with {len(all_facts)} total facts (old + new)."
)
self.bm25_index = BM25Okapi(all_tokens)
self.tokenized_facts = all_facts
# 4) Save the updated BM25 index to disk
with open(self.bm25_index_file, "wb") as f:
pickle.dump({
pickle.dump(
{
"bm25_index": self.bm25_index,
"tokenized_facts": self.tokenized_facts
}, f)
"tokenized_facts": self.tokenized_facts,
},
f,
)
final_mem = process.memory_info().rss / 1024 / 1024
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None:
async def generate_index_files(
self, force_generate_facts: bool = False, clear_bm25_cache: bool = False
) -> None:
"""
Generate index files for all documents in parallel batches
@@ -353,15 +391,17 @@ Wrap your response in <index>...</index> tags.
self.logger.info("Starting index generation for documentation files.")
md_files = [
self.docs_dir / f for f in os.listdir(self.docs_dir)
if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md'])
self.docs_dir / f
for f in os.listdir(self.docs_dir)
if f.endswith(".md") and not any(f.endswith(x) for x in [".q.md", ".xs.md"])
]
# Filter out files that already have .q files unless force=True
if not force_generate_facts:
md_files = [
f for f in md_files
if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists()
f
for f in md_files
if not (self.docs_dir / f.name.replace(".md", ".q.md")).exists()
]
if not md_files:
@@ -369,8 +409,10 @@ Wrap your response in <index>...</index> tags.
else:
# Process documents in batches
for i in range(0, len(md_files), self.batch_size):
batch = md_files[i:i + self.batch_size]
self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}")
batch = md_files[i : i + self.batch_size]
self.logger.info(
f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}"
)
await self._process_document_batch(batch)
self.logger.info("Index generation complete, building/updating search index.")
@@ -378,21 +420,31 @@ Wrap your response in <index>...</index> tags.
def generate(self, sections: List[str], mode: str = "extended") -> str:
# Get all markdown files
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \
glob.glob(str(self.docs_dir / "[0-9]*.xs.md"))
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + glob.glob(
str(self.docs_dir / "[0-9]*.xs.md")
)
# Extract base names without extensions
base_docs = {Path(f).name.split('.')[0] for f in all_files
if not Path(f).name.endswith('.q.md')}
base_docs = {
Path(f).name.split(".")[0]
for f in all_files
if not Path(f).name.endswith(".q.md")
}
# Filter by sections if provided
if sections:
base_docs = {doc for doc in base_docs
if any(section.lower() in doc.lower() for section in sections)}
base_docs = {
doc
for doc in base_docs
if any(section.lower() in doc.lower() for section in sections)
}
# Get file paths based on mode
files = []
for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999):
for doc in sorted(
base_docs,
key=lambda x: int(x.split("_")[0]) if x.split("_")[0].isdigit() else 999999,
):
if mode == "condensed":
xs_file = self.docs_dir / f"{doc}.xs.md"
regular_file = self.docs_dir / f"{doc}.md"
@@ -404,7 +456,7 @@ Wrap your response in <index>...</index> tags.
content = []
for file in files:
try:
with open(file, 'r', encoding='utf-8') as f:
with open(file, "r", encoding="utf-8") as f:
fname = Path(file).name
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
except Exception as e:
@@ -443,15 +495,9 @@ Wrap your response in <index>...</index> tags.
for file, _ in ranked_files:
main_doc = str(file).replace(".q.md", ".md")
if os.path.exists(self.docs_dir / main_doc):
with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f:
with open(self.docs_dir / main_doc, "r", encoding="utf-8") as f:
only_file_name = main_doc.split("/")[-1]
content = [
"#" * 20,
f"# {only_file_name}",
"#" * 20,
"",
f.read()
]
content = ["#" * 20, f"# {only_file_name}", "#" * 20, "", f.read()]
results.append("\n".join(content))
return "\n\n---\n\n".join(results)
@@ -482,7 +528,9 @@ Wrap your response in <index>...</index> tags.
if len(components) == 3:
code_ref = components[2].strip()
code_tokens = self.preprocess_text(code_ref)
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens)
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(
query_tokens
)
file_data[file_path]["total_score"] += score
file_data[file_path]["match_count"] += 1

View File

@@ -2,41 +2,51 @@ from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple
from .models import MarkdownGenerationResult
from .html2text import CustomHTML2Text
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter
from .content_filter_strategy import RelevantContentFilter
import re
from urllib.parse import urljoin
# Pre-compile the regex pattern
LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)')
def fast_urljoin(base: str, url: str) -> str:
"""Fast URL joining for common cases."""
if url.startswith(('http://', 'https://', 'mailto:', '//')):
if url.startswith(("http://", "https://", "mailto:", "//")):
return url
if url.startswith('/'):
if url.startswith("/"):
# Handle absolute paths
if base.endswith('/'):
if base.endswith("/"):
return base[:-1] + url
return base + url
return urljoin(base, url)
class MarkdownGenerationStrategy(ABC):
"""Abstract base class for markdown generation strategies."""
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
def __init__(
self,
content_filter: Optional[RelevantContentFilter] = None,
options: Optional[Dict[str, Any]] = None,
):
self.content_filter = content_filter
self.options = options or {}
@abstractmethod
def generate_markdown(self,
def generate_markdown(
self,
cleaned_html: str,
base_url: str = "",
html2text_options: Optional[Dict[str, Any]] = None,
content_filter: Optional[RelevantContentFilter] = None,
citations: bool = True,
**kwargs) -> MarkdownGenerationResult:
**kwargs,
) -> MarkdownGenerationResult:
"""Generate markdown from cleaned HTML."""
pass
class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
"""
Default implementation of markdown generation strategy.
@@ -54,10 +64,17 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
Returns:
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
"""
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
def __init__(
self,
content_filter: Optional[RelevantContentFilter] = None,
options: Optional[Dict[str, Any]] = None,
):
super().__init__(content_filter, options)
def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]:
def convert_links_to_citations(
self, markdown: str, base_url: str = ""
) -> Tuple[str, str]:
"""
Convert links in markdown to citations.
@@ -83,28 +100,34 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
counter = 1
for match in LINK_PATTERN.finditer(markdown):
parts.append(markdown[last_end:match.start()])
parts.append(markdown[last_end : match.start()])
text, url, title = match.groups()
# Use cached URL if available, otherwise compute and cache
if base_url and not url.startswith(('http://', 'https://', 'mailto:')):
if base_url and not url.startswith(("http://", "https://", "mailto:")):
if url not in url_cache:
url_cache[url] = fast_urljoin(base_url, url)
url = url_cache[url]
if url not in link_map:
desc = []
if title: desc.append(title)
if text and text != title: desc.append(text)
if title:
desc.append(title)
if text and text != title:
desc.append(text)
link_map[url] = (counter, ": " + " - ".join(desc) if desc else "")
counter += 1
num = link_map[url][0]
parts.append(f"{text}{num}" if not match.group(0).startswith('!') else f"![{text}{num}⟩]")
parts.append(
f"{text}{num}"
if not match.group(0).startswith("!")
else f"![{text}{num}⟩]"
)
last_end = match.end()
parts.append(markdown[last_end:])
converted_text = ''.join(parts)
converted_text = "".join(parts)
# Pre-build reference strings
references = ["\n\n## References\n\n"]
@@ -113,16 +136,18 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0])
)
return converted_text, ''.join(references)
return converted_text, "".join(references)
def generate_markdown(self,
def generate_markdown(
self,
cleaned_html: str,
base_url: str = "",
html2text_options: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
content_filter: Optional[RelevantContentFilter] = None,
citations: bool = True,
**kwargs) -> MarkdownGenerationResult:
**kwargs,
) -> MarkdownGenerationResult:
"""
Generate markdown with citations from cleaned HTML.
@@ -147,14 +172,14 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
# Initialize HTML2Text with default options for better conversion
h = CustomHTML2Text(baseurl=base_url)
default_options = {
'body_width': 0, # Disable text wrapping
'ignore_emphasis': False,
'ignore_links': False,
'ignore_images': False,
'protect_links': True,
'single_line_break': True,
'mark_code': True,
'escape_snob': False
"body_width": 0, # Disable text wrapping
"ignore_emphasis": False,
"ignore_links": False,
"ignore_images": False,
"protect_links": True,
"single_line_break": True,
"mark_code": True,
"escape_snob": False,
}
# Update with custom options if provided
@@ -179,16 +204,17 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
except Exception as e:
raw_markdown = f"Error converting HTML to markdown: {str(e)}"
raw_markdown = raw_markdown.replace(' ```', '```')
raw_markdown = raw_markdown.replace(" ```", "```")
# Convert links to citations
markdown_with_citations: str = raw_markdown
references_markdown: str = ""
if citations:
try:
markdown_with_citations, references_markdown = self.convert_links_to_citations(
raw_markdown, base_url
)
(
markdown_with_citations,
references_markdown,
) = self.convert_links_to_citations(raw_markdown, base_url)
except Exception as e:
markdown_with_citations = raw_markdown
references_markdown = f"Error generating citations: {str(e)}"
@@ -200,7 +226,9 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
try:
content_filter = content_filter or self.content_filter
filtered_html = content_filter.filter_content(cleaned_html)
filtered_html = '\n'.join('<div>{}</div>'.format(s) for s in filtered_html)
filtered_html = "\n".join(
"<div>{}</div>".format(s) for s in filtered_html
)
fit_markdown = h.handle(filtered_html)
except Exception as e:
fit_markdown = f"Error generating fit markdown: {str(e)}"

View File

@@ -1,13 +1,11 @@
import os
import asyncio
import logging
from pathlib import Path
import aiosqlite
from typing import Optional
import xxhash
import aiofiles
import shutil
import time
from datetime import datetime
from .async_logger import AsyncLogger, LogLevel
@@ -17,6 +15,7 @@ logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
class DatabaseMigration:
def __init__(self, db_path: str):
self.db_path = db_path
@@ -24,11 +23,11 @@ class DatabaseMigration:
def _ensure_content_dirs(self, base_path: str) -> dict:
dirs = {
'html': 'html_content',
'cleaned': 'cleaned_html',
'markdown': 'markdown_content',
'extracted': 'extracted_content',
'screenshots': 'screenshots'
"html": "html_content",
"cleaned": "cleaned_html",
"markdown": "markdown_content",
"extracted": "extracted_content",
"screenshots": "screenshots",
}
content_paths = {}
for key, dirname in dirs.items():
@@ -52,7 +51,7 @@ class DatabaseMigration:
file_path = os.path.join(self.content_paths[content_type], content_hash)
if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
await f.write(content)
return content_hash
@@ -66,24 +65,36 @@ class DatabaseMigration:
async with aiosqlite.connect(self.db_path) as db:
# Get all rows
async with db.execute(
'''SELECT url, html, cleaned_html, markdown,
extracted_content, screenshot FROM crawled_data'''
"""SELECT url, html, cleaned_html, markdown,
extracted_content, screenshot FROM crawled_data"""
) as cursor:
rows = await cursor.fetchall()
migrated_count = 0
for row in rows:
url, html, cleaned_html, markdown, extracted_content, screenshot = row
(
url,
html,
cleaned_html,
markdown,
extracted_content,
screenshot,
) = row
# Store content in files and get hashes
html_hash = await self._store_content(html, 'html')
cleaned_hash = await self._store_content(cleaned_html, 'cleaned')
markdown_hash = await self._store_content(markdown, 'markdown')
extracted_hash = await self._store_content(extracted_content, 'extracted')
screenshot_hash = await self._store_content(screenshot, 'screenshots')
html_hash = await self._store_content(html, "html")
cleaned_hash = await self._store_content(cleaned_html, "cleaned")
markdown_hash = await self._store_content(markdown, "markdown")
extracted_hash = await self._store_content(
extracted_content, "extracted"
)
screenshot_hash = await self._store_content(
screenshot, "screenshots"
)
# Update database with hashes
await db.execute('''
await db.execute(
"""
UPDATE crawled_data
SET html = ?,
cleaned_html = ?,
@@ -91,26 +102,37 @@ class DatabaseMigration:
extracted_content = ?,
screenshot = ?
WHERE url = ?
''', (html_hash, cleaned_hash, markdown_hash,
extracted_hash, screenshot_hash, url))
""",
(
html_hash,
cleaned_hash,
markdown_hash,
extracted_hash,
screenshot_hash,
url,
),
)
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(f"Migrated {migrated_count} records...", tag="INIT")
await db.commit()
logger.success(f"Migration completed. {migrated_count} records processed.", tag="COMPLETE")
logger.success(
f"Migration completed. {migrated_count} records processed.",
tag="COMPLETE",
)
except Exception as e:
# logger.error(f"Migration failed: {e}")
logger.error(
message="Migration failed: {error}",
tag="ERROR",
params={"error": str(e)}
params={"error": str(e)},
)
raise e
async def backup_database(db_path: str) -> str:
"""Create backup of existing database"""
if not os.path.exists(db_path):
@@ -118,7 +140,7 @@ async def backup_database(db_path: str) -> str:
return None
# Create backup with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{db_path}.backup_{timestamp}"
try:
@@ -132,12 +154,11 @@ async def backup_database(db_path: str) -> str:
except Exception as e:
# logger.error(f"Backup failed: {e}")
logger.error(
message="Migration failed: {error}",
tag="ERROR",
params={"error": str(e)}
message="Migration failed: {error}", tag="ERROR", params={"error": str(e)}
)
raise e
async def run_migration(db_path: Optional[str] = None):
"""Run database migration"""
if db_path is None:
@@ -155,14 +176,19 @@ async def run_migration(db_path: Optional[str] = None):
migration = DatabaseMigration(db_path)
await migration.migrate_database()
def main():
"""CLI entry point for migration"""
import argparse
parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage')
parser.add_argument('--db-path', help='Custom database path')
parser = argparse.ArgumentParser(
description="Migrate Crawl4AI database to file-based storage"
)
parser.add_argument("--db-path", help="Custom database path")
args = parser.parse_args()
asyncio.run(run_migration(args.db_path))
if __name__ == "__main__":
main()

View File

@@ -2,75 +2,86 @@ from functools import lru_cache
from pathlib import Path
import subprocess, os
import shutil
import tarfile
from .model_loader import *
import argparse
import urllib.request
from crawl4ai.config import MODEL_REPO_BRANCH
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@lru_cache()
def get_available_memory(device):
import torch
if device.type == 'cuda':
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
elif device.type == 'mps':
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
elif device.type == "mps":
return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
else:
return 0
@lru_cache()
def calculate_batch_size(device):
available_memory = get_available_memory(device)
if device.type == 'cpu':
if device.type == "cpu":
return 16
elif device.type in ['cuda', 'mps']:
elif device.type in ["cuda", "mps"]:
# Adjust these thresholds based on your model size and available memory
if available_memory >= 31 * 1024 ** 3: # > 32GB
if available_memory >= 31 * 1024**3: # > 32GB
return 256
elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB
elif available_memory >= 15 * 1024**3: # > 16GB to 32GB
return 128
elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB
elif available_memory >= 8 * 1024**3: # 8GB to 16GB
return 64
else:
return 32
else:
return 16 # Default batch size
@lru_cache()
def get_device():
import torch
if torch.cuda.is_available():
device = torch.device('cuda')
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device('mps')
device = torch.device("mps")
else:
device = torch.device('cpu')
device = torch.device("cpu")
return device
def set_model_device(model):
device = get_device()
model.to(device)
return model, device
@lru_cache()
def get_home_folder():
home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
home_folder = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(home_folder, exist_ok=True)
os.makedirs(f"{home_folder}/cache", exist_ok=True)
os.makedirs(f"{home_folder}/models", exist_ok=True)
return home_folder
@lru_cache()
def load_bert_base_uncased():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None)
model = BertModel.from_pretrained("bert-base-uncased", resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
@lru_cache()
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
"""Load the Hugging Face model for embedding.
@@ -81,30 +92,35 @@ def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
Returns:
tuple: The tokenizer and model.
"""
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
model = AutoModel.from_pretrained(model_name, resume_download=None)
model.eval()
model, device = set_model_device(model)
return tokenizer, model
@lru_cache()
def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import torch
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
tokenizer = AutoTokenizer.from_pretrained(
"dstefa/roberta-base_topic_classification_nyt_news"
)
model = AutoModelForSequenceClassification.from_pretrained(
"dstefa/roberta-base_topic_classification_nyt_news"
)
model.eval()
model, device = set_model_device(model)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe
@lru_cache()
def load_text_multilabel_classifier():
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
from scipy.special import expit
import torch
@@ -117,17 +133,26 @@ def load_text_multilabel_classifier():
# device = torch.device("cpu")
# # return load_spacy_model(), torch.device("cpu")
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL, resume_download=None
)
model.eval()
model, device = set_model_device(model)
class_mapping = model.config.id2label
def _classifier(texts, threshold=0.5, max_length=64):
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
tokens = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
tokens = {
key: val.to(device) for key, val in tokens.items()
} # Move tokens to the selected device
with torch.no_grad():
output = model(**tokens)
@@ -138,25 +163,31 @@ def load_text_multilabel_classifier():
batch_labels = []
for prediction in predictions:
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1]
labels = [
class_mapping[i] for i, value in enumerate(prediction) if value == 1
]
batch_labels.append(labels)
return batch_labels
return _classifier, device
@lru_cache()
def load_nltk_punkt():
import nltk
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download('punkt')
return nltk.data.find('tokenizers/punkt')
nltk.download("punkt")
return nltk.data.find("tokenizers/punkt")
@lru_cache()
def load_spacy_model():
import spacy
name = "models/reuters"
home_folder = get_home_folder()
model_folder = Path(home_folder) / name
@@ -176,7 +207,9 @@ def load_spacy_model():
if model_folder.exists():
shutil.rmtree(model_folder)
except PermissionError:
print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:")
print(
"[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:"
)
print(f"- {repo_folder}")
print(f"- {model_folder}")
return None
@@ -187,7 +220,7 @@ def load_spacy_model():
["git", "clone", "-b", branch, repo_url, str(repo_folder)],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True
check=True,
)
# Create the models directory if it doesn't exist
@@ -215,6 +248,7 @@ def load_spacy_model():
print(f"Error loading spacy model: {e}")
return None
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""
if remove_existing:
@@ -243,14 +277,20 @@ def download_all_models(remove_existing=False):
load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.")
def main():
print("[LOG] Welcome to the Crawl4AI Model Downloader!")
print("[LOG] This script will download all the models required for Crawl4AI.")
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
parser.add_argument(
"--remove-existing",
action="store_true",
help="Remove existing models before downloading",
)
args = parser.parse_args()
download_all_models(remove_existing=args.remove_existing)
if __name__ == "__main__":
main()

View File

@@ -1,18 +1,12 @@
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Optional, Callable, Awaitable, Union, Tuple, Any
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
from enum import Enum
from dataclasses import dataclass, field
from .ssl_certificate import SSLCertificate
from dataclasses import dataclass
from .ssl_certificate import SSLCertificate
from datetime import datetime
from enum import Enum
from typing import Optional
from datetime import timedelta
###############################
# Dispatcher Models
###############################
@@ -22,6 +16,7 @@ class DomainState:
current_delay: float = 0
fail_count: int = 0
@dataclass
class CrawlerTaskResult:
task_id: str
@@ -33,12 +28,14 @@ class CrawlerTaskResult:
end_time: datetime
error_message: str = ""
class CrawlStatus(Enum):
QUEUED = "QUEUED"
IN_PROGRESS = "IN_PROGRESS"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
@dataclass
class CrawlStats:
task_id: str
@@ -58,10 +55,12 @@ class CrawlStats:
duration = end - self.start_time
return str(timedelta(seconds=int(duration.total_seconds())))
class DisplayMode(Enum):
DETAILED = "DETAILED"
AGGREGATED = "AGGREGATED"
###############################
# Crawler Models
###############################
@@ -78,6 +77,7 @@ class UrlModel(BaseModel):
url: HttpUrl
forced: bool = False
class MarkdownGenerationResult(BaseModel):
raw_markdown: str
markdown_with_citations: str
@@ -85,6 +85,7 @@ class MarkdownGenerationResult(BaseModel):
fit_markdown: Optional[str] = None
fit_html: Optional[str] = None
class DispatchResult(BaseModel):
task_id: str
memory_usage: float
@@ -92,6 +93,8 @@ class DispatchResult(BaseModel):
start_time: datetime
end_time: datetime
error_message: str = ""
class CrawlResult(BaseModel):
url: str
html: str
@@ -101,7 +104,7 @@ class CrawlResult(BaseModel):
links: Dict[str, List[Dict]] = {}
downloaded_files: Optional[List[str]] = None
screenshot: Optional[str] = None
pdf : Optional[bytes] = None
pdf: Optional[bytes] = None
markdown: Optional[Union[str, MarkdownGenerationResult]] = None
markdown_v2: Optional[MarkdownGenerationResult] = None
fit_markdown: Optional[str] = None
@@ -114,9 +117,11 @@ class CrawlResult(BaseModel):
status_code: Optional[int] = None
ssl_certificate: Optional[SSLCertificate] = None
dispatch_result: Optional[DispatchResult] = None
class Config:
arbitrary_types_allowed = True
class AsyncCrawlResponse(BaseModel):
html: str
response_headers: Dict[str, str]
@@ -130,6 +135,7 @@ class AsyncCrawlResponse(BaseModel):
class Config:
arbitrary_types_allowed = True
###############################
# Scraping Models
###############################
@@ -143,21 +149,29 @@ class MediaItem(BaseModel):
format: Optional[str] = None
width: Optional[int] = None
class Link(BaseModel):
href: str
text: str
title: Optional[str] = None
base_domain: str
class Media(BaseModel):
images: List[MediaItem] = []
videos: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Video model if needed
audios: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Audio model if needed
videos: List[
MediaItem
] = [] # Using MediaItem model for now, can be extended with Video model if needed
audios: List[
MediaItem
] = [] # Using MediaItem model for now, can be extended with Audio model if needed
class Links(BaseModel):
internal: List[Link] = []
external: List[Link] = []
class ScrapingResult(BaseModel):
cleaned_html: str
success: bool

View File

@@ -26,11 +26,12 @@ class SSLCertificate:
export_as_json() -> Dict[str, Any]: Export the certificate as JSON format.
export_as_text() -> str: Export the certificate as text format.
"""
def __init__(self, cert_info: Dict[str, Any]):
self._cert_info = self._decode_cert_data(cert_info)
@staticmethod
def from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']:
def from_url(url: str, timeout: int = 10) -> Optional["SSLCertificate"]:
"""
Create SSLCertificate instance from a URL.
@@ -43,14 +44,16 @@ class SSLCertificate:
"""
try:
hostname = urlparse(url).netloc
if ':' in hostname:
hostname = hostname.split(':')[0]
if ":" in hostname:
hostname = hostname.split(":")[0]
context = ssl.create_default_context()
with socket.create_connection((hostname, 443), timeout=timeout) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert_binary = ssock.getpeercert(binary_form=True)
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary)
x509 = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, cert_binary
)
cert_info = {
"subject": dict(x509.get_subject().get_components()),
@@ -61,32 +64,33 @@ class SSLCertificate:
"not_after": x509.get_notAfter(),
"fingerprint": x509.digest("sha256").hex(),
"signature_algorithm": x509.get_signature_algorithm(),
"raw_cert": base64.b64encode(cert_binary)
"raw_cert": base64.b64encode(cert_binary),
}
# Add extensions
extensions = []
for i in range(x509.get_extension_count()):
ext = x509.get_extension(i)
extensions.append({
"name": ext.get_short_name(),
"value": str(ext)
})
extensions.append(
{"name": ext.get_short_name(), "value": str(ext)}
)
cert_info["extensions"] = extensions
return SSLCertificate(cert_info)
except Exception as e:
except Exception:
return None
@staticmethod
def _decode_cert_data(data: Any) -> Any:
"""Helper method to decode bytes in certificate data."""
if isinstance(data, bytes):
return data.decode('utf-8')
return data.decode("utf-8")
elif isinstance(data, dict):
return {
(k.decode('utf-8') if isinstance(k, bytes) else k): SSLCertificate._decode_cert_data(v)
(
k.decode("utf-8") if isinstance(k, bytes) else k
): SSLCertificate._decode_cert_data(v)
for k, v in data.items()
}
elif isinstance(data, list):
@@ -105,7 +109,7 @@ class SSLCertificate:
"""
json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False)
if filepath:
Path(filepath).write_text(json_str, encoding='utf-8')
Path(filepath).write_text(json_str, encoding="utf-8")
return None
return json_str
@@ -122,18 +126,17 @@ class SSLCertificate:
try:
x509 = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1,
base64.b64decode(self._cert_info['raw_cert'])
base64.b64decode(self._cert_info["raw_cert"]),
)
pem_data = OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
x509
).decode('utf-8')
OpenSSL.crypto.FILETYPE_PEM, x509
).decode("utf-8")
if filepath:
Path(filepath).write_text(pem_data, encoding='utf-8')
Path(filepath).write_text(pem_data, encoding="utf-8")
return None
return pem_data
except Exception as e:
except Exception:
return None
def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]:
@@ -147,7 +150,7 @@ class SSLCertificate:
Optional[bytes]: DER bytes if successful, None otherwise.
"""
try:
der_data = base64.b64decode(self._cert_info['raw_cert'])
der_data = base64.b64decode(self._cert_info["raw_cert"])
if filepath:
Path(filepath).write_bytes(der_data)
return None
@@ -158,24 +161,24 @@ class SSLCertificate:
@property
def issuer(self) -> Dict[str, str]:
"""Get certificate issuer information."""
return self._cert_info.get('issuer', {})
return self._cert_info.get("issuer", {})
@property
def subject(self) -> Dict[str, str]:
"""Get certificate subject information."""
return self._cert_info.get('subject', {})
return self._cert_info.get("subject", {})
@property
def valid_from(self) -> str:
"""Get certificate validity start date."""
return self._cert_info.get('not_before', '')
return self._cert_info.get("not_before", "")
@property
def valid_until(self) -> str:
"""Get certificate validity end date."""
return self._cert_info.get('not_after', '')
return self._cert_info.get("not_after", "")
@property
def fingerprint(self) -> str:
"""Get certificate fingerprint."""
return self._cert_info.get('fingerprint', '')
return self._cert_info.get("fingerprint", "")

View File

@@ -32,6 +32,7 @@ class UserAgentGenerator:
android_version: Optional[str] = None
): Generates a random user agent string based on the specified parameters.
"""
def __init__(self):
# Previous platform definitions remain the same...
self.desktop_platforms = {
@@ -47,7 +48,7 @@ class UserAgentGenerator:
"generic": "(X11; Linux x86_64)",
"ubuntu": "(X11; Ubuntu; Linux x86_64)",
"chrome_os": "(X11; CrOS x86_64 14541.0.0)",
}
},
}
self.mobile_platforms = {
@@ -60,26 +61,14 @@ class UserAgentGenerator:
"ios": {
"iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)",
"ipad": "(iPad; CPU OS 16_5 like Mac OS X)",
}
},
}
# Browser Combinations
self.browser_combinations = {
1: [
["chrome"],
["firefox"],
["safari"],
["edge"]
],
2: [
["gecko", "firefox"],
["chrome", "safari"],
["webkit", "safari"]
],
3: [
["chrome", "safari", "edge"],
["webkit", "chrome", "safari"]
]
1: [["chrome"], ["firefox"], ["safari"], ["edge"]],
2: [["gecko", "firefox"], ["chrome", "safari"], ["webkit", "safari"]],
3: [["chrome", "safari", "edge"], ["webkit", "chrome", "safari"]],
}
# Rendering Engines with versions
@@ -90,7 +79,7 @@ class UserAgentGenerator:
"Gecko/20100101",
"Gecko/20100101", # Firefox usually uses this constant version
"Gecko/2010010",
]
],
}
# Browser Versions
@@ -170,12 +159,14 @@ class UserAgentGenerator:
return browser_stack
def generate(self,
device_type: Optional[Literal['desktop', 'mobile']] = None,
def generate(
self,
device_type: Optional[Literal["desktop", "mobile"]] = None,
os_type: Optional[str] = None,
device_brand: Optional[str] = None,
browser_type: Optional[Literal['chrome', 'edge', 'safari', 'firefox']] = None,
num_browsers: int = 3) -> str:
browser_type: Optional[Literal["chrome", "edge", "safari", "firefox"]] = None,
num_browsers: int = 3,
) -> str:
"""
Generate a random user agent with specified constraints.
@@ -215,9 +206,13 @@ class UserAgentGenerator:
def get_random_platform(self, device_type, os_type, device_brand):
"""Helper method to get random platform based on constraints"""
platforms = self.desktop_platforms if device_type == 'desktop' else \
self.mobile_platforms if device_type == 'mobile' else \
{**self.desktop_platforms, **self.mobile_platforms}
platforms = (
self.desktop_platforms
if device_type == "desktop"
else self.mobile_platforms
if device_type == "mobile"
else {**self.desktop_platforms, **self.mobile_platforms}
)
if os_type:
for platform_group in [self.desktop_platforms, self.mobile_platforms]:
@@ -233,10 +228,10 @@ class UserAgentGenerator:
def parse_user_agent(self, user_agent: str) -> Dict[str, str]:
"""Parse a user agent string to extract browser and version information"""
browsers = {
'chrome': r'Chrome/(\d+)',
'edge': r'Edg/(\d+)',
'safari': r'Version/(\d+)',
'firefox': r'Firefox/(\d+)'
"chrome": r"Chrome/(\d+)",
"edge": r"Edg/(\d+)",
"safari": r"Version/(\d+)",
"firefox": r"Firefox/(\d+)",
}
result = {}
@@ -255,25 +250,26 @@ class UserAgentGenerator:
hints = []
# Handle different browser combinations
if 'chrome' in browsers:
if "chrome" in browsers:
hints.append(f'"Chromium";v="{browsers["chrome"]}"')
hints.append('"Not_A Brand";v="8"')
if 'edge' in browsers:
if "edge" in browsers:
hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"')
else:
hints.append(f'"Google Chrome";v="{browsers["chrome"]}"')
elif 'firefox' in browsers:
elif "firefox" in browsers:
# Firefox doesn't typically send Sec-CH-UA
return '""'
elif 'safari' in browsers:
elif "safari" in browsers:
# Safari's format for client hints
hints.append(f'"Safari";v="{browsers["safari"]}"')
hints.append('"Not_A Brand";v="8"')
return ', '.join(hints)
return ", ".join(hints)
# Example usage:
if __name__ == "__main__":
@@ -281,7 +277,7 @@ if __name__ == "__main__":
print(generator.generate())
print("\nSingle browser (Chrome):")
print(generator.generate(num_browsers=1, browser_type='chrome'))
print(generator.generate(num_browsers=1, browser_type="chrome"))
print("\nTwo browsers (Gecko/Firefox):")
print(generator.generate(num_browsers=2))
@@ -290,16 +286,14 @@ if __name__ == "__main__":
print(generator.generate(num_browsers=3))
print("\nFirefox on Linux:")
print(generator.generate(
device_type='desktop',
os_type='linux',
browser_type='firefox',
num_browsers=2
))
print(
generator.generate(
device_type="desktop",
os_type="linux",
browser_type="firefox",
num_browsers=2,
)
)
print("\nChrome/Safari/Edge on Windows:")
print(generator.generate(
device_type='desktop',
os_type='windows',
num_browsers=3
))
print(generator.generate(device_type="desktop", os_type="windows", num_browsers=3))

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
# version_manager.py
import os
from pathlib import Path
from packaging import version
from . import __version__
class VersionManager:
def __init__(self):
self.home_dir = Path.home() / ".crawl4ai"
@@ -27,4 +27,3 @@ class VersionManager:
installed = self.get_installed_version()
current = version.parse(__version__.__version__)
return installed is None or installed < current

View File

@@ -1,9 +1,10 @@
import os, time
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path
from .models import UrlModel, CrawlResult
from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db
from .database import init_db, get_cached_url, cache_url
from .utils import *
from .chunking_strategy import *
from .extraction_strategy import *
@@ -14,14 +15,27 @@ from .content_scraping_strategy import WebScrapingStrategy
from .config import *
import warnings
import json
warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".')
warnings.filterwarnings(
"ignore",
message='Field "model_name" has conflict with protected namespace "model_".',
)
class WebCrawler:
def __init__(self, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False):
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
def __init__(
self,
crawler_strategy: CrawlerStrategy = None,
always_by_pass_cache: bool = False,
verbose: bool = False,
):
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(
verbose=verbose
)
self.always_by_pass_cache = always_by_pass_cache
self.crawl4ai_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
self.crawl4ai_folder = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(self.crawl4ai_folder, exist_ok=True)
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
init_db()
@@ -30,11 +44,11 @@ class WebCrawler:
def warmup(self):
print("[LOG] 🌤️ Warming up the WebCrawler")
self.run(
url='https://google.com/',
url="https://google.com/",
word_count_threshold=5,
extraction_strategy=NoExtractionStrategy(),
bypass_cache=False,
verbose=False
verbose=False,
)
self.ready = True
print("[LOG] 🌞 WebCrawler is ready to crawl")
@@ -80,6 +94,7 @@ class WebCrawler:
**kwargs,
) -> List[CrawlResult]:
extraction_strategy = extraction_strategy or NoExtractionStrategy()
def fetch_page_wrapper(url_model, *args, **kwargs):
return self.fetch_page(url_model, *args, **kwargs)
@@ -150,12 +165,25 @@ class WebCrawler:
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
t2 = time.time()
if verbose:
print(f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds")
print(
f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
)
if screenshot:
screenshot_data = self.crawler_strategy.take_screenshot()
crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs)
crawl_result = self.process_html(
url,
html,
extracted_content,
word_count_threshold,
extraction_strategy,
chunking_strategy,
css_selector,
screenshot_data,
verbose,
bool(cached),
**kwargs,
)
crawl_result.success = bool(html)
return crawl_result
except Exception as e:
@@ -183,7 +211,11 @@ class WebCrawler:
try:
t1 = time.time()
scrapping_strategy = WebScrapingStrategy()
extra_params = {k: v for k, v in kwargs.items() if k not in ["only_text", "image_description_min_word_threshold"]}
extra_params = {
k: v
for k, v in kwargs.items()
if k not in ["only_text", "image_description_min_word_threshold"]
}
result = scrapping_strategy.scrap(
url,
html,
@@ -191,14 +223,17 @@ class WebCrawler:
css_selector=css_selector,
only_text=kwargs.get("only_text", False),
image_description_min_word_threshold=kwargs.get(
"image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
"image_description_min_word_threshold",
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
),
**extra_params,
)
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
if verbose:
print(f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds")
print(
f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds"
)
if result is None:
raise ValueError(f"Failed to extract content from the website: {url}")
@@ -213,14 +248,20 @@ class WebCrawler:
if extracted_content is None:
if verbose:
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}")
print(
f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}"
)
sections = chunking_strategy.chunk(markdown)
extracted_content = extraction_strategy.run(url, sections)
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
extracted_content = json.dumps(
extracted_content, indent=4, default=str, ensure_ascii=False
)
if verbose:
print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds.")
print(
f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds."
)
screenshot = None if not screenshot else screenshot

View File

@@ -9,12 +9,10 @@ from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json
async def extract_amazon_products():
# Initialize browser config
browser_config = BrowserConfig(
browser_type="chromium",
headless=True
)
browser_config = BrowserConfig(browser_type="chromium", headless=True)
# Initialize crawler config with JSON CSS extraction strategy
crawler_config = CrawlerRunConfig(
@@ -27,57 +25,53 @@ async def extract_amazon_products():
"name": "asin",
"selector": "",
"type": "attribute",
"attribute": "data-asin"
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
"attribute": "data-asin",
},
{"name": "title", "selector": "h2 a span", "type": "text"},
{
"name": "url",
"selector": "h2 a",
"type": "attribute",
"attribute": "href"
"attribute": "href",
},
{
"name": "image",
"selector": ".s-image",
"type": "attribute",
"attribute": "src"
"attribute": "src",
},
{
"name": "rating",
"selector": ".a-icon-star-small .a-icon-alt",
"type": "text"
"type": "text",
},
{
"name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text"
"type": "text",
},
{
"name": "price",
"selector": ".a-price .a-offscreen",
"type": "text"
"type": "text",
},
{
"name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen",
"type": "text"
"type": "text",
},
{
"name": "sponsored",
"selector": ".puis-sponsored-label-text",
"type": "exists"
"type": "exists",
},
{
"name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text",
"multiple": True
}
]
"multiple": True,
},
],
}
)
)
@@ -105,10 +99,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'):
if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80)
if __name__ == "__main__":
import asyncio
asyncio.run(extract_amazon_products())

View File

@@ -10,6 +10,7 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json
from playwright.async_api import Page, BrowserContext
async def extract_amazon_products():
# Initialize browser config
browser_config = BrowserConfig(
@@ -20,7 +21,6 @@ async def extract_amazon_products():
# Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
extraction_strategy=JsonCssExtractionStrategy(
schema={
"name": "Amazon Product Search Results",
@@ -30,82 +30,86 @@ async def extract_amazon_products():
"name": "asin",
"selector": "",
"type": "attribute",
"attribute": "data-asin"
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
"attribute": "data-asin",
},
{"name": "title", "selector": "h2 a span", "type": "text"},
{
"name": "url",
"selector": "h2 a",
"type": "attribute",
"attribute": "href"
"attribute": "href",
},
{
"name": "image",
"selector": ".s-image",
"type": "attribute",
"attribute": "src"
"attribute": "src",
},
{
"name": "rating",
"selector": ".a-icon-star-small .a-icon-alt",
"type": "text"
"type": "text",
},
{
"name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text"
"type": "text",
},
{
"name": "price",
"selector": ".a-price .a-offscreen",
"type": "text"
"type": "text",
},
{
"name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen",
"type": "text"
"type": "text",
},
{
"name": "sponsored",
"selector": ".puis-sponsored-label-text",
"type": "exists"
"type": "exists",
},
{
"name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text",
"multiple": True
"multiple": True,
},
],
}
]
}
)
),
)
url = "https://www.amazon.com/"
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs):
async def after_goto(
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
):
"""Hook called after navigating to each URL"""
print(f"[HOOK] after_goto - Successfully loaded: {url}")
try:
# Wait for search box to be available
search_box = await page.wait_for_selector('#twotabsearchtextbox', timeout=1000)
search_box = await page.wait_for_selector(
"#twotabsearchtextbox", timeout=1000
)
# Type the search query
await search_box.fill('Samsung Galaxy Tab')
await search_box.fill("Samsung Galaxy Tab")
# Get the search button and prepare for navigation
search_button = await page.wait_for_selector('#nav-search-submit-button', timeout=1000)
search_button = await page.wait_for_selector(
"#nav-search-submit-button", timeout=1000
)
# Click with navigation waiting
await search_button.click()
# Wait for search results to load
await page.wait_for_selector('[data-component-type="s-search-result"]', timeout=10000)
await page.wait_for_selector(
'[data-component-type="s-search-result"]', timeout=10000
)
print("[HOOK] Search completed and results loaded!")
except Exception as e:
@@ -115,7 +119,6 @@ async def extract_amazon_products():
# Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler.crawler_strategy.set_hook("after_goto", after_goto)
# Extract the data
@@ -136,10 +139,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'):
if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80)
if __name__ == "__main__":
import asyncio
asyncio.run(extract_amazon_products())

View File

@@ -8,7 +8,7 @@ from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json
from playwright.async_api import Page, BrowserContext
async def extract_amazon_products():
# Initialize browser config
@@ -30,7 +30,7 @@ async def extract_amazon_products():
"""
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
js_code = js_code_to_search,
js_code=js_code_to_search,
wait_for='css:[data-component-type="s-search-result"]',
extraction_strategy=JsonCssExtractionStrategy(
schema={
@@ -41,65 +41,60 @@ async def extract_amazon_products():
"name": "asin",
"selector": "",
"type": "attribute",
"attribute": "data-asin"
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
"attribute": "data-asin",
},
{"name": "title", "selector": "h2 a span", "type": "text"},
{
"name": "url",
"selector": "h2 a",
"type": "attribute",
"attribute": "href"
"attribute": "href",
},
{
"name": "image",
"selector": ".s-image",
"type": "attribute",
"attribute": "src"
"attribute": "src",
},
{
"name": "rating",
"selector": ".a-icon-star-small .a-icon-alt",
"type": "text"
"type": "text",
},
{
"name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text"
"type": "text",
},
{
"name": "price",
"selector": ".a-price .a-offscreen",
"type": "text"
"type": "text",
},
{
"name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen",
"type": "text"
"type": "text",
},
{
"name": "sponsored",
"selector": ".puis-sponsored-label-text",
"type": "exists"
"type": "exists",
},
{
"name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text",
"multiple": True
"multiple": True,
},
],
}
]
}
)
),
)
# Example search URL (you should replace with your actual Amazon URL)
url = "https://www.amazon.com/"
# Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler:
# Extract the data
@@ -120,10 +115,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'):
if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80)
if __name__ == "__main__":
import asyncio
asyncio.run(extract_amazon_products())

View File

@@ -1,12 +1,16 @@
# File: async_webcrawler_multiple_urls_example.py
import os, sys
# append 2 parent directories to sys.path to import crawl4ai
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir)
import asyncio
from crawl4ai import AsyncWebCrawler
async def main():
# Initialize the AsyncWebCrawler
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -16,7 +20,7 @@ async def main():
"https://python.org",
"https://github.com",
"https://stackoverflow.com",
"https://news.ycombinator.com"
"https://news.ycombinator.com",
]
# Set up crawling parameters
@@ -27,7 +31,7 @@ async def main():
urls=urls,
word_count_threshold=word_count_threshold,
bypass_cache=True,
verbose=True
verbose=True,
)
# Process the results
@@ -36,7 +40,9 @@ async def main():
print(f"Successfully crawled: {result.url}")
print(f"Title: {result.metadata.get('title', 'N/A')}")
print(f"Word count: {len(result.markdown.split())}")
print(f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}")
print(
f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}"
)
print(f"Number of images: {len(result.media.get('images', []))}")
print("---")
else:
@@ -44,5 +50,6 @@ async def main():
print(f"Error: {result.error_message}")
print("---")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -6,10 +6,8 @@ This example demonstrates optimal browser usage patterns in Crawl4AI:
"""
import asyncio
import os
from typing import List
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator

View File

@@ -1,31 +1,32 @@
import os, time
# append the path to the root of the project
import sys
import asyncio
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from firecrawl import FirecrawlApp
from crawl4ai import AsyncWebCrawler
__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data'
__data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data"
async def compare():
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
# Tet Firecrawl with a simple crawl
start = time.time()
scrape_status = app.scrape_url(
'https://www.nbcnews.com/business',
params={'formats': ['markdown', 'html']}
"https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
)
end = time.time()
print(f"Time taken: {end - start} seconds")
print(len(scrape_status['markdown']))
print(len(scrape_status["markdown"]))
# save the markdown content with provider name
with open(f"{__data__}/firecrawl_simple.md", "w") as f:
f.write(scrape_status['markdown'])
f.write(scrape_status["markdown"])
# Count how many "cldnry.s-nbcnews.com" are in the markdown
print(scrape_status['markdown'].count("cldnry.s-nbcnews.com"))
print(scrape_status["markdown"].count("cldnry.s-nbcnews.com"))
async with AsyncWebCrawler() as crawler:
start = time.time()
@@ -34,7 +35,7 @@ async def compare():
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
word_count_threshold=0,
bypass_cache=True,
verbose=False
verbose=False,
)
end = time.time()
print(f"Time taken: {end - start} seconds")
@@ -48,10 +49,12 @@ async def compare():
start = time.time()
result = await crawler.arun(
url="https://www.nbcnews.com/business",
js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
js_code=[
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
word_count_threshold=0,
bypass_cache=True,
verbose=False
verbose=False,
)
end = time.time()
print(f"Time taken: {end - start} seconds")
@@ -62,6 +65,6 @@ async def compare():
# count how many "cldnry.s-nbcnews.com" are in the markdown
print(result.markdown.count("cldnry.s-nbcnews.com"))
if __name__ == "__main__":
asyncio.run(compare())

View File

@@ -3,11 +3,18 @@ import time
from rich import print
from rich.table import Table
from crawl4ai import (
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig,
MemoryAdaptiveDispatcher, SemaphoreDispatcher,
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
CacheMode,
)
async def memory_adaptive(urls, browser_config, run_config):
"""Memory adaptive crawler with monitoring"""
start = time.perf_counter()
@@ -16,14 +23,16 @@ async def memory_adaptive(urls, browser_config, run_config):
memory_threshold_percent=70.0,
max_session_permit=10,
monitor=CrawlerMonitor(
max_visible_rows=15,
display_mode=DisplayMode.DETAILED
max_visible_rows=15, display_mode=DisplayMode.DETAILED
),
)
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start
return len(results), duration
async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
"""Memory adaptive crawler with rate limiting"""
start = time.perf_counter()
@@ -32,19 +41,19 @@ async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
memory_threshold_percent=70.0,
max_session_permit=10,
rate_limiter=RateLimiter(
base_delay=(1.0, 2.0),
max_delay=30.0,
max_retries=2
base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
),
monitor=CrawlerMonitor(
max_visible_rows=15,
display_mode=DisplayMode.DETAILED
max_visible_rows=15, display_mode=DisplayMode.DETAILED
),
)
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start
return len(results), duration
async def semaphore(urls, browser_config, run_config):
"""Basic semaphore crawler"""
start = time.perf_counter()
@@ -52,14 +61,16 @@ async def semaphore(urls, browser_config, run_config):
dispatcher = SemaphoreDispatcher(
semaphore_count=5,
monitor=CrawlerMonitor(
max_visible_rows=15,
display_mode=DisplayMode.DETAILED
max_visible_rows=15, display_mode=DisplayMode.DETAILED
),
)
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start
return len(results), duration
async def semaphore_with_rate_limit(urls, browser_config, run_config):
"""Semaphore crawler with rate limiting"""
start = time.perf_counter()
@@ -67,19 +78,19 @@ async def semaphore_with_rate_limit(urls, browser_config, run_config):
dispatcher = SemaphoreDispatcher(
semaphore_count=5,
rate_limiter=RateLimiter(
base_delay=(1.0, 2.0),
max_delay=30.0,
max_retries=2
base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
),
monitor=CrawlerMonitor(
max_visible_rows=15,
display_mode=DisplayMode.DETAILED
max_visible_rows=15, display_mode=DisplayMode.DETAILED
),
)
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start
return len(results), duration
def create_performance_table(results):
"""Creates a rich table showing performance results"""
table = Table(title="Crawler Strategy Performance Comparison")
@@ -93,14 +104,12 @@ def create_performance_table(results):
for strategy, (urls_crawled, duration) in sorted_results:
urls_per_second = urls_crawled / duration
table.add_row(
strategy,
str(urls_crawled),
f"{duration:.2f}",
f"{urls_per_second:.2f}"
strategy, str(urls_crawled), f"{duration:.2f}", f"{urls_per_second:.2f}"
)
return table
async def main():
urls = [f"https://example.com/page{i}" for i in range(1, 20)]
browser_config = BrowserConfig(headless=True, verbose=False)
@@ -108,14 +117,19 @@ async def main():
results = {
"Memory Adaptive": await memory_adaptive(urls, browser_config, run_config),
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(urls, browser_config, run_config),
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(
urls, browser_config, run_config
),
"Semaphore": await semaphore(urls, browser_config, run_config),
"Semaphore + Rate Limit": await semaphore_with_rate_limit(urls, browser_config, run_config),
"Semaphore + Rate Limit": await semaphore_with_rate_limit(
urls, browser_config, run_config
),
}
table = create_performance_table(results)
print("\nPerformance Summary:")
print(table)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -6,15 +6,24 @@ import base64
import os
from typing import Dict, Any
class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
self.base_url = base_url
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" # Check environment variable as fallback
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {}
self.api_token = (
api_token or os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
) # Check environment variable as fallback
self.headers = (
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
)
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers)
response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
if response.status_code == 403:
raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"]
@@ -24,9 +33,13 @@ class Crawl4AiTester:
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers)
result = requests.get(
f"{self.base_url}/task/{task_id}", headers=self.headers
)
status = result.json()
if status["status"] == "failed":
@@ -39,7 +52,12 @@ class Crawl4AiTester:
time.sleep(2)
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60)
response = requests.post(
f"{self.base_url}/crawl_sync",
json=request_data,
headers=self.headers,
timeout=60,
)
if response.status_code == 408:
raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status()
@@ -48,16 +66,15 @@ class Crawl4AiTester:
def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Directly crawl without using task queue"""
response = requests.post(
f"{self.base_url}/crawl_direct",
json=request_data,
headers=self.headers
f"{self.base_url}/crawl_direct", json=request_data, headers=self.headers
)
response.raise_for_status()
return response.json()
def test_docker_deployment(version="basic"):
tester = Crawl4AiTester(
base_url="http://localhost:11235" ,
base_url="http://localhost:11235",
# base_url="https://api.crawl4ai.com" # just for example
# api_token="test" # just for example
)
@@ -70,7 +87,7 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json())
break
except requests.exceptions.RequestException as e:
except requests.exceptions.RequestException:
if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts")
sys.exit(1)
@@ -99,7 +116,7 @@ def test_basic_crawl(tester: Crawl4AiTester):
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10,
"session_id": "test"
"session_id": "test",
}
result = tester.submit_and_wait(request)
@@ -107,19 +124,21 @@ def test_basic_crawl(tester: Crawl4AiTester):
assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_sync(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Sync) ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10,
"session_id": "test"
"session_id": "test",
}
result = tester.submit_sync(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['status'] == 'completed'
assert result['result']['success']
assert len(result['result']['markdown']) > 0
assert result["status"] == "completed"
assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_direct(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Direct) ===")
@@ -127,13 +146,14 @@ def test_basic_crawl_direct(tester: Crawl4AiTester):
"urls": "https://www.nbcnews.com/business",
"priority": 10,
# "session_id": "test"
"cache_mode": "bypass" # or "enabled", "disabled", "read_only", "write_only"
"cache_mode": "bypass", # or "enabled", "disabled", "read_only", "write_only"
}
result = tester.crawl_direct(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['result']['success']
assert len(result['result']['markdown']) > 0
assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===")
@@ -144,32 +164,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
"wait_for": "article.tease-card:nth-child(10)",
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 7,
"css_selector": ".wide-tease-item__description",
"crawler_params": {
"headless": True
},
"extra": {"word_count_threshold": 10}
"crawler_params": {"headless": True},
"extra": {"word_count_threshold": 10},
}
result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===")
schema = {
@@ -190,19 +207,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price",
"selector": "td:nth-child(2)",
"type": "text",
}
},
],
}
request = {
"urls": "https://www.coinbase.com/explore",
"priority": 9,
"extraction_config": {
"type": "json_css",
"params": {
"schema": schema
}
}
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
}
result = tester.submit_and_wait(request)
@@ -212,6 +224,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"]
assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===")
schema = {
@@ -219,18 +232,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": {
"model_name": {
"type": "string",
"description": "Name of the OpenAI model."
"description": "Name of the OpenAI model.",
},
"input_fee": {
"type": "string",
"description": "Fee for input token for the OpenAI model."
"description": "Fee for input token for the OpenAI model.",
},
"output_fee": {
"type": "string",
"description": "Fee for output token for the OpenAI model."
}
"description": "Fee for output token for the OpenAI model.",
},
"required": ["model_name", "input_fee", "output_fee"]
},
"required": ["model_name", "input_fee", "output_fee"],
}
request = {
@@ -243,10 +256,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema,
"extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
}
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
},
"crawler_params": {"word_count_threshold": 1}
},
"crawler_params": {"word_count_threshold": 1},
}
try:
@@ -258,6 +271,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===")
schema = {
@@ -265,18 +279,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": {
"article_title": {
"type": "string",
"description": "The main title of the news article"
"description": "The main title of the news article",
},
"summary": {
"type": "string",
"description": "A brief summary of the article content"
"description": "A brief summary of the article content",
},
"main_topics": {
"type": "array",
"items": {"type": "string"},
"description": "Main topics or themes discussed in the article"
}
}
"description": "Main topics or themes discussed in the article",
},
},
}
request = {
@@ -288,11 +302,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2",
"schema": schema,
"extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics."
}
"instruction": "Extract the main article information including title, summary, and main topics.",
},
},
"extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True}
"crawler_params": {"verbose": True},
}
try:
@@ -303,6 +317,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e:
print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===")
request = {
@@ -314,9 +329,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy",
"word_count_threshold": 10,
"max_dist": 0.2,
"top_k": 3
}
}
"top_k": 3,
},
},
}
try:
@@ -328,15 +343,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e:
print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 5,
"screenshot": True,
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
result = tester.submit_and_wait(request)
@@ -351,6 +365,7 @@ def test_screenshot(tester: Crawl4AiTester):
assert result["result"]["success"]
if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full"

View File

@@ -9,18 +9,17 @@ This example shows how to:
import asyncio
import os
from typing import Dict, Any
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from crawl4ai.extraction_strategy import (
LLMExtractionStrategy,
JsonCssExtractionStrategy,
JsonXPathExtractionStrategy
JsonXPathExtractionStrategy,
)
from crawl4ai.chunking_strategy import RegexChunking, IdentityChunking
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str):
"""Helper function to run extraction with proper configuration"""
try:
@@ -30,7 +29,7 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
extraction_strategy=strategy,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter() # For fit_markdown support
)
),
)
# Run the crawler
@@ -40,22 +39,22 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
print(f"\n=== {name} Results ===")
print(f"Extracted Content: {result.extracted_content}")
print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}")
print(f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}")
print(
f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}"
)
else:
print(f"Error in {name}: Crawl failed")
except Exception as e:
print(f"Error in {name}: {str(e)}")
async def main():
# Example URL (replace with actual URL)
url = "https://example.com/product-page"
# Configure browser settings
browser_config = BrowserConfig(
headless=True,
verbose=True
)
browser_config = BrowserConfig(headless=True, verbose=True)
# Initialize extraction strategies
@@ -63,21 +62,21 @@ async def main():
markdown_strategy = LLMExtractionStrategy(
provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information including name, price, and description"
instruction="Extract product information including name, price, and description",
)
html_strategy = LLMExtractionStrategy(
input_format="html",
provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information from HTML including structured data"
instruction="Extract product information from HTML including structured data",
)
fit_markdown_strategy = LLMExtractionStrategy(
input_format="fit_markdown",
provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information from cleaned markdown"
instruction="Extract product information from cleaned markdown",
)
# 2. JSON CSS Extraction (automatically uses HTML input)
@@ -86,8 +85,8 @@ async def main():
"fields": [
{"name": "title", "selector": "h1.product-title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"},
{"name": "description", "selector": ".description", "type": "text"}
]
{"name": "description", "selector": ".description", "type": "text"},
],
}
css_strategy = JsonCssExtractionStrategy(schema=css_schema)
@@ -95,10 +94,22 @@ async def main():
xpath_schema = {
"baseSelector": "//div[@class='product']",
"fields": [
{"name": "title", "selector": ".//h1[@class='product-title']/text()", "type": "text"},
{"name": "price", "selector": ".//span[@class='price']/text()", "type": "text"},
{"name": "description", "selector": ".//div[@class='description']/text()", "type": "text"}
]
{
"name": "title",
"selector": ".//h1[@class='product-title']/text()",
"type": "text",
},
{
"name": "price",
"selector": ".//span[@class='price']/text()",
"type": "text",
},
{
"name": "description",
"selector": ".//div[@class='description']/text()",
"type": "text",
},
],
}
xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema)
@@ -111,5 +122,6 @@ async def main():
await run_extraction(crawler, url, css_strategy, "CSS Extraction")
await run_extraction(crawler, url, xpath_strategy, "XPath Extraction")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,20 +1,23 @@
import asyncio
from crawl4ai import *
async def main():
browser_config = BrowserConfig(headless=True, verbose=True)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
),
)
result = await crawler.arun(
url="https://www.helloworld.org",
config=crawler_config
url="https://www.helloworld.org", config=crawler_config
)
print(result.markdown_v2.raw_markdown[:500])
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,19 +1,18 @@
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from playwright.async_api import Page, BrowserContext
async def main():
print("🔗 Hooks Example: Demonstrating different hook use cases")
# Configure browser settings
browser_config = BrowserConfig(
headless=True
)
browser_config = BrowserConfig(headless=True)
# Configure crawler settings
crawler_run_config = CrawlerRunConfig(
js_code="window.scrollTo(0, document.body.scrollHeight);",
wait_for="body",
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
# Create crawler instance
@@ -30,16 +29,22 @@ async def main():
"""Hook called after a new page and context are created"""
print("[HOOK] on_page_context_created - New page created!")
# Example: Set default viewport size
await context.add_cookies([{
'name': 'session_id',
'value': 'example_session',
'domain': '.example.com',
'path': '/'
}])
await context.add_cookies(
[
{
"name": "session_id",
"value": "example_session",
"domain": ".example.com",
"path": "/",
}
]
)
await page.set_viewport_size({"width": 1080, "height": 800})
return page
async def on_user_agent_updated(page: Page, context: BrowserContext, user_agent: str, **kwargs):
async def on_user_agent_updated(
page: Page, context: BrowserContext, user_agent: str, **kwargs
):
"""Hook called when the user agent is updated"""
print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}")
return page
@@ -53,17 +58,17 @@ async def main():
"""Hook called before navigating to each URL"""
print(f"[HOOK] before_goto - About to visit: {url}")
# Example: Add custom headers for the request
await page.set_extra_http_headers({
"Custom-Header": "my-value"
})
await page.set_extra_http_headers({"Custom-Header": "my-value"})
return page
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs):
async def after_goto(
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
):
"""Hook called after navigating to each URL"""
print(f"[HOOK] after_goto - Successfully loaded: {url}")
# Example: Wait for a specific element to be loaded
try:
await page.wait_for_selector('.content', timeout=1000)
await page.wait_for_selector(".content", timeout=1000)
print("Content element found!")
except:
print("Content element not found, continuing anyway")
@@ -76,7 +81,9 @@ async def main():
await page.evaluate("window.scrollTo(0, document.body.scrollHeight);")
return page
async def before_return_html(page: Page, context: BrowserContext, html:str, **kwargs):
async def before_return_html(
page: Page, context: BrowserContext, html: str, **kwargs
):
"""Hook called before returning the HTML content"""
print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})")
# Example: You could modify the HTML content here if needed
@@ -84,7 +91,9 @@ async def main():
# Set all the hooks
crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created)
crawler.crawler_strategy.set_hook("on_page_context_created", on_page_context_created)
crawler.crawler_strategy.set_hook(
"on_page_context_created", on_page_context_created
)
crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated)
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
crawler.crawler_strategy.set_hook("before_goto", before_goto)
@@ -95,13 +104,15 @@ async def main():
await crawler.start()
# Example usage: crawl a simple website
url = 'https://example.com'
url = "https://example.com"
result = await crawler.arun(url, config=crawler_run_config)
print(f"\nCrawled URL: {result.url}")
print(f"HTML length: {len(result.html)}")
await crawler.close()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,6 +1,7 @@
import asyncio
from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy
async def main():
# Example 1: Setting language when creating the crawler
crawler1 = AsyncWebCrawler(
@@ -9,11 +10,15 @@ async def main():
)
)
result1 = await crawler1.arun("https://www.example.com")
print("Example 1 result:", result1.extracted_content[:100]) # Print first 100 characters
print(
"Example 1 result:", result1.extracted_content[:100]
) # Print first 100 characters
# Example 2: Setting language before crawling
crawler2 = AsyncWebCrawler()
crawler2.crawler_strategy.headers["Accept-Language"] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
crawler2.crawler_strategy.headers[
"Accept-Language"
] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
result2 = await crawler2.arun("https://www.example.com")
print("Example 2 result:", result2.extracted_content[:100])
@@ -21,7 +26,7 @@ async def main():
crawler3 = AsyncWebCrawler()
result3 = await crawler3.arun(
"https://www.example.com",
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"}
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"},
)
print("Example 3 result:", result3.extracted_content[:100])
@@ -33,13 +38,13 @@ async def main():
]
crawler4 = AsyncWebCrawler()
results = await asyncio.gather(*[
crawler4.arun(url, headers={"Accept-Language": lang})
for url, lang in urls
])
results = await asyncio.gather(
*[crawler4.arun(url, headers={"Accept-Language": lang}) for url, lang in urls]
)
for url, result in zip([u for u, _ in urls], results):
print(f"Result for {url}:", result.extracted_content[:100])
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -3,32 +3,37 @@ from crawl4ai.crawler_strategy import *
import asyncio
from pydantic import BaseModel, Field
url = r'https://openai.com/api/pricing/'
url = r"https://openai.com/api/pricing/"
class OpenAIModelFee(BaseModel):
model_name: str = Field(..., description="Name of the OpenAI model.")
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
output_fee: str = Field(..., description="Fee for output token for the OpenAI model.")
output_fee: str = Field(
..., description="Fee for output token for the OpenAI model."
)
from crawl4ai import AsyncWebCrawler
async def main():
# Use AsyncWebCrawler
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url=url,
word_count_threshold=1,
extraction_strategy= LLMExtractionStrategy(
extraction_strategy=LLMExtractionStrategy(
# provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
provider= "groq/llama-3.1-70b-versatile", api_token = os.getenv('GROQ_API_KEY'),
provider="groq/llama-3.1-70b-versatile",
api_token=os.getenv("GROQ_API_KEY"),
schema=OpenAIModelFee.model_json_schema(),
extraction_type="schema",
instruction="From the crawled content, extract all mentioned model names along with their " \
"fees for input and output tokens. Make sure not to miss anything in the entire content. " \
'One extracted model JSON format should look like this: ' \
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }'
instruction="From the crawled content, extract all mentioned model names along with their "
"fees for input and output tokens. Make sure not to miss anything in the entire content. "
"One extracted model JSON format should look like this: "
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }',
),
)
print("Success:", result.success)
model_fees = json.loads(result.extracted_content)
@@ -37,4 +42,5 @@ async def main():
with open(".data/data.json", "w", encoding="utf-8") as f:
f.write(result.extracted_content)
asyncio.run(main())

View File

@@ -8,12 +8,12 @@ import asyncio
import time
import json
import re
from typing import Dict, List
from typing import Dict
from bs4 import BeautifulSoup
from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import (
JsonCssExtractionStrategy,
LLMExtractionStrategy,
@@ -62,6 +62,7 @@ async def clean_content():
print(f"Full Markdown Length: {full_markdown_length}")
print(f"Fit Markdown Length: {fit_markdown_length}")
async def link_analysis():
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.ENABLED,
@@ -76,9 +77,10 @@ async def link_analysis():
print(f"Found {len(result.links['internal'])} internal links")
print(f"Found {len(result.links['external'])} external links")
for link in result.links['internal'][:5]:
for link in result.links["internal"][:5]:
print(f"Href: {link['href']}\nText: {link['text']}\n")
# JavaScript Execution Example
async def simple_example_with_running_js_code():
print("\n--- Executing JavaScript and Using CSS Selectors ---")
@@ -112,25 +114,29 @@ async def simple_example_with_css_selector():
)
print(result.markdown[:500])
async def media_handling():
crawler_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url="https://www.nbcnews.com/business",
config=crawler_config
url="https://www.nbcnews.com/business", config=crawler_config
)
for img in result.media['images'][:5]:
for img in result.media["images"][:5]:
print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}")
async def custom_hook_workflow(verbose=True):
async with AsyncWebCrawler() as crawler:
# Set a 'before_goto' hook to run custom code just before navigation
crawler.crawler_strategy.set_hook("before_goto", lambda page, context: print("[Hook] Preparing to navigate..."))
crawler.crawler_strategy.set_hook(
"before_goto",
lambda page, context: print("[Hook] Preparing to navigate..."),
)
# Perform the crawl operation
result = await crawler.arun(
url="https://crawl4ai.com"
)
result = await crawler.arun(url="https://crawl4ai.com")
print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- "))
@@ -417,16 +423,17 @@ async def cosine_similarity_extraction():
top_k=3, # Number of top keywords to extract
sim_threshold=0.3, # Similarity threshold for clustering
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
verbose=True
verbose=True,
),
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156",
config=crawl_config
config=crawl_config,
)
print(json.loads(result.extracted_content)[:5])
# Browser Comparison
async def crawl_custom_browser_type():
print("\n--- Browser Comparison ---")
@@ -484,18 +491,16 @@ async def crawl_with_user_simulation():
result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config)
print(result.markdown)
async def ssl_certification():
# Configure crawler to fetch SSL certificate
config = CrawlerRunConfig(
fetch_ssl_certificate=True,
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates
cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url='https://example.com',
config=config
)
result = await crawler.arun(url="https://example.com", config=config)
if result.success and result.ssl_certificate:
cert = result.ssl_certificate
@@ -511,12 +516,17 @@ async def ssl_certification():
print("\nCertificate exported to:")
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers
pem_data = cert.to_pem(
os.path.join(tmp_dir, "certificate.pem")
) # For web servers
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps
der_data = cert.to_der(
os.path.join(tmp_dir, "certificate.der")
) # For Java apps
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
# Speed Comparison
async def speed_comparison():
print("\n--- Speed Comparison ---")

View File

@@ -1,6 +1,10 @@
import os, sys
# append parent directory to system path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))); os.environ['FIRECRAWL_API_KEY'] = "fc-84b370ccfad44beabc686b38f1769692";
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
os.environ["FIRECRAWL_API_KEY"] = "fc-84b370ccfad44beabc686b38f1769692"
import asyncio
# import nest_asyncio
@@ -15,7 +19,7 @@ from bs4 import BeautifulSoup
from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import (
JsonCssExtractionStrategy,
LLMExtractionStrategy,
@@ -32,9 +36,12 @@ print("Website: https://crawl4ai.com")
async def simple_crawl():
print("\n--- Basic Usage ---")
async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun(url="https://www.nbcnews.com/business", cache_mode= CacheMode.BYPASS)
result = await crawler.arun(
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) # Print first 500 characters
async def simple_example_with_running_js_code():
print("\n--- Executing JavaScript and Using CSS Selectors ---")
# New code to handle the wait_for parameter
@@ -57,6 +64,7 @@ async def simple_example_with_running_js_code():
)
print(result.markdown[:500]) # Print first 500 characters
async def simple_example_with_css_selector():
print("\n--- Using CSS Selectors ---")
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -67,26 +75,27 @@ async def simple_example_with_css_selector():
)
print(result.markdown[:500]) # Print first 500 characters
async def use_proxy():
print("\n--- Using a Proxy ---")
print(
"Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example."
)
# Uncomment and modify the following lines to use a proxy
async with AsyncWebCrawler(verbose=True, proxy="http://your-proxy-url:port") as crawler:
async with AsyncWebCrawler(
verbose=True, proxy="http://your-proxy-url:port"
) as crawler:
result = await crawler.arun(
url="https://www.nbcnews.com/business",
cache_mode= CacheMode.BYPASS
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
)
if result.success:
print(result.markdown[:500]) # Print first 500 characters
async def capture_and_save_screenshot(url: str, output_path: str):
async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun(
url=url,
screenshot=True,
cache_mode= CacheMode.BYPASS
url=url, screenshot=True, cache_mode=CacheMode.BYPASS
)
if result.success and result.screenshot:
@@ -96,13 +105,14 @@ async def capture_and_save_screenshot(url: str, output_path: str):
screenshot_data = base64.b64decode(result.screenshot)
# Save the screenshot as a JPEG file
with open(output_path, 'wb') as f:
with open(output_path, "wb") as f:
f.write(screenshot_data)
print(f"Screenshot saved successfully to {output_path}")
else:
print("Failed to capture screenshot")
class OpenAIModelFee(BaseModel):
model_name: str = Field(..., description="Name of the OpenAI model.")
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
@@ -110,7 +120,10 @@ class OpenAIModelFee(BaseModel):
..., description="Fee for output token for the OpenAI model."
)
async def extract_structured_data_using_llm(provider: str, api_token: str = None, extra_headers: Dict[str, str] = None):
async def extract_structured_data_using_llm(
provider: str, api_token: str = None, extra_headers: Dict[str, str] = None
):
print(f"\n--- Extracting Structured Data with {provider} ---")
if api_token is None and provider != "ollama":
@@ -118,7 +131,7 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None
return
# extra_args = {}
extra_args={
extra_args = {
"temperature": 0,
"top_p": 0.9,
"max_tokens": 2000,
@@ -139,12 +152,13 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None
instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens.
Do not miss any models in the entire content. One extracted model JSON format should look like this:
{"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""",
extra_args=extra_args
extra_args=extra_args,
),
cache_mode=CacheMode.BYPASS,
)
print(result.extracted_content)
async def extract_structured_data_using_css_extractor():
print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---")
schema = {
@@ -175,16 +189,12 @@ async def extract_structured_data_using_css_extractor():
"name": "course_icon",
"selector": ".image-92",
"type": "attribute",
"attribute": "src"
"attribute": "src",
},
],
}
]
}
async with AsyncWebCrawler(
headless=True,
verbose=True
) as crawler:
async with AsyncWebCrawler(headless=True, verbose=True) as crawler:
# Create the JavaScript that handles clicking multiple times
js_click_tabs = """
(async () => {
@@ -204,13 +214,14 @@ async def extract_structured_data_using_css_extractor():
url="https://www.kidocode.com/degrees/technology",
extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True),
js_code=[js_click_tabs],
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
companies = json.loads(result.extracted_content)
print(f"Successfully extracted {len(companies)} companies")
print(json.dumps(companies[0], indent=2))
# Advanced Session-Based Crawling with Dynamic Content 🔄
async def crawl_dynamic_content_pages_method_1():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
@@ -267,6 +278,7 @@ async def crawl_dynamic_content_pages_method_1():
await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_dynamic_content_pages_method_2():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
@@ -334,8 +346,11 @@ async def crawl_dynamic_content_pages_method_2():
await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_dynamic_content_pages_method_3():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---")
print(
"\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---"
)
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://github.com/microsoft/TypeScript/commits/main"
@@ -395,41 +410,54 @@ async def crawl_dynamic_content_pages_method_3():
await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_custom_browser_type():
# Use Firefox
start = time.time()
async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler:
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
async with AsyncWebCrawler(
browser_type="firefox", verbose=True, headless=True
) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500])
print("Time taken: ", time.time() - start)
# Use WebKit
start = time.time()
async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler:
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
async with AsyncWebCrawler(
browser_type="webkit", verbose=True, headless=True
) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500])
print("Time taken: ", time.time() - start)
# Use Chromium (default)
start = time.time()
async with AsyncWebCrawler(verbose=True, headless = True) as crawler:
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500])
print("Time taken: ", time.time() - start)
async def crawl_with_user_simultion():
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
url = "YOUR-URL-HERE"
result = await crawler.arun(
url=url,
cache_mode=CacheMode.BYPASS,
magic = True, # Automatically detects and removes overlays, popups, and other elements that block content
magic=True, # Automatically detects and removes overlays, popups, and other elements that block content
# simulate_user = True,# Causes a series of random mouse movements and clicks to simulate user interaction
# override_navigator = True # Overrides the navigator object to make it look like a real user
)
print(result.markdown)
async def speed_comparison():
# print("\n--- Speed Comparison ---")
# print("Firecrawl (simulated):")
@@ -439,11 +467,11 @@ async def speed_comparison():
# print()
# Simulated Firecrawl performance
from firecrawl import FirecrawlApp
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
start = time.time()
scrape_status = app.scrape_url(
'https://www.nbcnews.com/business',
params={'formats': ['markdown', 'html']}
"https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
)
end = time.time()
print("Firecrawl:")
@@ -474,7 +502,9 @@ async def speed_comparison():
url="https://www.nbcnews.com/business",
word_count_threshold=0,
markdown_generator=DefaultMarkdownGenerator(
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
),
cache_mode=CacheMode.BYPASS,
@@ -498,7 +528,9 @@ async def speed_comparison():
word_count_threshold=0,
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
),
verbose=False,
@@ -520,6 +552,7 @@ async def speed_comparison():
print("If you run these tests in an environment with better network conditions,")
print("you may observe an even more significant speed advantage for Crawl4AI.")
async def generate_knowledge_graph():
class Entity(BaseModel):
name: str
@@ -536,11 +569,11 @@ async def generate_knowledge_graph():
relationships: List[Relationship]
extraction_strategy = LLMExtractionStrategy(
provider='openai/gpt-4o-mini', # Or any other provider, including Ollama and open source models
api_token=os.getenv('OPENAI_API_KEY'), # In case of Ollama just pass "no-token"
provider="openai/gpt-4o-mini", # Or any other provider, including Ollama and open source models
api_token=os.getenv("OPENAI_API_KEY"), # In case of Ollama just pass "no-token"
schema=KnowledgeGraph.model_json_schema(),
extraction_type="schema",
instruction="""Extract entities and relationships from the given text."""
instruction="""Extract entities and relationships from the given text.""",
)
async with AsyncWebCrawler() as crawler:
url = "https://paulgraham.com/love.html"
@@ -554,27 +587,22 @@ async def generate_knowledge_graph():
with open(os.path.join(__location__, "kb.json"), "w") as f:
f.write(result.extracted_content)
async def fit_markdown_remove_overlay():
async def fit_markdown_remove_overlay():
async with AsyncWebCrawler(
headless=True, # Set to False to see what is happening
verbose=True,
user_agent_mode="random",
user_agent_generator_config={
"device_type": "mobile",
"os_type": "android"
},
user_agent_generator_config={"device_type": "mobile", "os_type": "android"},
) as crawler:
result = await crawler.arun(
url='https://www.kidocode.com/degrees/technology',
url="https://www.kidocode.com/degrees/technology",
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
),
options={
"ignore_links": True
}
options={"ignore_links": True},
),
# markdown_generator=DefaultMarkdownGenerator(
# content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0),
@@ -593,13 +621,20 @@ async def fit_markdown_remove_overlay():
with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f:
f.write(result.cleaned_html)
with open(os.path.join(__location__, "output/output_raw_markdown.md"), "w") as f:
with open(
os.path.join(__location__, "output/output_raw_markdown.md"), "w"
) as f:
f.write(result.markdown_v2.raw_markdown)
with open(os.path.join(__location__, "output/output_markdown_with_citations.md"), "w") as f:
with open(
os.path.join(__location__, "output/output_markdown_with_citations.md"),
"w",
) as f:
f.write(result.markdown_v2.markdown_with_citations)
with open(os.path.join(__location__, "output/output_fit_markdown.md"), "w") as f:
with open(
os.path.join(__location__, "output/output_fit_markdown.md"), "w"
) as f:
f.write(result.markdown_v2.fit_markdown)
print("Done")

View File

@@ -10,15 +10,17 @@ from functools import lru_cache
console = Console()
@lru_cache()
def create_crawler():
crawler = WebCrawler(verbose=True)
crawler.warmup()
return crawler
def print_result(result):
# Print each key in one line and just the first 10 characters of each one's value and three dots
console.print(f"\t[bold]Result:[/bold]")
console.print("\t[bold]Result:[/bold]")
for key, value in result.model_dump().items():
if isinstance(value, str) and value:
console.print(f"\t{key}: [green]{value[:20]}...[/green]")
@@ -33,18 +35,27 @@ def cprint(message, press_any_key=False):
console.print("Press any key to continue...", style="")
input()
def basic_usage(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business", only_text = True)
cprint(
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
)
result = crawler.run(url="https://www.nbcnews.com/business", only_text=True)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result)
def basic_usage_some_params(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business", word_count_threshold=1, only_text = True)
cprint(
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
)
result = crawler.run(
url="https://www.nbcnews.com/business", word_count_threshold=1, only_text=True
)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result)
def screenshot_usage(crawler):
cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
@@ -55,16 +66,23 @@ def screenshot_usage(crawler):
cprint("Screenshot saved to 'screenshot.png'!")
print_result(result)
def understanding_parameters(crawler):
cprint("\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]")
cprint("By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action.")
cprint(
"\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]"
)
cprint(
"By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action."
)
# First crawl (reads from cache)
cprint("1⃣ First crawl (caches the result):", True)
start_time = time.time()
result = crawler.run(url="https://www.nbcnews.com/business")
end_time = time.time()
cprint(f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]")
cprint(
f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]"
)
print_result(result)
# Force to crawl again
@@ -72,132 +90,194 @@ def understanding_parameters(crawler):
start_time = time.time()
result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True)
end_time = time.time()
cprint(f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]")
cprint(
f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]"
)
print_result(result)
def add_chunking_strategy(crawler):
# Adding a chunking strategy: RegexChunking
cprint("\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", True)
cprint("RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!")
cprint(
"\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]",
True,
)
cprint(
"RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!"
)
result = crawler.run(
url="https://www.nbcnews.com/business",
chunking_strategy=RegexChunking(patterns=["\n\n"])
chunking_strategy=RegexChunking(patterns=["\n\n"]),
)
cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]")
print_result(result)
# Adding another chunking strategy: NlpSentenceChunking
cprint("\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", True)
cprint("NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!")
cprint(
"\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]",
True,
)
cprint(
"NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!"
)
result = crawler.run(
url="https://www.nbcnews.com/business",
chunking_strategy=NlpSentenceChunking()
url="https://www.nbcnews.com/business", chunking_strategy=NlpSentenceChunking()
)
cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]")
print_result(result)
def add_extraction_strategy(crawler):
# Adding an extraction strategy: CosineStrategy
cprint("\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", True)
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
cprint(
"\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]",
True,
)
cprint(
"CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!"
)
result = crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True)
extraction_strategy=CosineStrategy(
word_count_threshold=10,
max_dist=0.2,
linkage_method="ward",
top_k=3,
sim_threshold=0.3,
verbose=True,
),
)
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
print_result(result)
# Using semantic_filter with CosineStrategy
cprint("You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!")
cprint(
"You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!"
)
result = crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy(
semantic_filter="inflation rent prices",
),
)
cprint(
"[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]"
)
cprint("[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]")
print_result(result)
def add_llm_extraction_strategy(crawler):
# Adding an LLM extraction strategy without instructions
cprint("\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", True)
cprint("LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!")
cprint(
"\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]",
True,
)
cprint(
"LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!"
)
result = crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-4o", api_token=os.getenv('OPENAI_API_KEY'))
extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", api_token=os.getenv("OPENAI_API_KEY")
),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]"
)
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]")
print_result(result)
# Adding an LLM extraction strategy with instructions
cprint("\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", True)
cprint("Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!")
cprint(
"\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]",
True,
)
cprint(
"Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!"
)
result = crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o",
api_token=os.getenv('OPENAI_API_KEY'),
instruction="I am interested in only financial news"
api_token=os.getenv("OPENAI_API_KEY"),
instruction="I am interested in only financial news",
),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]"
)
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]")
print_result(result)
result = crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o",
api_token=os.getenv('OPENAI_API_KEY'),
instruction="Extract only content related to technology"
api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract only content related to technology",
),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]"
)
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]")
print_result(result)
def targeted_extraction(crawler):
# Using a CSS selector to extract only H2 tags
cprint("\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", True)
result = crawler.run(
url="https://www.nbcnews.com/business",
css_selector="h2"
cprint(
"\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]",
True,
)
result = crawler.run(url="https://www.nbcnews.com/business", css_selector="h2")
cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]")
print_result(result)
def interactive_extraction(crawler):
# Passing JavaScript code to interact with the page
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True)
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.")
cprint(
"\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
True,
)
cprint(
"In this example we try to click the 'Load More' button on the page using JavaScript code."
)
js_code = """
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click();
"""
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
result = crawler.run(
url="https://www.nbcnews.com/business",
js = js_code
result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
cprint(
"[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
)
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
print_result(result)
def multiple_scrip(crawler):
# Passing JavaScript code to interact with the page
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True)
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.")
js_code = ["""
cprint(
"\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
True,
)
cprint(
"In this example we try to click the 'Load More' button on the page using JavaScript code."
)
js_code = [
"""
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click();
"""] * 2
"""
] * 2
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
result = crawler.run(
url="https://www.nbcnews.com/business",
js = js_code
result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
cprint(
"[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
)
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
print_result(result)
def using_crawler_hooks(crawler):
# Example usage of the hooks for authentication and setting a cookie
def on_driver_created(driver):
@@ -206,33 +286,34 @@ def using_crawler_hooks(crawler):
driver.maximize_window()
# Example customization: logging in to a hypothetical website
driver.get('https://example.com/login')
driver.get("https://example.com/login")
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.NAME, 'username'))
EC.presence_of_element_located((By.NAME, "username"))
)
driver.find_element(By.NAME, 'username').send_keys('testuser')
driver.find_element(By.NAME, 'password').send_keys('password123')
driver.find_element(By.NAME, 'login').click()
driver.find_element(By.NAME, "username").send_keys("testuser")
driver.find_element(By.NAME, "password").send_keys("password123")
driver.find_element(By.NAME, "login").click()
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'welcome'))
EC.presence_of_element_located((By.ID, "welcome"))
)
# Add a custom cookie
driver.add_cookie({'name': 'test_cookie', 'value': 'cookie_value'})
driver.add_cookie({"name": "test_cookie", "value": "cookie_value"})
return driver
def before_get_url(driver):
print("[HOOK] before_get_url")
# Example customization: add a custom header
# Enable Network domain for sending headers
driver.execute_cdp_cmd('Network.enable', {})
driver.execute_cdp_cmd("Network.enable", {})
# Add a custom header
driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': {'X-Test-Header': 'test'}})
driver.execute_cdp_cmd(
"Network.setExtraHTTPHeaders", {"headers": {"X-Test-Header": "test"}}
)
return driver
def after_get_url(driver):
@@ -247,20 +328,24 @@ def using_crawler_hooks(crawler):
print(len(html))
return driver
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", True)
cprint(
"\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]",
True,
)
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
crawler_strategy.set_hook('on_driver_created', on_driver_created)
crawler_strategy.set_hook('before_get_url', before_get_url)
crawler_strategy.set_hook('after_get_url', after_get_url)
crawler_strategy.set_hook('before_return_html', before_return_html)
crawler_strategy.set_hook("on_driver_created", on_driver_created)
crawler_strategy.set_hook("before_get_url", before_get_url)
crawler_strategy.set_hook("after_get_url", after_get_url)
crawler_strategy.set_hook("before_return_html", before_return_html)
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
crawler.warmup()
result = crawler.run(url="https://example.com")
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
print_result(result= result)
print_result(result=result)
def using_crawler_hooks_dleay_example(crawler):
def delay(driver):
@@ -270,12 +355,14 @@ def using_crawler_hooks_dleay_example(crawler):
def create_crawler():
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
crawler_strategy.set_hook('after_get_url', delay)
crawler_strategy.set_hook("after_get_url", delay)
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
crawler.warmup()
return crawler
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]")
cprint(
"\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]"
)
crawler = create_crawler()
result = crawler.run(url="https://google.com", bypass_cache=True)
@@ -283,11 +370,16 @@ def using_crawler_hooks_dleay_example(crawler):
print_result(result)
def main():
cprint("🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]")
cprint("⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]")
cprint("If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files.")
cprint(
"🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]"
)
cprint(
"⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]"
)
cprint(
"If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files."
)
crawler = create_crawler()
@@ -305,8 +397,10 @@ def main():
interactive_extraction(crawler)
multiple_scrip(crawler)
cprint("\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]")
cprint(
"\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]"
)
if __name__ == "__main__":
main()

View File

@@ -11,7 +11,9 @@ from groq import Groq
# Import threadpools to run the crawl_url function in a separate thread
from concurrent.futures import ThreadPoolExecutor
client = AsyncOpenAI(base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY"))
client = AsyncOpenAI(
base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")
)
# Instrument the OpenAI client
cl.instrument_openai()
@@ -25,32 +27,31 @@ settings = {
"presence_penalty": 0,
}
def extract_urls(text):
url_pattern = re.compile(r'(https?://\S+)')
url_pattern = re.compile(r"(https?://\S+)")
return url_pattern.findall(text)
def crawl_url(url):
data = {
"urls": [url],
"include_raw_html": True,
"word_count_threshold": 10,
"extraction_strategy": "NoExtractionStrategy",
"chunking_strategy": "RegexChunking"
"chunking_strategy": "RegexChunking",
}
response = requests.post("https://crawl4ai.com/crawl", json=data)
response_data = response.json()
response_data = response_data['results'][0]
return response_data['markdown']
response_data = response_data["results"][0]
return response_data["markdown"]
@cl.on_chat_start
async def on_chat_start():
cl.user_session.set("session", {
"history": [],
"context": {}
})
await cl.Message(
content="Welcome to the chat! How can I assist you today?"
).send()
cl.user_session.set("session", {"history": [], "context": {}})
await cl.Message(content="Welcome to the chat! How can I assist you today?").send()
@cl.on_message
async def on_message(message: cl.Message):
@@ -59,7 +60,6 @@ async def on_message(message: cl.Message):
# Extract URLs from the user's message
urls = extract_urls(message.content)
futures = []
with ThreadPoolExecutor() as executor:
for url in urls:
@@ -69,16 +69,9 @@ async def on_message(message: cl.Message):
for url, result in zip(urls, results):
ref_number = f"REF_{len(user_session['context']) + 1}"
user_session["context"][ref_number] = {
"url": url,
"content": result
}
user_session["context"][ref_number] = {"url": url, "content": result}
user_session["history"].append({
"role": "user",
"content": message.content
})
user_session["history"].append({"role": "user", "content": message.content})
# Create a system message that includes the context
context_messages = [
@@ -95,26 +88,17 @@ async def on_message(message: cl.Message):
"If not, there is no need to add a references section. "
"At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n"
"\n\n".join(context_messages)
)
),
}
else:
system_message = {
"role": "system",
"content": "You are a helpful assistant."
}
system_message = {"role": "system", "content": "You are a helpful assistant."}
msg = cl.Message(content="")
await msg.send()
# Get response from the LLM
stream = await client.chat.completions.create(
messages=[
system_message,
*user_session["history"]
],
stream=True,
**settings
messages=[system_message, *user_session["history"]], stream=True, **settings
)
assistant_response = ""
@@ -124,10 +108,7 @@ async def on_message(message: cl.Message):
await msg.stream_token(token)
# Add assistant message to the history
user_session["history"].append({
"role": "assistant",
"content": assistant_response
})
user_session["history"].append({"role": "assistant", "content": assistant_response})
await msg.update()
# Append the reference section to the assistant's response
@@ -154,6 +135,7 @@ async def on_audio_chunk(chunk: cl.AudioChunk):
pass
@cl.step(type="tool")
async def speech_to_text(audio_file):
cli = Groq()
@@ -179,17 +161,12 @@ async def on_audio_end(elements: list[ElementBased]):
end_time = time.time()
print(f"Transcription took {end_time - start_time} seconds")
user_msg = cl.Message(
author="You",
type="user_message",
content=transcription
)
user_msg = cl.Message(author="You", type="user_message", content=transcription)
await user_msg.send()
await on_message(user_msg)
if __name__ == "__main__":
from chainlit.cli import run_chainlit
run_chainlit(__file__)

View File

@@ -1,4 +1,3 @@
import requests, base64, os
data = {
@@ -7,58 +6,49 @@ data = {
}
response = requests.post("https://crawl4ai.com/crawl", json=data)
result = response.json()['results'][0]
result = response.json()["results"][0]
print(result.keys())
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
# 'links', 'screenshot', 'markdown', 'extracted_content',
# 'metadata', 'error_message'])
with open("screenshot.png", "wb") as f:
f.write(base64.b64decode(result['screenshot']))
f.write(base64.b64decode(result["screenshot"]))
# Example of filtering the content using CSS selectors
data = {
"urls": [
"https://www.nbcnews.com/business"
],
"urls": ["https://www.nbcnews.com/business"],
"css_selector": "article",
"screenshot": True,
}
# Example of executing a JS script on the page before extracting the content
data = {
"urls": [
"https://www.nbcnews.com/business"
],
"urls": ["https://www.nbcnews.com/business"],
"screenshot": True,
'js' : ["""
"js": [
"""
const loadMoreButton = Array.from(document.querySelectorAll('button')).
find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click();
"""]
"""
],
}
# Example of using a custom extraction strategy
data = {
"urls": [
"https://www.nbcnews.com/business"
],
"urls": ["https://www.nbcnews.com/business"],
"extraction_strategy": "CosineStrategy",
"extraction_strategy_args": {
"semantic_filter": "inflation rent prices"
},
"extraction_strategy_args": {"semantic_filter": "inflation rent prices"},
}
# Example of using LLM to extract content
data = {
"urls": [
"https://www.nbcnews.com/business"
],
"urls": ["https://www.nbcnews.com/business"],
"extraction_strategy": "LLMExtractionStrategy",
"extraction_strategy_args": {
"provider": "groq/llama3-8b-8192",
"api_token": os.environ.get("GROQ_API_KEY"),
"instruction": """I am interested in only financial news,
and translate them in French."""
and translate them in French.""",
},
}

View File

@@ -5,22 +5,22 @@ import os
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode
# Create tmp directory if it doesn't exist
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
tmp_dir = os.path.join(parent_dir, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
async def main():
# Configure crawler to fetch SSL certificate
config = CrawlerRunConfig(
fetch_ssl_certificate=True,
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates
cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url='https://example.com',
config=config
)
result = await crawler.arun(url="https://example.com", config=config)
if result.success and result.ssl_certificate:
cert = result.ssl_certificate
@@ -36,11 +36,16 @@ async def main():
print("\nCertificate exported to:")
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers
pem_data = cert.to_pem(
os.path.join(tmp_dir, "certificate.pem")
) # For web servers
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps
der_data = cert.to_der(
os.path.join(tmp_dir, "certificate.der")
) # For Java apps
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,39 +1,41 @@
import os
import time
import json
from crawl4ai.web_crawler import WebCrawler
from crawl4ai.chunking_strategy import *
from crawl4ai.extraction_strategy import *
from crawl4ai.crawler_strategy import *
url = r'https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot'
url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot"
crawler = WebCrawler()
crawler.warmup()
from pydantic import BaseModel, Field
class PageSummary(BaseModel):
title: str = Field(..., description="Title of the page.")
summary: str = Field(..., description="Summary of the page.")
brief_summary: str = Field(..., description="Brief summary of the page.")
keywords: list = Field(..., description="Keywords assigned to the page.")
result = crawler.run(
url=url,
word_count_threshold=1,
extraction_strategy= LLMExtractionStrategy(
provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o",
api_token=os.getenv("OPENAI_API_KEY"),
schema=PageSummary.model_json_schema(),
extraction_type="schema",
apply_chunking =False,
instruction="From the crawled content, extract the following details: "\
"1. Title of the page "\
"2. Summary of the page, which is a detailed summary "\
"3. Brief summary of the page, which is a paragraph text "\
"4. Keywords assigned to the page, which is a list of keywords. "\
'The extracted JSON format should look like this: '\
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }'
apply_chunking=False,
instruction="From the crawled content, extract the following details: "
"1. Title of the page "
"2. Summary of the page, which is a detailed summary "
"3. Brief summary of the page, which is a paragraph text "
"4. Keywords assigned to the page, which is a list of keywords. "
"The extracted JSON format should look like this: "
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }',
),
bypass_cache=True,
)

View File

@@ -1,4 +1,5 @@
import os, sys
# append the parent directory to the sys.path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
@@ -13,6 +14,7 @@ import json
from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.content_filter_strategy import BM25ContentFilter
# 1. File Download Processing Example
async def download_example():
"""Example of downloading files from Python.org"""
@@ -23,9 +25,7 @@ async def download_example():
print(f"Downloads will be saved to: {downloads_path}")
async with AsyncWebCrawler(
accept_downloads=True,
downloads_path=downloads_path,
verbose=True
accept_downloads=True, downloads_path=downloads_path, verbose=True
) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
@@ -40,7 +40,7 @@ async def download_example():
}
""",
delay_before_return_html=1, # Wait 5 seconds to ensure download starts
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
if result.downloaded_files:
@@ -52,24 +52,25 @@ async def download_example():
else:
print("\nNo files were downloaded")
# 2. Local File and Raw HTML Processing Example
async def local_and_raw_html_example():
"""Example of processing local files and raw HTML"""
# Create a sample HTML file
sample_file = os.path.join(__data__, "sample.html")
with open(sample_file, "w") as f:
f.write("""
f.write(
"""
<html><body>
<h1>Test Content</h1>
<p>This is a test paragraph.</p>
</body></html>
""")
"""
)
async with AsyncWebCrawler(verbose=True) as crawler:
# Process local file
local_result = await crawler.arun(
url=f"file://{os.path.abspath(sample_file)}"
)
local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}")
# Process raw HTML
raw_html = """
@@ -78,9 +79,7 @@ async def local_and_raw_html_example():
<p>This is a test of raw HTML processing.</p>
</body></html>
"""
raw_result = await crawler.arun(
url=f"raw:{raw_html}"
)
raw_result = await crawler.arun(url=f"raw:{raw_html}")
# Clean up
os.remove(sample_file)
@@ -88,6 +87,7 @@ async def local_and_raw_html_example():
print("Local file content:", local_result.markdown)
print("\nRaw HTML content:", raw_result.markdown)
# 3. Enhanced Markdown Generation Example
async def markdown_generation_example():
"""Example of enhanced markdown generation with citations and LLM-friendly features"""
@@ -102,27 +102,32 @@ async def markdown_generation_example():
url="https://en.wikipedia.org/wiki/Apple",
css_selector="main div#bodyContent",
content_filter=content_filter,
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
from crawl4ai import AsyncWebCrawler
from crawl4ai.content_filter_strategy import BM25ContentFilter
result = await crawler.arun(
url="https://en.wikipedia.org/wiki/Apple",
css_selector="main div#bodyContent",
content_filter=BM25ContentFilter()
content_filter=BM25ContentFilter(),
)
print(result.markdown_v2.fit_markdown)
print("\nMarkdown Generation Results:")
print(f"1. Original markdown length: {len(result.markdown)}")
print(f"2. New markdown versions (markdown_v2):")
print("2. New markdown versions (markdown_v2):")
print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}")
print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}")
print(f" - References section length: {len(result.markdown_v2.references_markdown)}")
print(
f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}"
)
print(
f" - References section length: {len(result.markdown_v2.references_markdown)}"
)
if result.markdown_v2.fit_markdown:
print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}")
print(
f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}"
)
# Save examples to files
output_dir = os.path.join(__data__, "markdown_examples")
@@ -148,7 +153,10 @@ async def markdown_generation_example():
print("\nSample of markdown with citations:")
print(result.markdown_v2.markdown_with_citations[:500] + "...\n")
print("Sample of references:")
print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...")
print(
"\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..."
)
# 4. Browser Management Example
async def browser_management_example():
@@ -163,31 +171,31 @@ async def browser_management_example():
use_managed_browser=True,
user_data_dir=user_data_dir,
headless=False,
verbose=True
verbose=True,
) as crawler:
result = await crawler.arun(
url="https://crawl4ai.com",
# session_id="persistent_session_1",
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
# Use GitHub as an example - it's a good test for browser management
# because it requires proper browser handling
result = await crawler.arun(
url="https://github.com/trending",
# session_id="persistent_session_1",
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
print("\nBrowser session result:", result.success)
if result.success:
print("Page title:", result.metadata.get('title', 'No title found'))
print("Page title:", result.metadata.get("title", "No title found"))
# 5. API Usage Example
async def api_example():
"""Example of using the new API endpoints"""
api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code"
headers = {'Authorization': f'Bearer {api_token}'}
api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
headers = {"Authorization": f"Bearer {api_token}"}
async with aiohttp.ClientSession() as session:
# Submit crawl job
crawl_request = {
@@ -199,26 +207,18 @@ async def api_example():
"name": "Hacker News Articles",
"baseSelector": ".athing",
"fields": [
{
"name": "title",
"selector": ".title a",
"type": "text"
},
{
"name": "score",
"selector": ".score",
"type": "text"
},
{"name": "title", "selector": ".title a", "type": "text"},
{"name": "score", "selector": ".score", "type": "text"},
{
"name": "url",
"selector": ".title a",
"type": "attribute",
"attribute": "href"
}
]
}
"attribute": "href",
},
],
}
},
},
"crawler_params": {
"headless": True,
# "use_managed_browser": True
@@ -229,9 +229,7 @@ async def api_example():
}
async with session.post(
"http://localhost:11235/crawl",
json=crawl_request,
headers=headers
"http://localhost:11235/crawl", json=crawl_request, headers=headers
) as response:
task_data = await response.json()
task_id = task_data["task_id"]
@@ -239,8 +237,7 @@ async def api_example():
# Check task status
while True:
async with session.get(
f"http://localhost:11235/task/{task_id}",
headers=headers
f"http://localhost:11235/task/{task_id}", headers=headers
) as status_response:
result = await status_response.json()
print(f"Task status: {result['status']}")
@@ -248,12 +245,13 @@ async def api_example():
if result["status"] == "completed":
print("Task completed!")
print("Results:")
news = json.loads(result["results"][0]['extracted_content'])
news = json.loads(result["results"][0]["extracted_content"])
print(json.dumps(news[:4], indent=2))
break
else:
await asyncio.sleep(1)
# Main execution
async def main():
# print("Running Crawl4AI feature examples...")
@@ -273,5 +271,6 @@ async def main():
# print("\n5. Running API Example:")
await api_example()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -10,15 +10,14 @@ import asyncio
import os
import json
import re
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from typing import List
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
CacheMode,
LLMExtractionStrategy,
JsonCssExtractionStrategy
JsonCssExtractionStrategy,
)
from crawl4ai.content_filter_strategy import RelevantContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
@@ -52,6 +51,7 @@ SAMPLE_HTML = """
</div>
"""
async def demo_ssl_features():
"""
Enhanced SSL & Security Features Demo
@@ -76,14 +76,11 @@ async def demo_ssl_features():
run_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
fetch_ssl_certificate=True # Enable SSL certificate fetching
fetch_ssl_certificate=True, # Enable SSL certificate fetching
)
async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun(
url="https://example.com",
config=run_config
)
result = await crawler.arun(url="https://example.com", config=run_config)
print(f"SSL Crawl Success: {result.success}")
result.ssl_certificate.to_json(
os.path.join(os.getcwd(), "ssl_certificate.json")
@@ -91,6 +88,7 @@ async def demo_ssl_features():
if not result.success:
print(f"SSL Error: {result.error_message}")
async def demo_content_filtering():
"""
Smart Content Filtering Demo
@@ -110,12 +108,14 @@ async def demo_content_filtering():
super().__init__()
# Add news-specific patterns
self.negative_patterns = re.compile(
r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending',
re.I
r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending",
re.I,
)
self.min_word_count = 30 # Higher threshold for news content
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
def filter_content(
self, html: str, min_word_threshold: int = None
) -> List[str]:
"""
Implements news-specific content filtering logic.
@@ -129,14 +129,16 @@ async def demo_content_filtering():
if not html or not isinstance(html, str):
return []
soup = BeautifulSoup(html, 'lxml')
soup = BeautifulSoup(html, "lxml")
if not soup.body:
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
body = soup.find('body')
body = soup.find("body")
# Extract chunks with metadata
chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count)
chunks = self.extract_text_chunks(
body, min_word_threshold or self.min_word_count
)
# Filter chunks based on news-specific criteria
filtered_chunks = []
@@ -146,7 +148,7 @@ async def demo_content_filtering():
continue
# Headers are important in news articles
if tag_type == 'header':
if tag_type == "header":
filtered_chunks.append(self.clean_element(element))
continue
@@ -154,7 +156,9 @@ async def demo_content_filtering():
text = element.get_text(strip=True)
if len(text.split()) >= (min_word_threshold or self.min_word_count):
# Calculate link density
links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a'))
links_text = " ".join(
a.get_text(strip=True) for a in element.find_all("a")
)
link_density = len(links_text) / len(text) if text else 1
# Accept if link density is reasonable
@@ -164,23 +168,20 @@ async def demo_content_filtering():
return filtered_chunks
# Create markdown generator with custom filter
markdown_gen = DefaultMarkdownGenerator(
content_filter=CustomNewsFilter()
)
markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter())
run_config = CrawlerRunConfig(
markdown_generator=markdown_gen,
cache_mode=CacheMode.BYPASS
markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url="https://news.ycombinator.com",
config=run_config
url="https://news.ycombinator.com", config=run_config
)
print("Filtered Content Sample:")
print(result.markdown[:500]) # Show first 500 chars
async def demo_json_extraction():
"""
Improved JSON Extraction Demo
@@ -206,7 +207,7 @@ async def demo_json_extraction():
"baseSelector": "div.article-list",
"baseFields": [
{"name": "list_id", "type": "attribute", "attribute": "data-list-id"},
{"name": "category", "type": "attribute", "attribute": "data-category"}
{"name": "category", "type": "attribute", "attribute": "data-category"},
],
"fields": [
{
@@ -214,8 +215,16 @@ async def demo_json_extraction():
"selector": "article.post",
"type": "nested_list",
"baseFields": [
{"name": "post_id", "type": "attribute", "attribute": "data-post-id"},
{"name": "author_id", "type": "attribute", "attribute": "data-author"}
{
"name": "post_id",
"type": "attribute",
"attribute": "data-post-id",
},
{
"name": "author_id",
"type": "attribute",
"attribute": "data-author",
},
],
"fields": [
{
@@ -223,51 +232,59 @@ async def demo_json_extraction():
"selector": "h2.title a",
"type": "text",
"baseFields": [
{"name": "url", "type": "attribute", "attribute": "href"}
]
{
"name": "url",
"type": "attribute",
"attribute": "href",
}
],
},
{
"name": "author",
"selector": "div.meta a.author",
"type": "text",
"baseFields": [
{"name": "profile_url", "type": "attribute", "attribute": "href"}
]
},
{
"name": "date",
"selector": "span.date",
"type": "text"
"name": "profile_url",
"type": "attribute",
"attribute": "href",
}
],
},
{"name": "date", "selector": "span.date", "type": "text"},
{
"name": "read_more",
"selector": "a.read-more",
"type": "nested",
"fields": [
{"name": "text", "type": "text"},
{"name": "url", "type": "attribute", "attribute": "href"}
]
{
"name": "url",
"type": "attribute",
"attribute": "href",
},
],
},
],
}
]
}
]
],
}
)
# Demonstrate extraction from raw HTML
run_config = CrawlerRunConfig(
extraction_strategy=json_strategy,
cache_mode=CacheMode.BYPASS
extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML
config=run_config
config=run_config,
)
print("Extracted Content:")
print(result.extracted_content)
async def demo_input_formats():
"""
Input Format Handling Demo
@@ -359,18 +376,30 @@ async def demo_input_formats():
# Define our schema using Pydantic
class JobRequirement(BaseModel):
category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)")
items: List[str] = Field(description="List of specific requirements in this category")
priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context")
category: str = Field(
description="Category of the requirement (e.g., Technical, Soft Skills)"
)
items: List[str] = Field(
description="List of specific requirements in this category"
)
priority: str = Field(
description="Priority level (Required/Preferred) based on the HTML class or context"
)
class JobPosting(BaseModel):
title: str = Field(description="Job title")
department: str = Field(description="Department or team")
location: str = Field(description="Job location, including remote options")
salary_range: Optional[str] = Field(description="Salary range if specified")
requirements: List[JobRequirement] = Field(description="Categorized job requirements")
application_deadline: Optional[str] = Field(description="Application deadline if specified")
contact_info: Optional[dict] = Field(description="Contact information from footer or contact section")
requirements: List[JobRequirement] = Field(
description="Categorized job requirements"
)
application_deadline: Optional[str] = Field(
description="Application deadline if specified"
)
contact_info: Optional[dict] = Field(
description="Contact information from footer or contact section"
)
# First try with markdown (default)
markdown_strategy = LLMExtractionStrategy(
@@ -382,7 +411,7 @@ async def demo_input_formats():
Extract job posting details into structured data. Focus on the visible text content
and organize requirements into categories.
""",
input_format="markdown" # default
input_format="markdown", # default
)
# Then with HTML for better structure understanding
@@ -400,34 +429,25 @@ async def demo_input_formats():
Use HTML attributes and classes to enhance extraction accuracy.
""",
input_format="html" # explicitly use HTML
input_format="html", # explicitly use HTML
)
async with AsyncWebCrawler() as crawler:
# Try with markdown first
markdown_config = CrawlerRunConfig(
extraction_strategy=markdown_strategy
)
markdown_result = await crawler.arun(
url=url,
config=markdown_config
)
markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy)
markdown_result = await crawler.arun(url=url, config=markdown_config)
print("\nMarkdown-based Extraction Result:")
items = json.loads(markdown_result.extracted_content)
print(json.dumps(items, indent=2))
# Then with HTML for better structure understanding
html_config = CrawlerRunConfig(
extraction_strategy=html_strategy
)
html_result = await crawler.arun(
url=url,
config=html_config
)
html_config = CrawlerRunConfig(extraction_strategy=html_strategy)
html_result = await crawler.arun(url=url, config=html_config)
print("\nHTML-based Extraction Result:")
items = json.loads(html_result.extracted_content)
print(json.dumps(items, indent=2))
# Main execution
async def main():
print("Crawl4AI v0.4.24 Feature Walkthrough")
@@ -439,5 +459,6 @@ async def main():
await demo_json_extraction()
# await demo_input_formats()
if __name__ == "__main__":
asyncio.run(main())

78
main.py
View File

@@ -1,14 +1,9 @@
import asyncio, os
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
from fastapi.responses import JSONResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import FileResponse
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends, Security
@@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union
import psutil
import time
import uuid
from collections import defaultdict
from urllib.parse import urlparse
import math
import logging
from enum import Enum
from dataclasses import dataclass
import json
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
from crawl4ai.config import MIN_WORD_THRESHOLD
from crawl4ai.extraction_strategy import (
@@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TaskStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class CrawlerType(str, Enum):
BASIC = "basic"
LLM = "llm"
COSINE = "cosine"
JSON_CSS = "json_css"
class ExtractionConfig(BaseModel):
type: CrawlerType
params: Dict[str, Any] = {}
class ChunkingStrategy(BaseModel):
type: str
params: Dict[str, Any] = {}
class ContentFilter(BaseModel):
type: str = "bm25"
params: Dict[str, Any] = {}
class CrawlRequest(BaseModel):
urls: Union[HttpUrl, List[HttpUrl]]
word_count_threshold: int = MIN_WORD_THRESHOLD
@@ -80,6 +78,7 @@ class CrawlRequest(BaseModel):
ttl: Optional[int] = 3600
crawler_params: Dict[str, Any] = {}
@dataclass
class TaskInfo:
id: str
@@ -89,6 +88,7 @@ class TaskInfo:
created_at: float = time.time()
ttl: int = 3600
class ResourceMonitor:
def __init__(self, max_concurrent_tasks: int = 10):
self.max_concurrent_tasks = max_concurrent_tasks
@@ -106,7 +106,9 @@ class ResourceMonitor:
mem_usage = psutil.virtual_memory().percent / 100
cpu_usage = psutil.cpu_percent() / 100
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold)
memory_factor = max(
0, (self.memory_threshold - mem_usage) / self.memory_threshold
)
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
self._last_available_slots = math.floor(
@@ -116,6 +118,7 @@ class ResourceMonitor:
return self._last_available_slots
class TaskManager:
def __init__(self, cleanup_interval: int = 300):
self.tasks: Dict[str, TaskInfo] = {}
@@ -149,12 +152,16 @@ class TaskManager:
except asyncio.TimeoutError:
try:
# Then try low priority
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1)
_, task_id = await asyncio.wait_for(
self.low_priority.get(), timeout=0.1
)
return task_id
except asyncio.TimeoutError:
return None
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None):
def update_task(
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
):
if task_id in self.tasks:
task_info = self.tasks[task_id]
task_info.status = status
@@ -180,6 +187,7 @@ class TaskManager:
except Exception as e:
logger.error(f"Error in cleanup loop: {e}")
class CrawlerPool:
def __init__(self, max_size: int = 10):
self.max_size = max_size
@@ -222,6 +230,7 @@ class CrawlerPool:
await crawler.__aexit__(None, None, None)
self.active_crawlers.clear()
class CrawlerService:
def __init__(self, max_concurrent_tasks: int = 10):
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
@@ -287,7 +296,9 @@ class CrawlerService:
try:
crawler = await self.crawler_pool.acquire(**request.crawler_params)
extraction_strategy = self._create_extraction_strategy(request.extraction_config)
extraction_strategy = self._create_extraction_strategy(
request.extraction_config
)
if isinstance(request.urls, list):
results = await crawler.arun_many(
@@ -318,16 +329,21 @@ class CrawlerService:
)
await self.crawler_pool.release(crawler)
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results)
self.task_manager.update_task(
task_id, TaskStatus.COMPLETED, results
)
except Exception as e:
logger.error(f"Error processing task {task_id}: {str(e)}")
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e))
self.task_manager.update_task(
task_id, TaskStatus.FAILED, error=str(e)
)
except Exception as e:
logger.error(f"Error in queue processing: {str(e)}")
await asyncio.sleep(1)
app = FastAPI(title="Crawl4AI API")
# CORS configuration
@@ -344,6 +360,7 @@ app.add_middleware(
security = HTTPBearer()
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
if not CRAWL4AI_API_TOKEN:
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
@@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu
raise HTTPException(status_code=401, detail="Invalid token")
return credentials
def secure_endpoint():
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
# Check if site directory exists
if os.path.exists(__location__ + "/site"):
# Mount the site directory as a static directory
@@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site")
crawler_service = CrawlerService()
@app.on_event("startup")
async def startup_event():
await crawler_service.start()
@app.on_event("shutdown")
async def shutdown_event():
await crawler_service.stop()
@app.get("/")
def read_root():
if os.path.exists(__location__ + "/site"):
@@ -379,12 +401,16 @@ def read_root():
# Return a json response
return {"message": "Crawl4AI API service is running"}
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl(request: CrawlRequest) -> Dict[str, str]:
task_id = await crawler_service.submit_task(request)
return {"task_id": task_id}
@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
@app.get(
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def get_task_status(task_id: str):
task_info = crawler_service.task_manager.get_task(task_id)
if not task_info:
@@ -406,6 +432,7 @@ async def get_task_status(task_id: str):
return response
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
task_id = await crawler_service.submit_task(request)
@@ -419,7 +446,10 @@ async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
if task_info.status == TaskStatus.COMPLETED:
# Return same format as /task/{task_id} endpoint
if isinstance(task_info.result, list):
return {"status": task_info.status, "results": [result.dict() for result in task_info.result]}
return {
"status": task_info.status,
"results": [result.dict() for result in task_info.result],
}
return {"status": task_info.status, "result": task_info.result.dict()}
if task_info.status == TaskStatus.FAILED:
@@ -430,11 +460,16 @@ async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
# If we get here, task didn't complete within timeout
raise HTTPException(status_code=408, detail="Task timed out")
@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
@app.post(
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
try:
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config)
extraction_strategy = crawler_service._create_extraction_strategy(
request.extraction_config
)
try:
if isinstance(request.urls, list):
@@ -471,6 +506,7 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
logger.error(f"Error in direct crawl: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
available_slots = await crawler_service.resource_monitor.get_available_slots()
@@ -482,6 +518,8 @@ async def health_check():
"cpu_usage": psutil.cpu_percent(),
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=11235)

View File

@@ -51,9 +51,7 @@ setup(
author_email="unclecode@kidocode.com",
license="MIT",
packages=find_packages(),
package_data={
'crawl4ai': ['js_snippet/*.js']
},
package_data={"crawl4ai": ["js_snippet/*.js"]},
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",

View File

@@ -1,17 +1,18 @@
import os, sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__)))
import os, sys
import os
import sys
import asyncio
from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
# Assuming that the changes made allow different configurations
# for managed browser, persistent context, and so forth.
async def test_default_headless():
async with AsyncWebCrawler(
headless=True,
@@ -24,13 +25,14 @@ async def test_default_headless():
# Testing normal ephemeral context
) as crawler:
result = await crawler.arun(
url='https://www.kidocode.com/degrees/technology',
url="https://www.kidocode.com/degrees/technology",
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_default_headless] success:", result.success)
print("HTML length:", len(result.html if result.html else ""))
async def test_managed_browser_persistent():
# Treating use_persistent_context=True as managed_browser scenario.
async with AsyncWebCrawler(
@@ -44,13 +46,14 @@ async def test_managed_browser_persistent():
# This should store and reuse profile data across runs
) as crawler:
result = await crawler.arun(
url='https://www.google.com',
url="https://www.google.com",
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_managed_browser_persistent] success:", result.success)
print("HTML length:", len(result.html if result.html else ""))
async def test_session_reuse():
# Test creating a session, using it for multiple calls
session_id = "my_session"
@@ -62,25 +65,25 @@ async def test_session_reuse():
use_managed_browser=False,
use_persistent_context=False,
) as crawler:
# First call: create session
result1 = await crawler.arun(
url='https://www.example.com',
url="https://www.example.com",
cache_mode=CacheMode.BYPASS,
session_id=session_id,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_session_reuse first call] success:", result1.success)
# Second call: same session, possibly cookie retained
result2 = await crawler.arun(
url='https://www.example.com/about',
url="https://www.example.com/about",
cache_mode=CacheMode.BYPASS,
session_id=session_id,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_session_reuse second call] success:", result2.success)
async def test_magic_mode():
# Test magic mode with override_navigator and simulate_user
async with AsyncWebCrawler(
@@ -95,13 +98,14 @@ async def test_magic_mode():
simulate_user=True,
) as crawler:
result = await crawler.arun(
url='https://www.kidocode.com/degrees/business',
url="https://www.kidocode.com/degrees/business",
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_magic_mode] success:", result.success)
print("HTML length:", len(result.html if result.html else ""))
async def test_proxy_settings():
# Test with a proxy (if available) to ensure code runs with proxy
async with AsyncWebCrawler(
@@ -113,14 +117,15 @@ async def test_proxy_settings():
use_persistent_context=False,
) as crawler:
result = await crawler.arun(
url='https://httpbin.org/ip',
url="https://httpbin.org/ip",
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_proxy_settings] success:", result.success)
if result.success:
print("HTML preview:", result.html[:200] if result.html else "")
async def test_ignore_https_errors():
# Test ignore HTTPS errors with a self-signed or invalid cert domain
# This is just conceptual, the domain should be one that triggers SSL error.
@@ -134,12 +139,13 @@ async def test_ignore_https_errors():
use_persistent_context=False,
) as crawler:
result = await crawler.arun(
url='https://self-signed.badssl.com/',
url="https://self-signed.badssl.com/",
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
)
print("[test_ignore_https_errors] success:", result.success)
async def main():
print("Running tests...")
# await test_default_headless()
@@ -149,5 +155,6 @@ async def main():
# await test_proxy_settings()
await test_ignore_https_errors()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,4 +1,5 @@
import os, sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@@ -9,7 +10,7 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.chunking_strategy import RegexChunking
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Category 1: Browser Configuration Tests
async def test_browser_config_object():
@@ -21,29 +22,31 @@ async def test_browser_config_object():
viewport_height=1080,
use_managed_browser=True,
user_agent_mode="random",
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"},
)
async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler:
result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS)
result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS)
assert result.success, "Browser config crawl failed"
assert len(result.html) > 0, "No HTML content retrieved"
async def test_browser_performance_config():
"""Test browser configurations focused on performance"""
browser_config = BrowserConfig(
text_mode=True,
light_mode=True,
extra_args=['--disable-gpu', '--disable-software-rasterizer'],
extra_args=["--disable-gpu", "--disable-software-rasterizer"],
ignore_https_errors=True,
java_script_enabled=False
java_script_enabled=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun('https://example.com')
result = await crawler.arun("https://example.com")
assert result.success, "Performance optimized crawl failed"
assert result.status_code == 200, "Unexpected status code"
# Category 2: Content Processing Tests
async def test_content_extraction_config():
"""Test content extraction with various strategies"""
@@ -53,24 +56,20 @@ async def test_content_extraction_config():
schema={
"name": "article",
"baseSelector": "div",
"fields": [{
"name": "title",
"selector": "h1",
"type": "text"
}]
"fields": [{"name": "title", "selector": "h1", "type": "text"}],
}
),
chunking_strategy=RegexChunking(),
content_filter=PruningContentFilter()
content_filter=PruningContentFilter(),
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
'https://example.com/article',
config=crawler_config
"https://example.com/article", config=crawler_config
)
assert result.extracted_content is not None, "Content extraction failed"
assert 'title' in result.extracted_content, "Missing expected content field"
assert "title" in result.extracted_content, "Missing expected content field"
# Category 3: Cache and Session Management Tests
async def test_cache_and_session_management():
@@ -79,25 +78,20 @@ async def test_cache_and_session_management():
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.WRITE_ONLY,
process_iframes=True,
remove_overlay_elements=True
remove_overlay_elements=True,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
# First request - should write to cache
result1 = await crawler.arun(
'https://example.com',
config=crawler_config
)
result1 = await crawler.arun("https://example.com", config=crawler_config)
# Second request - should use fresh fetch due to WRITE_ONLY mode
result2 = await crawler.arun(
'https://example.com',
config=crawler_config
)
result2 = await crawler.arun("https://example.com", config=crawler_config)
assert result1.success and result2.success, "Cache mode crawl failed"
assert result1.html == result2.html, "Inconsistent results between requests"
# Category 4: Media Handling Tests
async def test_media_handling_config():
"""Test configurations related to media handling"""
@@ -107,24 +101,22 @@ async def test_media_handling_config():
viewport_width=1920,
viewport_height=1080,
accept_downloads=True,
downloads_path= os.path.expanduser("~/.crawl4ai/downloads")
downloads_path=os.path.expanduser("~/.crawl4ai/downloads"),
)
crawler_config = CrawlerRunConfig(
screenshot=True,
pdf=True,
adjust_viewport_to_content=True,
wait_for_images=True,
screenshot_height_threshold=20000
screenshot_height_threshold=20000,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun(
'https://example.com',
config=crawler_config
)
result = await crawler.arun("https://example.com", config=crawler_config)
assert result.screenshot is not None, "Screenshot capture failed"
assert result.pdf is not None, "PDF generation failed"
# Category 5: Anti-Bot and Site Interaction Tests
async def test_antibot_config():
"""Test configurations for handling anti-bot measures"""
@@ -135,57 +127,43 @@ async def test_antibot_config():
wait_for="js:()=>document.querySelector('body')",
delay_before_return_html=1.0,
log_console=True,
cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS,
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
'https://example.com',
config=crawler_config
)
result = await crawler.arun("https://example.com", config=crawler_config)
assert result.success, "Anti-bot measure handling failed"
# Category 6: Parallel Processing Tests
async def test_parallel_processing():
"""Test parallel processing capabilities"""
crawler_config = CrawlerRunConfig(
mean_delay=0.5,
max_range=1.0,
semaphore_count=5
)
crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5)
urls = [
'https://example.com/1',
'https://example.com/2',
'https://example.com/3'
]
urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"]
async with AsyncWebCrawler() as crawler:
results = await crawler.arun_many(
urls,
config=crawler_config
)
results = await crawler.arun_many(urls, config=crawler_config)
assert len(results) == len(urls), "Not all URLs were processed"
assert all(r.success for r in results), "Some parallel requests failed"
# Category 7: Backwards Compatibility Tests
async def test_legacy_parameter_support():
"""Test that legacy parameters still work"""
async with AsyncWebCrawler(
headless=True,
browser_type="chromium",
viewport_width=1024,
viewport_height=768
headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768
) as crawler:
result = await crawler.arun(
'https://example.com',
"https://example.com",
screenshot=True,
word_count_threshold=200,
bypass_cache=True,
css_selector=".main-content"
css_selector=".main-content",
)
assert result.success, "Legacy parameter support failed"
# Category 8: Mixed Configuration Tests
async def test_mixed_config_usage():
"""Test mixing new config objects with legacy parameters"""
@@ -194,17 +172,19 @@ async def test_mixed_config_usage():
async with AsyncWebCrawler(
config=browser_config,
verbose=True # legacy parameter
verbose=True, # legacy parameter
) as crawler:
result = await crawler.arun(
'https://example.com',
"https://example.com",
config=crawler_config,
cache_mode=CacheMode.BYPASS, # legacy parameter
css_selector="body" # legacy parameter
css_selector="body", # legacy parameter
)
assert result.success, "Mixed configuration usage failed"
if __name__ == "__main__":
async def run_tests():
test_functions = [
test_browser_config_object,

View File

@@ -4,7 +4,6 @@ import asyncio
import shutil
from typing import List
import tempfile
import time
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -12,6 +11,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
class TestDownloads:
def __init__(self):
self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_")
@@ -31,9 +31,7 @@ class TestDownloads:
"""Test basic file download functionality"""
try:
async with AsyncWebCrawler(
accept_downloads=True,
downloads_path=self.download_dir,
verbose=True
accept_downloads=True, downloads_path=self.download_dir, verbose=True
) as crawler:
# Python.org downloads page typically has stable download links
result = await crawler.arun(
@@ -42,14 +40,19 @@ class TestDownloads:
// Click first download link
const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click();
"""
""",
)
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
)
self.log_result(
"Basic Download",
success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
)
except Exception as e:
self.log_result("Basic Download", False, str(e))
@@ -65,21 +68,26 @@ class TestDownloads:
downloads_path=self.download_dir,
use_persistent_context=True,
user_data_dir=user_data_dir,
verbose=True
verbose=True,
) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
js_code="""
const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click();
"""
""",
)
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
)
self.log_result(
"Persistent Context Download",
success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
)
except Exception as e:
self.log_result("Persistent Context Download", False, str(e))
@@ -88,9 +96,7 @@ class TestDownloads:
"""Test multiple simultaneous downloads"""
try:
async with AsyncWebCrawler(
accept_downloads=True,
downloads_path=self.download_dir,
verbose=True
accept_downloads=True, downloads_path=self.download_dir, verbose=True
) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
@@ -98,14 +104,19 @@ class TestDownloads:
// Click multiple download links
const downloadLinks = document.querySelectorAll('a[href$=".exe"]');
downloadLinks.forEach(link => link.click());
"""
""",
)
success = result.downloaded_files is not None and len(result.downloaded_files) > 1
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 1
)
self.log_result(
"Multiple Downloads",
success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded"
f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "Not enough files downloaded",
)
except Exception as e:
self.log_result("Multiple Downloads", False, str(e))
@@ -120,21 +131,26 @@ class TestDownloads:
accept_downloads=True,
downloads_path=self.download_dir,
browser_type=browser_type,
verbose=True
verbose=True,
) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
js_code="""
const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click();
"""
""",
)
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
)
self.log_result(
f"{browser_type.title()} Download",
success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
)
except Exception as e:
self.log_result(f"{browser_type.title()} Download", False, str(e))
@@ -144,18 +160,15 @@ class TestDownloads:
# Test 1: Downloads without specifying download path
try:
async with AsyncWebCrawler(
accept_downloads=True,
verbose=True
) as crawler:
async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()"
js_code="document.querySelector('a[href$=\".exe\"]').click()",
)
self.log_result(
"Default Download Path",
True,
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}"
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}",
)
except Exception as e:
self.log_result("Default Download Path", False, str(e))
@@ -165,31 +178,34 @@ class TestDownloads:
async with AsyncWebCrawler(
accept_downloads=True,
downloads_path="/invalid/path/that/doesnt/exist",
verbose=True
verbose=True,
) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()"
js_code="document.querySelector('a[href$=\".exe\"]').click()",
)
self.log_result(
"Invalid Download Path", False, "Should have raised an error"
)
except Exception:
self.log_result(
"Invalid Download Path", True, "Correctly handled invalid path"
)
self.log_result("Invalid Download Path", False, "Should have raised an error")
except Exception as e:
self.log_result("Invalid Download Path", True, "Correctly handled invalid path")
# Test 3: Download with accept_downloads=False
try:
async with AsyncWebCrawler(
accept_downloads=False,
verbose=True
) as crawler:
async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler:
result = await crawler.arun(
url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()"
js_code="document.querySelector('a[href$=\".exe\"]').click()",
)
success = result.downloaded_files is None
self.log_result(
"Disabled Downloads",
success,
"Correctly ignored downloads" if success else "Unexpectedly downloaded files"
"Correctly ignored downloads"
if success
else "Unexpectedly downloaded files",
)
except Exception as e:
self.log_result("Disabled Downloads", False, str(e))
@@ -203,7 +219,7 @@ class TestDownloads:
self.test_persistent_context_download,
self.test_multiple_downloads,
self.test_different_browsers,
self.test_edge_cases
self.test_edge_cases,
]
for test in test_methods:
@@ -215,15 +231,17 @@ class TestDownloads:
for result in self.results:
print(result)
successes = len([r for r in self.results if '' in r])
successes = len([r for r in self.results if "" in r])
total = len(self.results)
print(f"\nTotal: {successes}/{total} tests passed")
self.cleanup()
async def main():
tester = TestDownloads()
await tester.run_all_tests()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,15 +1,17 @@
import os
import sys
import pytest
import asyncio
import time
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_successful_crawl():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -21,6 +23,7 @@ async def test_successful_crawl():
assert result.markdown
assert result.cleaned_html
@pytest.mark.asyncio
async def test_invalid_url():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -29,19 +32,21 @@ async def test_invalid_url():
assert not result.success
assert result.error_message
@pytest.mark.asyncio
async def test_multiple_urls():
async with AsyncWebCrawler(verbose=True) as crawler:
urls = [
"https://www.nbcnews.com/business",
"https://www.example.com",
"https://www.python.org"
"https://www.python.org",
]
results = await crawler.arun_many(urls=urls, bypass_cache=True)
assert len(results) == len(urls)
assert all(result.success for result in results)
assert all(result.html for result in results)
@pytest.mark.asyncio
async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -51,6 +56,7 @@ async def test_javascript_execution():
assert result.success
assert "<h1>Modified by JS</h1>" in result.html
@pytest.mark.asyncio
async def test_concurrent_crawling_performance():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -59,7 +65,7 @@ async def test_concurrent_crawling_performance():
"https://www.example.com",
"https://www.python.org",
"https://www.github.com",
"https://www.stackoverflow.com"
"https://www.stackoverflow.com",
]
start_time = time.time()
@@ -74,7 +80,10 @@ async def test_concurrent_crawling_performance():
# Assert that concurrent crawling is faster than sequential
# This multiplier may need adjustment based on the number of URLs and their complexity
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
assert (
total_time < len(urls) * 5
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
# Entry point for debugging
if __name__ == "__main__":

View File

@@ -9,6 +9,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_caching():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -31,6 +32,7 @@ async def test_caching():
assert result2.success
assert time_taken2 < time_taken1 # Cached result should be faster
@pytest.mark.asyncio
async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -47,6 +49,7 @@ async def test_bypass_cache():
# Content should be different (or at least, not guaranteed to be the same)
assert result1.html != result2.html or result1.markdown != result2.markdown
@pytest.mark.asyncio
async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -62,6 +65,7 @@ async def test_clear_cache():
cache_size = await crawler.aget_cache_size()
assert cache_size == 0
@pytest.mark.asyncio
async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -77,6 +81,7 @@ async def test_flush_cache():
cache_size = await crawler.aget_cache_size()
assert cache_size == 0
# Entry point for debugging
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os
import sys
import pytest
import asyncio
import json
# Add the parent directory to the Python path
@@ -9,8 +8,9 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy
from crawl4ai.chunking_strategy import RegexChunking
from crawl4ai.extraction_strategy import LLMExtractionStrategy
@pytest.mark.asyncio
async def test_regex_chunking():
@@ -18,15 +18,14 @@ async def test_regex_chunking():
url = "https://www.nbcnews.com/business"
chunking_strategy = RegexChunking(patterns=["\n\n"])
result = await crawler.arun(
url=url,
chunking_strategy=chunking_strategy,
bypass_cache=True
url=url, chunking_strategy=chunking_strategy, bypass_cache=True
)
assert result.success
assert result.extracted_content
chunks = json.loads(result.extracted_content)
assert len(chunks) > 1 # Ensure multiple chunks were created
# @pytest.mark.asyncio
# async def test_cosine_strategy():
# async with AsyncWebCrawler(verbose=True) as crawler:
@@ -43,25 +42,25 @@ async def test_regex_chunking():
# assert len(extracted_data) > 0
# assert all('tags' in item for item in extracted_data)
@pytest.mark.asyncio
async def test_llm_extraction_strategy():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business"
extraction_strategy = LLMExtractionStrategy(
provider="openai/gpt-4o-mini",
api_token=os.getenv('OPENAI_API_KEY'),
instruction="Extract only content related to technology"
api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract only content related to technology",
)
result = await crawler.arun(
url=url,
extraction_strategy=extraction_strategy,
bypass_cache=True
url=url, extraction_strategy=extraction_strategy, bypass_cache=True
)
assert result.success
assert result.extracted_content
extracted_data = json.loads(result.extracted_content)
assert len(extracted_data) > 0
assert all('content' in item for item in extracted_data)
assert all("content" in item for item in extracted_data)
# @pytest.mark.asyncio
# async def test_combined_chunking_and_extraction():

View File

@@ -1,8 +1,6 @@
import os
import sys
import pytest
import asyncio
import json
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_extract_markdown():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -20,6 +19,7 @@ async def test_extract_markdown():
assert isinstance(result.markdown, str)
assert len(result.markdown) > 0
@pytest.mark.asyncio
async def test_extract_cleaned_html():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -30,6 +30,7 @@ async def test_extract_cleaned_html():
assert isinstance(result.cleaned_html, str)
assert len(result.cleaned_html) > 0
@pytest.mark.asyncio
async def test_extract_media():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -46,6 +47,7 @@ async def test_extract_media():
assert "alt" in image
assert "type" in image
@pytest.mark.asyncio
async def test_extract_links():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -63,6 +65,7 @@ async def test_extract_links():
assert "href" in link
assert "text" in link
@pytest.mark.asyncio
async def test_extract_metadata():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -75,16 +78,20 @@ async def test_extract_metadata():
assert "title" in metadata
assert isinstance(metadata["title"], str)
@pytest.mark.asyncio
async def test_css_selector_extraction():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business"
css_selector = "h1, h2, h3"
result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector)
result = await crawler.arun(
url=url, bypass_cache=True, css_selector=css_selector
)
assert result.success
assert result.markdown
assert all(heading in result.markdown for heading in ["#", "##", "###"])
# Entry point for debugging
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os, sys
import pytest
from bs4 import BeautifulSoup
from typing import List
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -9,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.content_filter_strategy import BM25ContentFilter
@pytest.fixture
def basic_html():
return """
@@ -28,6 +28,7 @@ def basic_html():
</html>
"""
@pytest.fixture
def wiki_html():
return """
@@ -46,6 +47,7 @@ def wiki_html():
</html>
"""
@pytest.fixture
def no_meta_html():
return """
@@ -57,6 +59,7 @@ def no_meta_html():
</html>
"""
class TestBM25ContentFilter:
def test_basic_extraction(self, basic_html):
"""Test basic content extraction functionality"""
@@ -65,8 +68,8 @@ class TestBM25ContentFilter:
assert contents, "Should extract content"
assert len(contents) >= 1, "Should extract at least one content block"
assert "long paragraph" in ' '.join(contents).lower()
assert "navigation" not in ' '.join(contents).lower()
assert "long paragraph" in " ".join(contents).lower()
assert "navigation" not in " ".join(contents).lower()
def test_user_query_override(self, basic_html):
"""Test that user query overrides metadata extraction"""
@@ -74,8 +77,8 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter(user_query=user_query)
# Access internal state to verify query usage
soup = BeautifulSoup(basic_html, 'lxml')
extracted_query = filter.extract_page_query(soup.find('head'))
soup = BeautifulSoup(basic_html, "lxml")
extracted_query = filter.extract_page_query(soup.find("head"))
assert extracted_query == user_query
assert "Test description" not in extracted_query
@@ -85,7 +88,7 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter()
contents = filter.filter_content(wiki_html)
combined_content = ' '.join(contents).lower()
combined_content = " ".join(contents).lower()
assert "section 1" in combined_content, "Should include section header"
assert "article title" in combined_content, "Should include main title"
@@ -95,7 +98,9 @@ class TestBM25ContentFilter:
contents = filter.filter_content(no_meta_html)
assert contents, "Should extract content even without metadata"
assert "First paragraph" in ' '.join(contents), "Should use first paragraph content"
assert "First paragraph" in " ".join(
contents
), "Should use first paragraph content"
def test_empty_input(self):
"""Test handling of empty input"""
@@ -119,18 +124,19 @@ class TestBM25ContentFilter:
strict_contents = strict_filter.filter_content(basic_html)
lenient_contents = lenient_filter.filter_content(basic_html)
assert len(strict_contents) <= len(lenient_contents), \
"Strict threshold should extract fewer elements"
assert len(strict_contents) <= len(
lenient_contents
), "Strict threshold should extract fewer elements"
def test_html_cleaning(self, basic_html):
"""Test HTML cleaning functionality"""
filter = BM25ContentFilter()
contents = filter.filter_content(basic_html)
cleaned_content = ' '.join(contents)
assert 'class=' not in cleaned_content, "Should remove class attributes"
assert 'style=' not in cleaned_content, "Should remove style attributes"
assert '<script' not in cleaned_content, "Should remove script tags"
cleaned_content = " ".join(contents)
assert "class=" not in cleaned_content, "Should remove class attributes"
assert "style=" not in cleaned_content, "Should remove style attributes"
assert "<script" not in cleaned_content, "Should remove script tags"
def test_large_content(self):
"""Test handling of large content blocks"""
@@ -143,9 +149,9 @@ class TestBM25ContentFilter:
contents = filter.filter_content(large_html)
assert contents, "Should handle large content blocks"
@pytest.mark.parametrize("unwanted_tag", [
'script', 'style', 'nav', 'footer', 'header'
])
@pytest.mark.parametrize(
"unwanted_tag", ["script", "style", "nav", "footer", "header"]
)
def test_excluded_tags(self, unwanted_tag):
"""Test that specific tags are properly excluded"""
html = f"""
@@ -157,7 +163,7 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter()
contents = filter.filter_content(html)
combined_content = ' '.join(contents).lower()
combined_content = " ".join(contents).lower()
assert "should not appear" not in combined_content
def test_performance(self, basic_html):
@@ -165,11 +171,13 @@ class TestBM25ContentFilter:
filter = BM25ContentFilter()
import time
start = time.perf_counter()
filter.filter_content(basic_html)
duration = time.perf_counter() - start
assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,12 +1,12 @@
import os, sys
import pytest
from bs4 import BeautifulSoup
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from crawl4ai.content_filter_strategy import PruningContentFilter
@pytest.fixture
def basic_html():
return """
@@ -22,6 +22,7 @@ def basic_html():
</html>
"""
@pytest.fixture
def link_heavy_html():
return """
@@ -40,6 +41,7 @@ def link_heavy_html():
</html>
"""
@pytest.fixture
def mixed_content_html():
return """
@@ -60,13 +62,14 @@ def mixed_content_html():
</html>
"""
class TestPruningContentFilter:
def test_basic_pruning(self, basic_html):
"""Test basic content pruning functionality"""
filter = PruningContentFilter(min_word_threshold=5)
contents = filter.filter_content(basic_html)
combined_content = ' '.join(contents).lower()
combined_content = " ".join(contents).lower()
assert "high-quality paragraph" in combined_content
assert "sidebar content" not in combined_content
assert "share buttons" not in combined_content
@@ -76,39 +79,41 @@ class TestPruningContentFilter:
filter = PruningContentFilter(min_word_threshold=10)
contents = filter.filter_content(mixed_content_html)
combined_content = ' '.join(contents).lower()
combined_content = " ".join(contents).lower()
assert "short summary" not in combined_content
assert "long high-quality paragraph" in combined_content
assert "short comment" not in combined_content
def test_threshold_types(self, basic_html):
"""Test fixed vs dynamic thresholds"""
fixed_filter = PruningContentFilter(threshold_type='fixed', threshold=0.48)
dynamic_filter = PruningContentFilter(threshold_type='dynamic', threshold=0.45)
fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48)
dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45)
fixed_contents = fixed_filter.filter_content(basic_html)
dynamic_contents = dynamic_filter.filter_content(basic_html)
assert len(fixed_contents) != len(dynamic_contents), \
"Fixed and dynamic thresholds should yield different results"
assert len(fixed_contents) != len(
dynamic_contents
), "Fixed and dynamic thresholds should yield different results"
def test_link_density_impact(self, link_heavy_html):
"""Test handling of link-heavy content"""
filter = PruningContentFilter(threshold_type='dynamic')
filter = PruningContentFilter(threshold_type="dynamic")
contents = filter.filter_content(link_heavy_html)
combined_content = ' '.join(contents).lower()
combined_content = " ".join(contents).lower()
assert "good content paragraph" in combined_content
assert len([c for c in contents if 'href' in c]) < 2, \
"Should prune link-heavy sections"
assert (
len([c for c in contents if "href" in c]) < 2
), "Should prune link-heavy sections"
def test_tag_importance(self, mixed_content_html):
"""Test tag importance in scoring"""
filter = PruningContentFilter(threshold_type='dynamic')
filter = PruningContentFilter(threshold_type="dynamic")
contents = filter.filter_content(mixed_content_html)
has_article = any('article' in c.lower() for c in contents)
has_h1 = any('h1' in c.lower() for c in contents)
has_article = any("article" in c.lower() for c in contents)
has_h1 = any("h1" in c.lower() for c in contents)
assert has_article or has_h1, "Should retain important tags"
def test_empty_input(self):
@@ -129,6 +134,7 @@ class TestPruningContentFilter:
filter = PruningContentFilter()
import time
start = time.perf_counter()
filter.filter_content(basic_html)
duration = time.perf_counter() - start
@@ -136,17 +142,21 @@ class TestPruningContentFilter:
# Extra strict on performance since you mentioned milliseconds matter
assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds"
@pytest.mark.parametrize("threshold,expected_count", [
@pytest.mark.parametrize(
"threshold,expected_count",
[
(0.3, 4), # Very lenient
(0.48, 2), # Default
(0.7, 1), # Very strict
])
],
)
def test_threshold_levels(self, mixed_content_html, threshold, expected_count):
"""Test different threshold levels"""
filter = PruningContentFilter(threshold_type='fixed', threshold=threshold)
filter = PruningContentFilter(threshold_type="fixed", threshold=threshold)
contents = filter.filter_content(mixed_content_html)
assert len(contents) <= expected_count, \
f"Expected {expected_count} or fewer elements with threshold {threshold}"
assert (
len(contents) <= expected_count
), f"Expected {expected_count} or fewer elements with threshold {threshold}"
def test_consistent_output(self, basic_html):
"""Test output consistency across multiple runs"""
@@ -155,5 +165,6 @@ class TestPruningContentFilter:
second_run = filter.filter_content(basic_html)
assert first_run == second_run, "Output should be consistent"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,22 +1,24 @@
import asyncio
from bs4 import BeautifulSoup
from typing import Dict, Any
import os
import sys
import time
import csv
from tabulate import tabulate
from dataclasses import dataclass
from typing import List, Dict
from typing import List
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
from crawl4ai.content_scraping_strategy import WebScrapingStrategy
from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent
from crawl4ai.content_scraping_strategy import (
WebScrapingStrategy as WebScrapingStrategyCurrent,
)
# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
@dataclass
class TestResult:
name: str
@@ -27,33 +29,32 @@ class TestResult:
markdown_length: int
execution_time: float
class StrategyTester:
def __init__(self):
self.new_scraper = WebScrapingStrategy()
self.current_scraper = WebScrapingStrategyCurrent()
with open(__location__ + '/sample_wikipedia.html', 'r', encoding='utf-8') as f:
with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f:
self.WIKI_HTML = f.read()
self.results = {'new': [], 'current': []}
self.results = {"new": [], "current": []}
def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
results = []
for scraper in [self.new_scraper, self.current_scraper]:
start_time = time.time()
result = scraper._get_content_of_website_optimized(
url="https://en.wikipedia.org/wiki/Test",
html=self.WIKI_HTML,
**kwargs
url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs
)
execution_time = time.time() - start_time
test_result = TestResult(
name=name,
success=result['success'],
images=len(result['media']['images']),
internal_links=len(result['links']['internal']),
external_links=len(result['links']['external']),
markdown_length=len(result['markdown']),
execution_time=execution_time
success=result["success"],
images=len(result["media"]["images"]),
internal_links=len(result["links"]["internal"]),
external_links=len(result["links"]["external"]),
markdown_length=len(result["markdown"]),
execution_time=execution_time,
)
results.append(test_result)
@@ -62,34 +63,37 @@ class StrategyTester:
def run_all_tests(self):
test_cases = [
("Basic Extraction", {}),
("Exclude Tags", {'excluded_tags': ['table', 'div.infobox', 'div.navbox']}),
("Word Threshold", {'word_count_threshold': 50}),
("CSS Selector", {'css_selector': 'div.mw-parser-output > p'}),
("Link Exclusions", {
'exclude_external_links': True,
'exclude_social_media_links': True,
'exclude_domains': ['facebook.com', 'twitter.com']
}),
("Media Handling", {
'exclude_external_images': True,
'image_description_min_word_threshold': 20
}),
("Text Only", {
'only_text': True,
'remove_forms': True
}),
("HTML Cleaning", {
'clean_html': True,
'keep_data_attributes': True
}),
("HTML2Text Options", {
'html2text': {
'skip_internal_links': True,
'single_line_break': True,
'mark_code': True,
'preserve_tags': ['pre', 'code']
("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
("Word Threshold", {"word_count_threshold": 50}),
("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
(
"Link Exclusions",
{
"exclude_external_links": True,
"exclude_social_media_links": True,
"exclude_domains": ["facebook.com", "twitter.com"],
},
),
(
"Media Handling",
{
"exclude_external_images": True,
"image_description_min_word_threshold": 20,
},
),
("Text Only", {"only_text": True, "remove_forms": True}),
("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
(
"HTML2Text Options",
{
"html2text": {
"skip_internal_links": True,
"single_line_break": True,
"mark_code": True,
"preserve_tags": ["pre", "code"],
}
})
},
),
]
all_results = []
@@ -104,58 +108,111 @@ class StrategyTester:
self.print_comparison_table(all_results)
def save_results_to_csv(self, all_results: List[tuple]):
csv_file = os.path.join(__location__, 'strategy_comparison_results.csv')
with open(csv_file, 'w', newline='') as f:
csv_file = os.path.join(__location__, "strategy_comparison_results.csv")
with open(csv_file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links',
'External Links', 'Markdown Length', 'Execution Time'])
writer.writerow(
[
"Test Name",
"Strategy",
"Success",
"Images",
"Internal Links",
"External Links",
"Markdown Length",
"Execution Time",
]
)
for name, new_result, current_result in all_results:
writer.writerow([name, 'New', new_result.success, new_result.images,
new_result.internal_links, new_result.external_links,
new_result.markdown_length, f"{new_result.execution_time:.3f}"])
writer.writerow([name, 'Current', current_result.success, current_result.images,
current_result.internal_links, current_result.external_links,
current_result.markdown_length, f"{current_result.execution_time:.3f}"])
writer.writerow(
[
name,
"New",
new_result.success,
new_result.images,
new_result.internal_links,
new_result.external_links,
new_result.markdown_length,
f"{new_result.execution_time:.3f}",
]
)
writer.writerow(
[
name,
"Current",
current_result.success,
current_result.images,
current_result.internal_links,
current_result.external_links,
current_result.markdown_length,
f"{current_result.execution_time:.3f}",
]
)
def print_comparison_table(self, all_results: List[tuple]):
table_data = []
headers = ['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links',
'External Links', 'Markdown Length', 'Time (s)']
headers = [
"Test Name",
"Strategy",
"Success",
"Images",
"Internal Links",
"External Links",
"Markdown Length",
"Time (s)",
]
for name, new_result, current_result in all_results:
# Check for differences
differences = []
if new_result.images != current_result.images: differences.append('images')
if new_result.internal_links != current_result.internal_links: differences.append('internal_links')
if new_result.external_links != current_result.external_links: differences.append('external_links')
if new_result.markdown_length != current_result.markdown_length: differences.append('markdown')
if new_result.images != current_result.images:
differences.append("images")
if new_result.internal_links != current_result.internal_links:
differences.append("internal_links")
if new_result.external_links != current_result.external_links:
differences.append("external_links")
if new_result.markdown_length != current_result.markdown_length:
differences.append("markdown")
# Add row for new strategy
new_row = [
name, 'New', new_result.success, new_result.images,
new_result.internal_links, new_result.external_links,
new_result.markdown_length, f"{new_result.execution_time:.3f}"
name,
"New",
new_result.success,
new_result.images,
new_result.internal_links,
new_result.external_links,
new_result.markdown_length,
f"{new_result.execution_time:.3f}",
]
table_data.append(new_row)
# Add row for current strategy
current_row = [
'', 'Current', current_result.success, current_result.images,
current_result.internal_links, current_result.external_links,
current_result.markdown_length, f"{current_result.execution_time:.3f}"
"",
"Current",
current_result.success,
current_result.images,
current_result.internal_links,
current_result.external_links,
current_result.markdown_length,
f"{current_result.execution_time:.3f}",
]
table_data.append(current_row)
# Add difference summary if any
if differences:
table_data.append(['', '⚠️ Differences', ', '.join(differences), '', '', '', '', ''])
table_data.append(
["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
)
# Add empty row for better readability
table_data.append([''] * len(headers))
table_data.append([""] * len(headers))
print("\nStrategy Comparison Results:")
print(tabulate(table_data, headers=headers, tablefmt='grid'))
print(tabulate(table_data, headers=headers, tablefmt="grid"))
if __name__ == "__main__":
tester = StrategyTester()

View File

@@ -1,14 +1,13 @@
import os
import sys
import pytest
import asyncio
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
@pytest.mark.asyncio
async def test_custom_user_agent():
@@ -20,6 +19,7 @@ async def test_custom_user_agent():
assert result.success
assert custom_user_agent in result.html
@pytest.mark.asyncio
async def test_custom_headers():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -31,6 +31,7 @@ async def test_custom_headers():
assert "X-Test-Header" in result.html
assert "TestValue" in result.html
@pytest.mark.asyncio
async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -40,19 +41,22 @@ async def test_javascript_execution():
assert result.success
assert "<h1>Modified by JS</h1>" in result.html
@pytest.mark.asyncio
async def test_hook_execution():
async with AsyncWebCrawler(verbose=True) as crawler:
async def test_hook(page):
await page.evaluate("document.body.style.backgroundColor = 'red';")
return page
crawler.crawler_strategy.set_hook('after_goto', test_hook)
crawler.crawler_strategy.set_hook("after_goto", test_hook)
url = "https://www.example.com"
result = await crawler.arun(url=url, bypass_cache=True)
assert result.success
assert "background-color: red" in result.html
@pytest.mark.asyncio
async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -63,6 +67,7 @@ async def test_screenshot():
assert isinstance(result.screenshot, str)
assert len(result.screenshot) > 0
# Entry point for debugging
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,8 +1,6 @@
import os
import sys
import pytest
import asyncio
import json
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_cache_url():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -23,6 +22,7 @@ async def test_cache_url():
assert result2.success
assert result2.html == result1.html
@pytest.mark.asyncio
async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -34,7 +34,10 @@ async def test_bypass_cache():
# Second run bypassing cache
result2 = await crawler.arun(url=url, bypass_cache=True)
assert result2.success
assert result2.html != result1.html # Content might be different due to dynamic nature of websites
assert (
result2.html != result1.html
) # Content might be different due to dynamic nature of websites
@pytest.mark.asyncio
async def test_cache_size():
@@ -47,6 +50,7 @@ async def test_cache_size():
new_size = await crawler.aget_cache_size()
assert new_size == initial_size + 1
@pytest.mark.asyncio
async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -60,6 +64,7 @@ async def test_clear_cache():
new_size = await crawler.aget_cache_size()
assert new_size == 0
@pytest.mark.asyncio
async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -75,7 +80,10 @@ async def test_flush_cache():
# Try to retrieve the previously cached URL
result = await crawler.arun(url=url, bypass_cache=False)
assert result.success # The crawler should still succeed, but it will fetch the content anew
assert (
result.success
) # The crawler should still succeed, but it will fetch the content anew
# Entry point for debugging
if __name__ == "__main__":

View File

@@ -1,114 +1,133 @@
import pytest
import asyncio, time
import time
from crawl4ai import (
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig,
MemoryAdaptiveDispatcher, SemaphoreDispatcher,
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
CacheMode,
)
@pytest.fixture
def browser_config():
return BrowserConfig(
headless=True,
verbose=False
)
return BrowserConfig(headless=True, verbose=False)
@pytest.fixture
def run_config():
return CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
verbose=False
)
return CrawlerRunConfig(cache_mode=CacheMode.BYPASS, verbose=False)
@pytest.fixture
def test_urls():
return [
"http://example.com",
"http://example.com/page1",
"http://example.com/page2"
"http://example.com/page2",
]
@pytest.mark.asyncio
class TestDispatchStrategies:
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=70.0,
max_session_permit=2,
check_interval=0.1
memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1
)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls)
assert all(r.success for r in results)
async def test_memory_adaptive_with_rate_limit(self, browser_config, run_config, test_urls):
async def test_memory_adaptive_with_rate_limit(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=70.0,
max_session_permit=2,
check_interval=0.1,
rate_limiter=RateLimiter(
base_delay=(0.1, 0.2),
max_delay=1.0,
max_retries=2
base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
),
)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls)
assert all(r.success for r in results)
async def test_semaphore_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = SemaphoreDispatcher(
semaphore_count=2
dispatcher = SemaphoreDispatcher(semaphore_count=2)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls)
assert all(r.success for r in results)
async def test_semaphore_with_rate_limit(self, browser_config, run_config, test_urls):
async def test_semaphore_with_rate_limit(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = SemaphoreDispatcher(
semaphore_count=2,
rate_limiter=RateLimiter(
base_delay=(0.1, 0.2),
max_delay=1.0,
max_retries=2
base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
),
)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls)
assert all(r.success for r in results)
async def test_memory_adaptive_memory_error(self, browser_config, run_config, test_urls):
async def test_memory_adaptive_memory_error(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=1.0, # Set unrealistically low threshold
max_session_permit=2,
check_interval=0.1,
memory_wait_timeout=1.0 # Short timeout for testing
memory_wait_timeout=1.0, # Short timeout for testing
)
with pytest.raises(MemoryError):
await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
async def test_empty_urls(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many([], config=run_config, dispatcher=dispatcher)
results = await crawler.arun_many(
[], config=run_config, dispatcher=dispatcher
)
assert len(results) == 0
async def test_single_url(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many(["http://example.com"], config=run_config, dispatcher=dispatcher)
results = await crawler.arun_many(
["http://example.com"], config=run_config, dispatcher=dispatcher
)
assert len(results) == 1
assert results[0].success
async def test_invalid_urls(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many(["http://invalid.url.that.doesnt.exist"], config=run_config, dispatcher=dispatcher)
results = await crawler.arun_many(
["http://invalid.url.that.doesnt.exist"],
config=run_config,
dispatcher=dispatcher,
)
assert len(results) == 1
assert not results[0].success
@@ -121,27 +140,31 @@ class TestDispatchStrategies:
base_delay=(0.1, 0.2),
max_delay=1.0,
max_retries=2,
rate_limit_codes=[200] # Force rate limiting for testing
)
rate_limit_codes=[200], # Force rate limiting for testing
),
)
start_time = time.time()
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
duration = time.time() - start_time
assert len(results) == len(urls)
assert duration > 1.0 # Ensure rate limiting caused delays
async def test_monitor_integration(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler:
monitor = CrawlerMonitor(max_visible_rows=5, display_mode=DisplayMode.DETAILED)
dispatcher = MemoryAdaptiveDispatcher(
max_session_permit=2,
monitor=monitor
monitor = CrawlerMonitor(
max_visible_rows=5, display_mode=DisplayMode.DETAILED
)
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls)
# Check monitor stats
assert len(monitor.stats) == len(test_urls)
assert all(stat.end_time is not None for stat in monitor.stats.values())
if __name__ == "__main__":
pytest.main([__file__, "-v", "--asyncio-mode=auto"])

View File

@@ -2,9 +2,9 @@ import os
import re
import sys
import pytest
import json
from bs4 import BeautifulSoup
import asyncio
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
@@ -59,19 +59,21 @@ from crawl4ai.async_webcrawler import AsyncWebCrawler
# assert result.success
# assert "github" in result.html.lower()
# Add this test to your existing test file
@pytest.mark.asyncio
async def test_typescript_commits_multi_page():
first_commit = ""
async def on_execution_started(page):
nonlocal first_commit
try:
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
while True:
await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4')
commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4')
commit = await commit.evaluate('(element) => element.textContent')
commit = re.sub(r'\s+', '', commit)
await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4")
commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4")
commit = await commit.evaluate("(element) => element.textContent")
commit = re.sub(r"\s+", "", commit)
if commit and commit != first_commit:
first_commit = commit
break
@@ -79,9 +81,8 @@ async def test_typescript_commits_multi_page():
except Exception as e:
print(f"Warning: New content didn't appear after JavaScript execution: {e}")
async with AsyncWebCrawler(verbose=True) as crawler:
crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started)
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
url = "https://github.com/microsoft/TypeScript/commits/main"
session_id = "typescript_commits_session"
@@ -97,19 +98,21 @@ async def test_typescript_commits_multi_page():
url=url, # Only use URL for the first page
session_id=session_id,
css_selector="li.Box-sc-g0xbh4-0",
js=js_next_page if page > 0 else None, # Don't click 'next' on the first page
js=js_next_page
if page > 0
else None, # Don't click 'next' on the first page
bypass_cache=True,
js_only=page > 0 # Use js_only for subsequent pages
js_only=page > 0, # Use js_only for subsequent pages
)
assert result.success, f"Failed to crawl page {page + 1}"
# Parse the HTML and extract commits
soup = BeautifulSoup(result.cleaned_html, 'html.parser')
soup = BeautifulSoup(result.cleaned_html, "html.parser")
commits = soup.select("li")
# Take first commit find h4 extract text
first_commit = commits[0].find("h4").text
first_commit = re.sub(r'\s+', '', first_commit)
first_commit = re.sub(r"\s+", "", first_commit)
all_commits.extend(commits)
print(f"Page {page + 1}: Found {len(commits)} commits")
@@ -118,10 +121,13 @@ async def test_typescript_commits_multi_page():
await crawler.crawler_strategy.kill_session(session_id)
# Assertions
assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}"
assert (
len(all_commits) >= 90
), f"Expected at least 90 commits, but got {len(all_commits)}"
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
# Entry point for debugging
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,11 +1,15 @@
import json
import time
from bs4 import BeautifulSoup
from crawl4ai.content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy
from typing import Dict, Any, List, Tuple
from crawl4ai.content_scraping_strategy import (
WebScrapingStrategy,
LXMLWebScrapingStrategy,
)
from typing import Dict, List, Tuple
import difflib
from lxml import html as lhtml, etree
def normalize_dom(element):
"""
Recursively normalizes an lxml HTML element:
@@ -15,7 +19,7 @@ def normalize_dom(element):
Returns the same element (mutated).
"""
# Remove comment nodes
comments = element.xpath('//comment()')
comments = element.xpath("//comment()")
for c in comments:
p = c.getparent()
if p is not None:
@@ -53,8 +57,8 @@ def strip_html_body(root):
tag_name = (root.tag or "").lower()
# Case 1: The root is <html>
if tag_name == 'html':
bodies = root.xpath('./body')
if tag_name == "html":
bodies = root.xpath("./body")
if bodies:
body = bodies[0]
new_div = lhtml.Element("div")
@@ -66,7 +70,7 @@ def strip_html_body(root):
return root
# Case 2: The root is <body>
elif tag_name == 'body':
elif tag_name == "body":
new_div = lhtml.Element("div")
for child in root:
new_div.append(child)
@@ -92,7 +96,9 @@ def compare_nodes(node1, node2, differences, path="/"):
attrs1 = list(node1.attrib.items())
attrs2 = list(node2.attrib.items())
if attrs1 != attrs2:
differences.append(f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}")
differences.append(
f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}"
)
# 3) Compare text (trim or unify whitespace as needed)
text1 = (node1.text or "").strip()
@@ -102,7 +108,9 @@ def compare_nodes(node1, node2, differences, path="/"):
text2 = " ".join(text2.split())
if text1 != text2:
# If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup
differences.append(f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'")
differences.append(
f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'"
)
# 4) Compare number of children
children1 = list(node1)
@@ -123,7 +131,9 @@ def compare_nodes(node1, node2, differences, path="/"):
tail1 = (node1.tail or "").strip()
tail2 = (node2.tail or "").strip()
if tail1 != tail2:
differences.append(f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'")
differences.append(
f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'"
)
def compare_html_structurally(html1, html2):
@@ -156,11 +166,11 @@ def compare_html_structurally(html1, html2):
return differences
def generate_large_html(n_elements=1000):
html = ['<!DOCTYPE html><html><head></head><body>']
html = ["<!DOCTYPE html><html><head></head><body>"]
for i in range(n_elements):
html.append(f'''
html.append(
f"""
<div class="article">
<h2>Heading {i}</h2>
<p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p>
@@ -170,9 +180,11 @@ def generate_large_html(n_elements=1000):
<li>List item {i}.2</li>
</ul>
</div>
''')
html.append('</body></html>')
return ''.join(html)
"""
)
html.append("</body></html>")
return "".join(html)
def generate_complicated_html():
"""
@@ -352,13 +364,12 @@ def get_test_scenarios():
return TEST_SCENARIOS
class ScraperEquivalenceTester:
def __init__(self):
self.test_cases = {
'basic': self.generate_basic_html(),
'complex': self.generate_complex_html(),
'malformed': self.generate_malformed_html(),
"basic": self.generate_basic_html(),
"complex": self.generate_complex_html(),
"malformed": self.generate_malformed_html(),
# 'real_world': self.load_real_samples()
}
@@ -399,20 +410,19 @@ class ScraperEquivalenceTester:
def load_real_samples(self):
# Load some real-world HTML samples you've collected
samples = {
'article': open('tests/samples/article.html').read(),
'product': open('tests/samples/product.html').read(),
'blog': open('tests/samples/blog.html').read()
"article": open("tests/samples/article.html").read(),
"product": open("tests/samples/product.html").read(),
"blog": open("tests/samples/blog.html").read(),
}
return samples
def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]:
"""Detailed comparison of link structures"""
differences = []
for category in ['internal', 'external']:
old_urls = {link['href'] for link in old_links[category]}
new_urls = {link['href'] for link in new_links[category]}
for category in ["internal", "external"]:
old_urls = {link["href"] for link in old_links[category]}
new_urls = {link["href"] for link in new_links[category]}
missing = old_urls - new_urls
extra = new_urls - old_urls
@@ -425,10 +435,10 @@ class ScraperEquivalenceTester:
# Compare link attributes for common URLs
common = old_urls & new_urls
for url in common:
old_link = next(l for l in old_links[category] if l['href'] == url)
new_link = next(l for l in new_links[category] if l['href'] == url)
old_link = next(l for l in old_links[category] if l["href"] == url)
new_link = next(l for l in new_links[category] if l["href"] == url)
for attr in ['text', 'title']:
for attr in ["text", "title"]:
if old_link[attr] != new_link[attr]:
differences.append(
f"Link attribute mismatch for {url} - {attr}:"
@@ -441,9 +451,9 @@ class ScraperEquivalenceTester:
"""Detailed comparison of media elements"""
differences = []
for media_type in ['images', 'videos', 'audios']:
old_srcs = {item['src'] for item in old_media[media_type]}
new_srcs = {item['src'] for item in new_media[media_type]}
for media_type in ["images", "videos", "audios"]:
old_srcs = {item["src"] for item in old_media[media_type]}
new_srcs = {item["src"] for item in new_media[media_type]}
missing = old_srcs - new_srcs
extra = new_srcs - old_srcs
@@ -456,10 +466,10 @@ class ScraperEquivalenceTester:
# Compare media attributes for common sources
common = old_srcs & new_srcs
for src in common:
old_item = next(m for m in old_media[media_type] if m['src'] == src)
new_item = next(m for m in new_media[media_type] if m['src'] == src)
old_item = next(m for m in old_media[media_type] if m["src"] == src)
new_item = next(m for m in new_media[media_type] if m["src"] == src)
for attr in ['alt', 'description']:
for attr in ["alt", "description"]:
if old_item.get(attr) != new_item.get(attr):
differences.append(
f"{media_type} attribute mismatch for {src} - {attr}:"
@@ -474,10 +484,10 @@ class ScraperEquivalenceTester:
differences = []
def normalize_html(html: str) -> Tuple[str, str]:
soup = BeautifulSoup(html, 'lxml')
soup = BeautifulSoup(html, "lxml")
# Get both structure and text
structure = ' '.join(tag.name for tag in soup.find_all())
text = ' '.join(soup.get_text().split())
structure = " ".join(tag.name for tag in soup.find_all())
text = " ".join(soup.get_text().split())
return structure, text
old_structure, old_text = normalize_html(old_html)
@@ -487,46 +497,47 @@ class ScraperEquivalenceTester:
if abs(len(old_structure) - len(new_structure)) > 100:
# if old_structure != new_structure:
diff = difflib.unified_diff(
old_structure.split(),
new_structure.split(),
lineterm=''
old_structure.split(), new_structure.split(), lineterm=""
)
differences.append("HTML structure differences:\n" + '\n'.join(diff))
differences.append("HTML structure differences:\n" + "\n".join(diff))
# Compare text content
if abs(len(old_text) - len(new_text)) > 100:
# if old_text != new_text:
# Show detailed text differences
text_diff = difflib.unified_diff(
old_text.split(),
new_text.split(),
lineterm=''
old_text.split(), new_text.split(), lineterm=""
)
differences.append("Text content differences:\n" + '\n'.join(text_diff))
differences.append("Text content differences:\n" + "\n".join(text_diff))
return differences
def compare_results(self, old_result: Dict, new_result: Dict) -> Dict[str, List[str]]:
def compare_results(
self, old_result: Dict, new_result: Dict
) -> Dict[str, List[str]]:
"""Comprehensive comparison of scraper outputs"""
differences = {}
# Compare links
link_differences = self.deep_compare_links(old_result['links'], new_result['links'])
link_differences = self.deep_compare_links(
old_result["links"], new_result["links"]
)
if link_differences:
differences['links'] = link_differences
differences["links"] = link_differences
# Compare media
media_differences = self.deep_compare_media(old_result['media'], new_result['media'])
media_differences = self.deep_compare_media(
old_result["media"], new_result["media"]
)
if media_differences:
differences['media'] = media_differences
differences["media"] = media_differences
# Compare HTML
html_differences = self.compare_html_content(
old_result['cleaned_html'],
new_result['cleaned_html']
old_result["cleaned_html"], new_result["cleaned_html"]
)
if html_differences:
differences['html'] = html_differences
differences["html"] = html_differences
return differences
@@ -535,10 +546,7 @@ class ScraperEquivalenceTester:
# We'll still keep some "test_cases" logic from above (basic, complex, malformed).
# But we add a new section for the complicated HTML scenarios.
results = {
'tests': [],
'summary': {'passed': 0, 'failed': 0}
}
results = {"tests": [], "summary": {"passed": 0, "failed": 0}}
# 1) First, run the existing 3 built-in test cases (basic, complex, malformed).
# for case_name, html in self.test_cases.items():
@@ -616,33 +624,38 @@ class ScraperEquivalenceTester:
lxml_time = time.time() - start
diffs = {}
link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links'])
link_diff = self.deep_compare_links(
orig_result["links"], lxml_result["links"]
)
if link_diff:
diffs['links'] = link_diff
diffs["links"] = link_diff
media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media'])
media_diff = self.deep_compare_media(
orig_result["media"], lxml_result["media"]
)
if media_diff:
diffs['media'] = media_diff
diffs["media"] = media_diff
html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html'])
html_diff = self.compare_html_content(
orig_result["cleaned_html"], lxml_result["cleaned_html"]
)
if html_diff:
diffs['html'] = html_diff
diffs["html"] = html_diff
test_result = {
'case': f"complicated_{scenario_name}",
'lxml_mode': {
'differences': diffs,
'execution_time': lxml_time
},
'original_time': orig_time
"case": f"complicated_{scenario_name}",
"lxml_mode": {"differences": diffs, "execution_time": lxml_time},
"original_time": orig_time,
}
results['tests'].append(test_result)
results["tests"].append(test_result)
if not diffs:
results['summary']['passed'] += 1
print(f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)")
results["summary"]["passed"] += 1
print(
f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)"
)
else:
results['summary']['failed'] += 1
results["summary"]["failed"] += 1
print("❌ Differences found:")
for category, dlist in diffs.items():
print(f" {category}:")
@@ -658,19 +671,21 @@ class ScraperEquivalenceTester:
print(f"Passed: {results['summary']['passed']}")
print(f"Failed: {results['summary']['failed']}")
for test in results['tests']:
for test in results["tests"]:
print(f"\nTest Case: {test['case']}")
if not test['lxml_mode']['differences']:
if not test["lxml_mode"]["differences"]:
print("✅ All implementations produced identical results")
print(f"Times - Original: {test['original_time']:.3f}s, "
f"LXML: {test['lxml_mode']['execution_time']:.3f}s")
print(
f"Times - Original: {test['original_time']:.3f}s, "
f"LXML: {test['lxml_mode']['execution_time']:.3f}s"
)
else:
print("❌ Differences found:")
if test['lxml_mode']['differences']:
if test["lxml_mode"]["differences"]:
print("\nLXML Mode Differences:")
for category, diffs in test['lxml_mode']['differences'].items():
for category, diffs in test["lxml_mode"]["differences"].items():
print(f"\n{category}:")
for diff in diffs:
print(f" - {diff}")
@@ -682,7 +697,7 @@ def main():
tester.print_report(results)
# Save detailed results for debugging
with open('scraper_equivalence_results.json', 'w') as f:
with open("scraper_equivalence_results.json", "w") as f:
json.dump(results, f, indent=2)

View File

@@ -4,10 +4,10 @@
# - **State:** open
import os, sys, time
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__)))
import asyncio
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
import os
import time
from typing import Dict, Any
@@ -16,12 +16,12 @@ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Get current directory
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
"""Helper function to print test results."""
print(f"\n{'='*20} {name} {'='*20}")
print(f"Execution time: {execution_time:.4f} seconds")
# Save markdown to files
for key, content in result.items():
if isinstance(content, str):
@@ -36,6 +36,7 @@ def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
# print(preview)
# print(f"Total length: {len(content)} characters")
def test_basic_markdown_conversion():
"""Test basic markdown conversion with links."""
with open(__location__ + "/data/wikipedia.html", "r") as f:
@@ -45,23 +46,29 @@ def test_basic_markdown_conversion():
start_time = time.perf_counter()
result = generator.generate_markdown(
cleaned_html=cleaned_html,
base_url="https://en.wikipedia.org"
cleaned_html=cleaned_html, base_url="https://en.wikipedia.org"
)
execution_time = time.perf_counter() - start_time
print_test_result("Basic Markdown Conversion", {
'raw': result.raw_markdown,
'with_citations': result.markdown_with_citations,
'references': result.references_markdown
}, execution_time)
print_test_result(
"Basic Markdown Conversion",
{
"raw": result.raw_markdown,
"with_citations": result.markdown_with_citations,
"references": result.references_markdown,
},
execution_time,
)
# Basic assertions
assert result.raw_markdown, "Raw markdown should not be empty"
assert result.markdown_with_citations, "Markdown with citations should not be empty"
assert result.references_markdown, "References should not be empty"
assert "" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets"
assert "## References" in result.references_markdown, "Should contain references section"
assert (
"## References" in result.references_markdown
), "Should contain references section"
def test_relative_links():
"""Test handling of relative links with base URL."""
@@ -72,14 +79,14 @@ def test_relative_links():
generator = DefaultMarkdownGenerator()
result = generator.generate_markdown(
cleaned_html=markdown,
base_url="https://en.wikipedia.org"
cleaned_html=markdown, base_url="https://en.wikipedia.org"
)
assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown
assert "https://example.com" in result.references_markdown
assert "https://en.wikipedia.org/images/test.png" in result.references_markdown
def test_duplicate_links():
"""Test handling of duplicate links."""
markdown = """
@@ -88,14 +95,14 @@ def test_duplicate_links():
generator = DefaultMarkdownGenerator()
result = generator.generate_markdown(
cleaned_html=markdown,
base_url="https://example.com"
cleaned_html=markdown, base_url="https://example.com"
)
# Count citations in markdown
citations = result.markdown_with_citations.count("⟨1⟩")
assert citations == 2, "Same link should use same citation number"
def test_link_descriptions():
"""Test handling of link titles and descriptions."""
markdown = """
@@ -104,12 +111,16 @@ def test_link_descriptions():
generator = DefaultMarkdownGenerator()
result = generator.generate_markdown(
cleaned_html=markdown,
base_url="https://example.com"
cleaned_html=markdown, base_url="https://example.com"
)
assert "Test Title" in result.references_markdown, "Link title should be in references"
assert "link with description" in result.references_markdown, "Link text should be in references"
assert (
"Test Title" in result.references_markdown
), "Link title should be in references"
assert (
"link with description" in result.references_markdown
), "Link text should be in references"
def test_performance_large_document():
"""Test performance with large document."""
@@ -125,18 +136,20 @@ def test_performance_large_document():
for i in range(iterations):
start_time = time.perf_counter()
result = generator.generate_markdown(
cleaned_html=markdown,
base_url="https://en.wikipedia.org"
cleaned_html=markdown, base_url="https://en.wikipedia.org"
)
end_time = time.perf_counter()
times.append(end_time - start_time)
avg_time = sum(times) / len(times)
print(f"\n{'='*20} Performance Test {'='*20}")
print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds")
print(
f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds"
)
print(f"Min time: {min(times):.4f} seconds")
print(f"Max time: {max(times):.4f} seconds")
def test_image_links():
"""Test handling of image links."""
markdown = """
@@ -146,12 +159,16 @@ def test_image_links():
generator = DefaultMarkdownGenerator()
result = generator.generate_markdown(
cleaned_html=markdown,
base_url="https://example.com"
cleaned_html=markdown, base_url="https://example.com"
)
assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved"
assert "Image Title" in result.references_markdown, "Image title should be in references"
assert (
"![" in result.markdown_with_citations
), "Image markdown syntax should be preserved"
assert (
"Image Title" in result.references_markdown
), "Image title should be in references"
if __name__ == "__main__":
print("Running markdown generation strategy tests...")
@@ -162,4 +179,3 @@ if __name__ == "__main__":
test_link_descriptions()
test_performance_large_document()
test_image_links()

View File

@@ -1,8 +1,6 @@
import os
import sys
import pytest
import asyncio
import json
# Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,24 +8,37 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_word_count_threshold():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business"
result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True)
result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True)
result_no_threshold = await crawler.arun(
url=url, word_count_threshold=0, bypass_cache=True
)
result_with_threshold = await crawler.arun(
url=url, word_count_threshold=50, bypass_cache=True
)
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
@pytest.mark.asyncio
async def test_css_selector():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business"
css_selector = "h1, h2, h3"
result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True)
result = await crawler.arun(
url=url, css_selector=css_selector, bypass_cache=True
)
assert result.success
assert "<h1" in result.cleaned_html or "<h2" in result.cleaned_html or "<h3" in result.cleaned_html
assert (
"<h1" in result.cleaned_html
or "<h2" in result.cleaned_html
or "<h3" in result.cleaned_html
)
@pytest.mark.asyncio
async def test_javascript_execution():
@@ -37,12 +48,15 @@ async def test_javascript_execution():
# Crawl without JS
result_without_more = await crawler.arun(url=url, bypass_cache=True)
js_code = ["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"]
js_code = [
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
]
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
assert result_with_more.success
assert len(result_with_more.markdown) > len(result_without_more.markdown)
@pytest.mark.asyncio
async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -53,16 +67,20 @@ async def test_screenshot():
assert result.screenshot
assert isinstance(result.screenshot, str) # Should be a base64 encoded string
@pytest.mark.asyncio
async def test_custom_user_agent():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business"
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True)
result = await crawler.arun(
url=url, user_agent=custom_user_agent, bypass_cache=True
)
assert result.success
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
@pytest.mark.asyncio
async def test_extract_media_and_links():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -72,10 +90,11 @@ async def test_extract_media_and_links():
assert result.success
assert result.media
assert isinstance(result.media, dict)
assert 'images' in result.media
assert "images" in result.media
assert result.links
assert isinstance(result.links, dict)
assert 'internal' in result.links and 'external' in result.links
assert "internal" in result.links and "external" in result.links
@pytest.mark.asyncio
async def test_metadata_extraction():
@@ -87,7 +106,10 @@ async def test_metadata_extraction():
assert result.metadata
assert isinstance(result.metadata, dict)
# Check for common metadata fields
assert any(key in result.metadata for key in ['title', 'description', 'keywords'])
assert any(
key in result.metadata for key in ["title", "description", "keywords"]
)
# Entry point for debugging
if __name__ == "__main__":

View File

@@ -1,7 +1,6 @@
import os
import sys
import pytest
import asyncio
import time
# Add the parent directory to the Python path
@@ -10,6 +9,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_crawl_speed():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -24,6 +24,7 @@ async def test_crawl_speed():
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
@pytest.mark.asyncio
async def test_concurrent_crawling_performance():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -32,7 +33,7 @@ async def test_concurrent_crawling_performance():
"https://www.example.com",
"https://www.python.org",
"https://www.github.com",
"https://www.stackoverflow.com"
"https://www.stackoverflow.com",
]
start_time = time.time()
@@ -45,7 +46,10 @@ async def test_concurrent_crawling_performance():
assert all(result.success for result in results)
assert len(results) == len(urls)
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
assert (
total_time < len(urls) * 5
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
@pytest.mark.asyncio
async def test_crawl_speed_with_caching():
@@ -66,7 +70,10 @@ async def test_crawl_speed_with_caching():
print(f"First crawl time: {first_crawl_time:.2f} seconds")
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster"
assert (
second_crawl_time < first_crawl_time / 2
), "Cached crawl not significantly faster"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os
import sys
import pytest
import asyncio
import base64
from PIL import Image
import io
@@ -12,6 +11,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio
async def test_basic_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -26,6 +26,7 @@ async def test_basic_screenshot():
image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG"
@pytest.mark.asyncio
async def test_screenshot_with_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -34,10 +35,7 @@ async def test_screenshot_with_wait_for():
wait_for = "css:#content" # Wait for the main content to load
result = await crawler.arun(
url=url,
bypass_cache=True,
screenshot=True,
wait_for=wait_for
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
)
assert result.success
@@ -51,6 +49,7 @@ async def test_screenshot_with_wait_for():
# You might want to add more specific checks here, like image dimensions
# or even use image recognition to verify certain elements are present
@pytest.mark.asyncio
async def test_screenshot_with_js_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -58,10 +57,7 @@ async def test_screenshot_with_js_wait_for():
wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null"
result = await crawler.arun(
url=url,
bypass_cache=True,
screenshot=True,
wait_for=wait_for
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
)
assert result.success
@@ -71,6 +67,7 @@ async def test_screenshot_with_js_wait_for():
image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG"
@pytest.mark.asyncio
async def test_screenshot_without_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -85,6 +82,7 @@ async def test_screenshot_without_wait_for():
image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG"
@pytest.mark.asyncio
async def test_screenshot_comparison():
async with AsyncWebCrawler(verbose=True) as crawler:
@@ -93,17 +91,12 @@ async def test_screenshot_comparison():
# Take screenshot without wait_for
result_without_wait = await crawler.arun(
url=url,
bypass_cache=True,
screenshot=True
url=url, bypass_cache=True, screenshot=True
)
# Take screenshot with wait_for
result_with_wait = await crawler.arun(
url=url,
bypass_cache=True,
screenshot=True,
wait_for=wait_for
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
)
assert result_without_wait.success and result_with_wait.success
@@ -111,14 +104,19 @@ async def test_screenshot_comparison():
assert result_with_wait.screenshot is not None
# Compare the two screenshots
image_without_wait = Image.open(io.BytesIO(base64.b64decode(result_without_wait.screenshot)))
image_with_wait = Image.open(io.BytesIO(base64.b64decode(result_with_wait.screenshot)))
image_without_wait = Image.open(
io.BytesIO(base64.b64decode(result_without_wait.screenshot))
)
image_with_wait = Image.open(
io.BytesIO(base64.b64decode(result_with_wait.screenshot))
)
# This is a simple size comparison. In a real-world scenario, you might want to use
# more sophisticated image comparison techniques.
assert image_with_wait.size[0] >= image_without_wait.size[0]
assert image_with_wait.size[1] >= image_without_wait.size[1]
# Entry point for debugging
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -6,15 +6,24 @@ import base64
import os
from typing import Dict, Any
class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
self.base_url = base_url
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {}
self.api_token = api_token or os.getenv(
"CRAWL4AI_API_TOKEN"
) # Check environment variable as fallback
self.headers = (
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
)
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers)
response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
if response.status_code == 403:
raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"]
@@ -24,9 +33,13 @@ class Crawl4AiTester:
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers)
result = requests.get(
f"{self.base_url}/task/{task_id}", headers=self.headers
)
status = result.json()
if status["status"] == "failed":
@@ -39,17 +52,23 @@ class Crawl4AiTester:
time.sleep(2)
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60)
response = requests.post(
f"{self.base_url}/crawl_sync",
json=request_data,
headers=self.headers,
timeout=60,
)
if response.status_code == 408:
raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status()
return response.json()
def test_docker_deployment(version="basic"):
tester = Crawl4AiTester(
# base_url="http://localhost:11235" ,
base_url="https://crawl4ai-sby74.ondigitalocean.app",
api_token="test"
api_token="test",
)
print(f"Testing Crawl4AI Docker {version} version")
@@ -60,7 +79,7 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json())
break
except requests.exceptions.RequestException as e:
except requests.exceptions.RequestException:
if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts")
sys.exit(1)
@@ -88,7 +107,7 @@ def test_basic_crawl(tester: Crawl4AiTester):
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10,
"session_id": "test"
"session_id": "test",
}
result = tester.submit_and_wait(request)
@@ -96,19 +115,21 @@ def test_basic_crawl(tester: Crawl4AiTester):
assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_sync(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Sync) ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10,
"session_id": "test"
"session_id": "test",
}
result = tester.submit_sync(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['status'] == 'completed'
assert result['result']['success']
assert len(result['result']['markdown']) > 0
assert result["status"] == "completed"
assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===")
@@ -119,32 +140,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
"wait_for": "article.tease-card:nth-child(10)",
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 7,
"css_selector": ".wide-tease-item__description",
"crawler_params": {
"headless": True
},
"extra": {"word_count_threshold": 10}
"crawler_params": {"headless": True},
"extra": {"word_count_threshold": 10},
}
result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===")
schema = {
@@ -165,19 +183,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price",
"selector": "td:nth-child(2)",
"type": "text",
}
},
],
}
request = {
"urls": "https://www.coinbase.com/explore",
"priority": 9,
"extraction_config": {
"type": "json_css",
"params": {
"schema": schema
}
}
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
}
result = tester.submit_and_wait(request)
@@ -187,6 +200,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"]
assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===")
schema = {
@@ -194,18 +208,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": {
"model_name": {
"type": "string",
"description": "Name of the OpenAI model."
"description": "Name of the OpenAI model.",
},
"input_fee": {
"type": "string",
"description": "Fee for input token for the OpenAI model."
"description": "Fee for input token for the OpenAI model.",
},
"output_fee": {
"type": "string",
"description": "Fee for output token for the OpenAI model."
}
"description": "Fee for output token for the OpenAI model.",
},
"required": ["model_name", "input_fee", "output_fee"]
},
"required": ["model_name", "input_fee", "output_fee"],
}
request = {
@@ -218,10 +232,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema,
"extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
}
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
},
"crawler_params": {"word_count_threshold": 1}
},
"crawler_params": {"word_count_threshold": 1},
}
try:
@@ -233,6 +247,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===")
schema = {
@@ -240,18 +255,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": {
"article_title": {
"type": "string",
"description": "The main title of the news article"
"description": "The main title of the news article",
},
"summary": {
"type": "string",
"description": "A brief summary of the article content"
"description": "A brief summary of the article content",
},
"main_topics": {
"type": "array",
"items": {"type": "string"},
"description": "Main topics or themes discussed in the article"
}
}
"description": "Main topics or themes discussed in the article",
},
},
}
request = {
@@ -263,11 +278,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2",
"schema": schema,
"extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics."
}
"instruction": "Extract the main article information including title, summary, and main topics.",
},
},
"extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True}
"crawler_params": {"verbose": True},
}
try:
@@ -278,6 +293,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e:
print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===")
request = {
@@ -289,9 +305,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy",
"word_count_threshold": 10,
"max_dist": 0.2,
"top_k": 3
}
}
"top_k": 3,
},
},
}
try:
@@ -303,15 +319,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e:
print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 5,
"screenshot": True,
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
result = tester.submit_and_wait(request)
@@ -326,6 +341,7 @@ def test_screenshot(tester: Crawl4AiTester):
assert result["result"]["success"]
if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full"

View File

@@ -1,9 +1,9 @@
import asyncio
from pathlib import Path
from crawl4ai.docs_manager import DocsManager
from click.testing import CliRunner
from crawl4ai.cli import cli
def test_cli():
"""Test all CLI commands"""
runner = CliRunner()
@@ -35,9 +35,10 @@ def test_cli():
# print(f"First 200 chars: {result.output[:200]}...")
print("\n5. Testing combine all sections...")
result = runner.invoke(cli, ['docs', 'combine', '--mode', 'condensed'])
result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"])
print(f"Status: {'' if result.exit_code == 0 else ''}")
print(f"First 200 chars: {result.output[:200]}...")
if __name__ == "__main__":
test_cli()

View File

@@ -6,11 +6,14 @@ import base64
import os
from typing import Dict, Any
class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235"):
self.base_url = base_url
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data)
task_id = response.json()["task_id"]
@@ -20,7 +23,9 @@ class Crawl4AiTester:
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}")
status = result.json()
@@ -34,6 +39,7 @@ class Crawl4AiTester:
time.sleep(2)
def test_docker_deployment(version="basic"):
tester = Crawl4AiTester()
print(f"Testing Crawl4AI Docker {version} version")
@@ -45,7 +51,7 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json())
break
except requests.exceptions.RequestException as e:
except requests.exceptions.RequestException:
if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts")
sys.exit(1)
@@ -68,16 +74,14 @@ def test_docker_deployment(version="basic"):
def test_basic_crawl(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
result = tester.submit_and_wait(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===")
request = {
@@ -87,32 +91,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
"wait_for": "article.tease-card:nth-child(10)",
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 7,
"css_selector": ".wide-tease-item__description",
"crawler_params": {
"headless": True
},
"extra": {"word_count_threshold": 10}
"crawler_params": {"headless": True},
"extra": {"word_count_threshold": 10},
}
result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===")
schema = {
@@ -133,19 +134,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price",
"selector": "td:nth-child(2)",
"type": "text",
}
},
],
}
request = {
"urls": "https://www.coinbase.com/explore",
"priority": 9,
"extraction_config": {
"type": "json_css",
"params": {
"schema": schema
}
}
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
}
result = tester.submit_and_wait(request)
@@ -155,6 +151,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"]
assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===")
schema = {
@@ -162,18 +159,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": {
"model_name": {
"type": "string",
"description": "Name of the OpenAI model."
"description": "Name of the OpenAI model.",
},
"input_fee": {
"type": "string",
"description": "Fee for input token for the OpenAI model."
"description": "Fee for input token for the OpenAI model.",
},
"output_fee": {
"type": "string",
"description": "Fee for output token for the OpenAI model."
}
"description": "Fee for output token for the OpenAI model.",
},
"required": ["model_name", "input_fee", "output_fee"]
},
"required": ["model_name", "input_fee", "output_fee"],
}
request = {
@@ -186,10 +183,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema,
"extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
}
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
},
"crawler_params": {"word_count_threshold": 1}
},
"crawler_params": {"word_count_threshold": 1},
}
try:
@@ -201,6 +198,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===")
schema = {
@@ -208,18 +206,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": {
"article_title": {
"type": "string",
"description": "The main title of the news article"
"description": "The main title of the news article",
},
"summary": {
"type": "string",
"description": "A brief summary of the article content"
"description": "A brief summary of the article content",
},
"main_topics": {
"type": "array",
"items": {"type": "string"},
"description": "Main topics or themes discussed in the article"
}
}
"description": "Main topics or themes discussed in the article",
},
},
}
request = {
@@ -231,11 +229,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2",
"schema": schema,
"extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics."
}
"instruction": "Extract the main article information including title, summary, and main topics.",
},
},
"extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True}
"crawler_params": {"verbose": True},
}
try:
@@ -246,6 +244,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e:
print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===")
request = {
@@ -257,9 +256,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy",
"word_count_threshold": 10,
"max_dist": 0.2,
"top_k": 3
}
}
"top_k": 3,
},
},
}
try:
@@ -271,15 +270,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e:
print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===")
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 5,
"screenshot": True,
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
result = tester.submit_and_wait(request)
@@ -294,6 +292,7 @@ def test_screenshot(tester: Crawl4AiTester):
assert result["result"]["success"]
if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full"

View File

@@ -3,6 +3,7 @@ from crawl4ai.async_logger import AsyncLogger
from pathlib import Path
import asyncio
async def main():
current_file = Path(__file__).resolve()
# base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
@@ -26,8 +27,7 @@ async def main():
# Generate index files
print("\nGenerating index files...")
await manager.generate_index_files(
force_generate_facts=False,
clear_bm25_cache=False
force_generate_facts=False, clear_bm25_cache=False
)
# Test some relevant queries about Crawl4AI
@@ -41,9 +41,12 @@ async def main():
results = manager.search(query, top_k=2)
print(f"Results length: {len(results)} characters")
if results:
print("First 200 chars of results:", results[:200].replace('\n', ' '), "...")
print(
"First 200 chars of results:", results[:200].replace("\n", " "), "..."
)
else:
print("No results found")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -3,8 +3,8 @@ import aiohttp
import json
import time
import os
from typing import Optional, Dict, Any
from pydantic import BaseModel, HttpUrl
from typing import Dict, Any
class NBCNewsAPITest:
def __init__(self, base_url: str = "http://localhost:8000"):
@@ -20,7 +20,9 @@ class NBCNewsAPITest:
await self.session.close()
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response:
async with self.session.post(
f"{self.base_url}/crawl", json=request_data
) as response:
result = await response.json()
return result["task_id"]
@@ -28,11 +30,15 @@ class NBCNewsAPITest:
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
return await response.json()
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]:
async def wait_for_task(
self, task_id: str, timeout: int = 300, poll_interval: int = 2
) -> Dict[str, Any]:
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
status = await self.get_task_status(task_id)
if status["status"] in ["completed", "failed"]:
@@ -44,13 +50,11 @@ class NBCNewsAPITest:
async with self.session.get(f"{self.base_url}/health") as response:
return await response.json()
async def test_basic_crawl():
print("\n=== Testing Basic Crawl ===")
async with NBCNewsAPITest() as api:
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
@@ -58,6 +62,7 @@ async def test_basic_crawl():
assert "result" in result
assert result["result"]["success"]
async def test_js_execution():
print("\n=== Testing JS Execution ===")
async with NBCNewsAPITest() as api:
@@ -68,9 +73,7 @@ async def test_js_execution():
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
"wait_for": "article.tease-card:nth-child(10)",
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
@@ -78,13 +81,14 @@ async def test_js_execution():
assert result["status"] == "completed"
assert result["result"]["success"]
async def test_css_selector():
print("\n=== Testing CSS Selector ===")
async with NBCNewsAPITest() as api:
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 7,
"css_selector": ".wide-tease-item__description"
"css_selector": ".wide-tease-item__description",
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
@@ -92,6 +96,7 @@ async def test_css_selector():
assert result["status"] == "completed"
assert result["result"]["success"]
async def test_structured_extraction():
print("\n=== Testing Structured Extraction ===")
async with NBCNewsAPITest() as api:
@@ -99,34 +104,25 @@ async def test_structured_extraction():
"name": "NBC News Articles",
"baseSelector": "article.tease-card",
"fields": [
{
"name": "title",
"selector": "h2",
"type": "text"
},
{"name": "title", "selector": "h2", "type": "text"},
{
"name": "description",
"selector": ".tease-card__description",
"type": "text"
"type": "text",
},
{
"name": "link",
"selector": "a",
"type": "attribute",
"attribute": "href"
}
]
"attribute": "href",
},
],
}
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 9,
"extraction_config": {
"type": "json_css",
"params": {
"schema": schema
}
}
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
@@ -136,6 +132,7 @@ async def test_structured_extraction():
assert result["result"]["success"]
assert len(extracted) > 0
async def test_batch_crawl():
print("\n=== Testing Batch Crawl ===")
async with NBCNewsAPITest() as api:
@@ -143,12 +140,10 @@ async def test_batch_crawl():
"urls": [
"https://www.nbcnews.com/business",
"https://www.nbcnews.com/business/consumer",
"https://www.nbcnews.com/business/economy"
"https://www.nbcnews.com/business/economy",
],
"priority": 6,
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
@@ -157,6 +152,7 @@ async def test_batch_crawl():
assert "results" in result
assert len(result["results"]) == 3
async def test_llm_extraction():
print("\n=== Testing LLM Extraction with Ollama ===")
async with NBCNewsAPITest() as api:
@@ -165,19 +161,19 @@ async def test_llm_extraction():
"properties": {
"article_title": {
"type": "string",
"description": "The main title of the news article"
"description": "The main title of the news article",
},
"summary": {
"type": "string",
"description": "A brief summary of the article content"
"description": "A brief summary of the article content",
},
"main_topics": {
"type": "array",
"items": {"type": "string"},
"description": "Main topics or themes discussed in the article"
}
"description": "Main topics or themes discussed in the article",
},
"required": ["article_title", "summary", "main_topics"]
},
"required": ["article_title", "summary", "main_topics"],
}
request = {
@@ -191,13 +187,10 @@ async def test_llm_extraction():
"schema": schema,
"extraction_type": "schema",
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
Focus on the primary business news article on the page."""
}
Focus on the primary business news article on the page.""",
},
"crawler_params": {
"headless": True,
"word_count_threshold": 1
}
},
"crawler_params": {"headless": True, "word_count_threshold": 1},
}
task_id = await api.submit_crawl(request)
@@ -205,12 +198,13 @@ async def test_llm_extraction():
if result["status"] == "completed":
extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted article analysis:")
print("Extracted article analysis:")
print(json.dumps(extracted, indent=2))
assert result["status"] == "completed"
assert result["result"]["success"]
async def test_screenshot():
print("\n=== Testing Screenshot ===")
async with NBCNewsAPITest() as api:
@@ -218,9 +212,7 @@ async def test_screenshot():
"urls": "https://www.nbcnews.com/business",
"priority": 5,
"screenshot": True,
"crawler_params": {
"headless": True
}
"crawler_params": {"headless": True},
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
@@ -229,6 +221,7 @@ async def test_screenshot():
assert result["result"]["success"]
assert result["result"]["screenshot"] is not None
async def test_priority_handling():
print("\n=== Testing Priority Handling ===")
async with NBCNewsAPITest() as api:
@@ -236,7 +229,7 @@ async def test_priority_handling():
low_priority = {
"urls": "https://www.nbcnews.com/business",
"priority": 1,
"crawler_params": {"headless": True}
"crawler_params": {"headless": True},
}
low_task_id = await api.submit_crawl(low_priority)
@@ -244,7 +237,7 @@ async def test_priority_handling():
high_priority = {
"urls": "https://www.nbcnews.com/business/consumer",
"priority": 10,
"crawler_params": {"headless": True}
"crawler_params": {"headless": True},
}
high_task_id = await api.submit_crawl(high_priority)
@@ -256,6 +249,7 @@ async def test_priority_handling():
assert high_result["status"] == "completed"
assert low_result["status"] == "completed"
async def main():
try:
# Start with health check
@@ -277,5 +271,6 @@ async def main():
print(f"Test failed: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,21 +1,26 @@
import nest_asyncio
nest_asyncio.apply()
import asyncio
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, LXMLWebScrapingStrategy, CacheMode
from crawl4ai import (
AsyncWebCrawler,
CrawlerRunConfig,
LXMLWebScrapingStrategy,
CacheMode,
)
async def main():
config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
scraping_strategy=LXMLWebScrapingStrategy() # Faster alternative to default BeautifulSoup
scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(
url="https://example.com",
config=config
)
result = await crawler.arun(url="https://example.com", config=config)
print(f"Success: {result.success}")
print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,10 +1,19 @@
import unittest, os
from crawl4ai.web_crawler import WebCrawler
from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy
from crawl4ai.chunking_strategy import (
RegexChunking,
FixedLengthWordChunking,
SlidingWindowChunking,
)
from crawl4ai.extraction_strategy import (
CosineStrategy,
LLMExtractionStrategy,
TopicExtractionStrategy,
NoExtractionStrategy,
)
class TestWebCrawler(unittest.TestCase):
def setUp(self):
self.crawler = WebCrawler()
@@ -14,52 +23,72 @@ class TestWebCrawler(unittest.TestCase):
def test_run_default_strategies(self):
result = self.crawler.run(
url='https://www.nbcnews.com/business',
url="https://www.nbcnews.com/business",
word_count_threshold=5,
chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(), bypass_cache=True
extraction_strategy=CosineStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success, "Failed to crawl and extract using default strategies"
)
self.assertTrue(result.success, "Failed to crawl and extract using default strategies")
def test_run_different_strategies(self):
url = 'https://www.nbcnews.com/business'
url = "https://www.nbcnews.com/business"
# Test with FixedLengthWordChunking and LLMExtractionStrategy
result = self.crawler.run(
url=url,
word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True
extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY")
),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy",
)
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy")
# Test with SlidingWindowChunking and TopicExtractionStrategy
result = self.crawler.run(
url=url,
word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True
extraction_strategy=TopicExtractionStrategy(num_keywords=5),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy",
)
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy")
def test_invalid_url(self):
with self.assertRaises(Exception) as context:
self.crawler.run(url='invalid_url', bypass_cache=True)
self.crawler.run(url="invalid_url", bypass_cache=True)
self.assertIn("Invalid URL", str(context.exception))
def test_unsupported_extraction_strategy(self):
with self.assertRaises(Exception) as context:
self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True)
self.crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy="UnsupportedStrategy",
bypass_cache=True,
)
self.assertIn("Unsupported extraction strategy", str(context.exception))
def test_invalid_css_selector(self):
with self.assertRaises(ValueError) as context:
self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True)
self.crawler.run(
url="https://www.nbcnews.com/business",
css_selector="invalid_selector",
bypass_cache=True,
)
self.assertIn("Invalid CSS selector", str(context.exception))
def test_crawl_with_cache_and_bypass_cache(self):
url = 'https://www.nbcnews.com/business'
url = "https://www.nbcnews.com/business"
# First crawl with cache enabled
result = self.crawler.run(url=url, bypass_cache=False)
@@ -70,10 +99,7 @@ class TestWebCrawler(unittest.TestCase):
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
def test_fetch_multiple_pages(self):
urls = [
'https://www.nbcnews.com/business',
'https://www.bbc.com/news'
]
urls = ["https://www.nbcnews.com/business", "https://www.bbc.com/news"]
results = []
for url in urls:
result = self.crawler.run(
@@ -81,31 +107,42 @@ class TestWebCrawler(unittest.TestCase):
word_count_threshold=5,
chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(),
bypass_cache=True
bypass_cache=True,
)
results.append(result)
self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages")
for result in results:
self.assertTrue(result.success, "Failed to crawl and extract a page in the list")
self.assertTrue(
result.success, "Failed to crawl and extract a page in the list"
)
def test_run_fixed_length_word_chunking_and_no_extraction(self):
result = self.crawler.run(
url='https://www.nbcnews.com/business',
url="https://www.nbcnews.com/business",
word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=NoExtractionStrategy(), bypass_cache=True
extraction_strategy=NoExtractionStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy",
)
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy")
def test_run_sliding_window_and_no_extraction(self):
result = self.crawler.run(
url='https://www.nbcnews.com/business',
url="https://www.nbcnews.com/business",
word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=NoExtractionStrategy(), bypass_cache=True
extraction_strategy=NoExtractionStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy",
)
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()