Apply Ruff Corrections

This commit is contained in:
UncleCode
2025-01-13 19:19:58 +08:00
parent c3370ec5da
commit 8ec12d7d68
84 changed files with 6861 additions and 5076 deletions

8
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,8 @@
# .pre-commit-config.yaml
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@@ -2,14 +2,28 @@
from .async_webcrawler import AsyncWebCrawler, CacheMode from .async_webcrawler import AsyncWebCrawler, CacheMode
from .async_configs import BrowserConfig, CrawlerRunConfig from .async_configs import BrowserConfig, CrawlerRunConfig
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy from .content_scraping_strategy import (
from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy ContentScrapingStrategy,
WebScrapingStrategy,
LXMLWebScrapingStrategy,
)
from .extraction_strategy import (
ExtractionStrategy,
LLMExtractionStrategy,
CosineStrategy,
JsonCssExtractionStrategy,
)
from .chunking_strategy import ChunkingStrategy, RegexChunking from .chunking_strategy import ChunkingStrategy, RegexChunking
from .markdown_generation_strategy import DefaultMarkdownGenerator from .markdown_generation_strategy import DefaultMarkdownGenerator
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter from .content_filter_strategy import PruningContentFilter, BM25ContentFilter
from .models import CrawlResult, MarkdownGenerationResult from .models import CrawlResult, MarkdownGenerationResult
from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode from .async_dispatcher import (
from .__version__ import __version__ MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
)
__all__ = [ __all__ = [
"AsyncWebCrawler", "AsyncWebCrawler",
@@ -18,40 +32,45 @@ __all__ = [
"ContentScrapingStrategy", "ContentScrapingStrategy",
"WebScrapingStrategy", "WebScrapingStrategy",
"LXMLWebScrapingStrategy", "LXMLWebScrapingStrategy",
'BrowserConfig', "BrowserConfig",
'CrawlerRunConfig', "CrawlerRunConfig",
'ExtractionStrategy', "ExtractionStrategy",
'LLMExtractionStrategy', "LLMExtractionStrategy",
'CosineStrategy', "CosineStrategy",
'JsonCssExtractionStrategy', "JsonCssExtractionStrategy",
'ChunkingStrategy', "ChunkingStrategy",
'RegexChunking', "RegexChunking",
'DefaultMarkdownGenerator', "DefaultMarkdownGenerator",
'PruningContentFilter', "PruningContentFilter",
'BM25ContentFilter', "BM25ContentFilter",
'MemoryAdaptiveDispatcher', "MemoryAdaptiveDispatcher",
'SemaphoreDispatcher', "SemaphoreDispatcher",
'RateLimiter', "RateLimiter",
'CrawlerMonitor', "CrawlerMonitor",
'DisplayMode', "DisplayMode",
'MarkdownGenerationResult', "MarkdownGenerationResult",
] ]
def is_sync_version_installed(): def is_sync_version_installed():
try: try:
import selenium import selenium
return True return True
except ImportError: except ImportError:
return False return False
if is_sync_version_installed(): if is_sync_version_installed():
try: try:
from .web_crawler import WebCrawler from .web_crawler import WebCrawler
__all__.append("WebCrawler") __all__.append("WebCrawler")
except ImportError: except ImportError:
import warnings print(
print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.") "Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies."
)
else: else:
WebCrawler = None WebCrawler = None
# import warnings # import warnings
# print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.") # print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.")

View File

@@ -5,7 +5,6 @@ from .config import (
PAGE_TIMEOUT, PAGE_TIMEOUT,
IMAGE_SCORE_THRESHOLD, IMAGE_SCORE_THRESHOLD,
SOCIAL_MEDIA_DOMAINS, SOCIAL_MEDIA_DOMAINS,
) )
from .user_agent_generator import UserAgentGenerator from .user_agent_generator import UserAgentGenerator
from .extraction_strategy import ExtractionStrategy from .extraction_strategy import ExtractionStrategy
@@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
from typing import Union, List from typing import Union, List
class BrowserConfig: class BrowserConfig:
""" """
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy. Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
@@ -84,7 +84,7 @@ class BrowserConfig:
proxy: str = None, proxy: str = None,
proxy_config: dict = None, proxy_config: dict = None,
viewport_width: int = 1080, viewport_width: int = 1080,
viewport_height: int = 600, viewport_height: int = 600,
accept_downloads: bool = False, accept_downloads: bool = False,
downloads_path: str = None, downloads_path: str = None,
storage_state=None, storage_state=None,
@@ -103,7 +103,7 @@ class BrowserConfig:
text_mode: bool = False, text_mode: bool = False,
light_mode: bool = False, light_mode: bool = False,
extra_args: list = None, extra_args: list = None,
debugging_port : int = 9222, debugging_port: int = 9222,
): ):
self.browser_type = browser_type self.browser_type = browser_type
self.headless = headless self.headless = headless
@@ -142,7 +142,7 @@ class BrowserConfig:
self.user_agent = user_agenr_generator.generate() self.user_agent = user_agenr_generator.generate()
else: else:
pass pass
self.browser_hint = user_agenr_generator.generate_client_hints(self.user_agent) self.browser_hint = user_agenr_generator.generate_client_hints(self.user_agent)
self.headers.setdefault("sec-ch-ua", self.browser_hint) self.headers.setdefault("sec-ch-ua", self.browser_hint)
@@ -313,7 +313,7 @@ class CrawlerRunConfig:
Default: True. Default: True.
log_console (bool): If True, log console messages from the page. log_console (bool): If True, log console messages from the page.
Default: False. Default: False.
# Optional Parameters # Optional Parameters
url: str = None # This is not a compulsory parameter url: str = None # This is not a compulsory parameter
""" """
@@ -335,10 +335,8 @@ class CrawlerRunConfig:
prettiify: bool = False, prettiify: bool = False,
parser_type: str = "lxml", parser_type: str = "lxml",
scraping_strategy: ContentScrapingStrategy = None, scraping_strategy: ContentScrapingStrategy = None,
# SSL Parameters # SSL Parameters
fetch_ssl_certificate: bool = False, fetch_ssl_certificate: bool = False,
# Caching Parameters # Caching Parameters
cache_mode=None, cache_mode=None,
session_id: str = None, session_id: str = None,
@@ -346,7 +344,6 @@ class CrawlerRunConfig:
disable_cache: bool = False, disable_cache: bool = False,
no_cache_read: bool = False, no_cache_read: bool = False,
no_cache_write: bool = False, no_cache_write: bool = False,
# Page Navigation and Timing Parameters # Page Navigation and Timing Parameters
wait_until: str = "domcontentloaded", wait_until: str = "domcontentloaded",
page_timeout: int = PAGE_TIMEOUT, page_timeout: int = PAGE_TIMEOUT,
@@ -356,7 +353,6 @@ class CrawlerRunConfig:
mean_delay: float = 0.1, mean_delay: float = 0.1,
max_range: float = 0.3, max_range: float = 0.3,
semaphore_count: int = 5, semaphore_count: int = 5,
# Page Interaction Parameters # Page Interaction Parameters
js_code: Union[str, List[str]] = None, js_code: Union[str, List[str]] = None,
js_only: bool = False, js_only: bool = False,
@@ -369,7 +365,6 @@ class CrawlerRunConfig:
override_navigator: bool = False, override_navigator: bool = False,
magic: bool = False, magic: bool = False,
adjust_viewport_to_content: bool = False, adjust_viewport_to_content: bool = False,
# Media Handling Parameters # Media Handling Parameters
screenshot: bool = False, screenshot: bool = False,
screenshot_wait_for: float = None, screenshot_wait_for: float = None,
@@ -378,21 +373,18 @@ class CrawlerRunConfig:
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
image_score_threshold: int = IMAGE_SCORE_THRESHOLD, image_score_threshold: int = IMAGE_SCORE_THRESHOLD,
exclude_external_images: bool = False, exclude_external_images: bool = False,
# Link and Domain Handling Parameters # Link and Domain Handling Parameters
exclude_social_media_domains: list = None, exclude_social_media_domains: list = None,
exclude_external_links: bool = False, exclude_external_links: bool = False,
exclude_social_media_links: bool = False, exclude_social_media_links: bool = False,
exclude_domains: list = None, exclude_domains: list = None,
# Debugging and Logging Parameters # Debugging and Logging Parameters
verbose: bool = True, verbose: bool = True,
log_console: bool = False, log_console: bool = False,
url: str = None, url: str = None,
): ):
self.url = url self.url = url
# Content Processing Parameters # Content Processing Parameters
self.word_count_threshold = word_count_threshold self.word_count_threshold = word_count_threshold
self.extraction_strategy = extraction_strategy self.extraction_strategy = extraction_strategy
@@ -453,7 +445,9 @@ class CrawlerRunConfig:
self.exclude_external_images = exclude_external_images self.exclude_external_images = exclude_external_images
# Link and Domain Handling Parameters # Link and Domain Handling Parameters
self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS self.exclude_social_media_domains = (
exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
)
self.exclude_external_links = exclude_external_links self.exclude_external_links = exclude_external_links
self.exclude_social_media_links = exclude_social_media_links self.exclude_social_media_links = exclude_social_media_links
self.exclude_domains = exclude_domains or [] self.exclude_domains = exclude_domains or []
@@ -466,11 +460,15 @@ class CrawlerRunConfig:
if self.extraction_strategy is not None and not isinstance( if self.extraction_strategy is not None and not isinstance(
self.extraction_strategy, ExtractionStrategy self.extraction_strategy, ExtractionStrategy
): ):
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy") raise ValueError(
"extraction_strategy must be an instance of ExtractionStrategy"
)
if self.chunking_strategy is not None and not isinstance( if self.chunking_strategy is not None and not isinstance(
self.chunking_strategy, ChunkingStrategy self.chunking_strategy, ChunkingStrategy
): ):
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy") raise ValueError(
"chunking_strategy must be an instance of ChunkingStrategy"
)
# Set default chunking strategy if None # Set default chunking strategy if None
if self.chunking_strategy is None: if self.chunking_strategy is None:
@@ -494,10 +492,8 @@ class CrawlerRunConfig:
prettiify=kwargs.get("prettiify", False), prettiify=kwargs.get("prettiify", False),
parser_type=kwargs.get("parser_type", "lxml"), parser_type=kwargs.get("parser_type", "lxml"),
scraping_strategy=kwargs.get("scraping_strategy"), scraping_strategy=kwargs.get("scraping_strategy"),
# SSL Parameters # SSL Parameters
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False), fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
# Caching Parameters # Caching Parameters
cache_mode=kwargs.get("cache_mode"), cache_mode=kwargs.get("cache_mode"),
session_id=kwargs.get("session_id"), session_id=kwargs.get("session_id"),
@@ -505,7 +501,6 @@ class CrawlerRunConfig:
disable_cache=kwargs.get("disable_cache", False), disable_cache=kwargs.get("disable_cache", False),
no_cache_read=kwargs.get("no_cache_read", False), no_cache_read=kwargs.get("no_cache_read", False),
no_cache_write=kwargs.get("no_cache_write", False), no_cache_write=kwargs.get("no_cache_write", False),
# Page Navigation and Timing Parameters # Page Navigation and Timing Parameters
wait_until=kwargs.get("wait_until", "domcontentloaded"), wait_until=kwargs.get("wait_until", "domcontentloaded"),
page_timeout=kwargs.get("page_timeout", 60000), page_timeout=kwargs.get("page_timeout", 60000),
@@ -515,7 +510,6 @@ class CrawlerRunConfig:
mean_delay=kwargs.get("mean_delay", 0.1), mean_delay=kwargs.get("mean_delay", 0.1),
max_range=kwargs.get("max_range", 0.3), max_range=kwargs.get("max_range", 0.3),
semaphore_count=kwargs.get("semaphore_count", 5), semaphore_count=kwargs.get("semaphore_count", 5),
# Page Interaction Parameters # Page Interaction Parameters
js_code=kwargs.get("js_code"), js_code=kwargs.get("js_code"),
js_only=kwargs.get("js_only", False), js_only=kwargs.get("js_only", False),
@@ -528,29 +522,34 @@ class CrawlerRunConfig:
override_navigator=kwargs.get("override_navigator", False), override_navigator=kwargs.get("override_navigator", False),
magic=kwargs.get("magic", False), magic=kwargs.get("magic", False),
adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False), adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
# Media Handling Parameters # Media Handling Parameters
screenshot=kwargs.get("screenshot", False), screenshot=kwargs.get("screenshot", False),
screenshot_wait_for=kwargs.get("screenshot_wait_for"), screenshot_wait_for=kwargs.get("screenshot_wait_for"),
screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD), screenshot_height_threshold=kwargs.get(
"screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD
),
pdf=kwargs.get("pdf", False), pdf=kwargs.get("pdf", False),
image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD), image_description_min_word_threshold=kwargs.get(
image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD), "image_description_min_word_threshold",
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
),
image_score_threshold=kwargs.get(
"image_score_threshold", IMAGE_SCORE_THRESHOLD
),
exclude_external_images=kwargs.get("exclude_external_images", False), exclude_external_images=kwargs.get("exclude_external_images", False),
# Link and Domain Handling Parameters # Link and Domain Handling Parameters
exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS), exclude_social_media_domains=kwargs.get(
"exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS
),
exclude_external_links=kwargs.get("exclude_external_links", False), exclude_external_links=kwargs.get("exclude_external_links", False),
exclude_social_media_links=kwargs.get("exclude_social_media_links", False), exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
exclude_domains=kwargs.get("exclude_domains", []), exclude_domains=kwargs.get("exclude_domains", []),
# Debugging and Logging Parameters # Debugging and Logging Parameters
verbose=kwargs.get("verbose", True), verbose=kwargs.get("verbose", True),
log_console=kwargs.get("log_console", False), log_console=kwargs.get("log_console", False),
url=kwargs.get("url"), url=kwargs.get("url"),
) )
# Create a funciton returns dict of the object # Create a funciton returns dict of the object
def to_dict(self): def to_dict(self):
return { return {

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +1,29 @@
import os, sys import os
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
import asyncio import asyncio
from typing import Optional, Tuple, Dict from typing import Optional, Dict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import logging import logging
import json # Added for serialization/deserialization import json # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash from .utils import ensure_content_dirs, generate_content_hash
from .models import CrawlResult, MarkdownGenerationResult from .models import CrawlResult, MarkdownGenerationResult
import xxhash
import aiofiles import aiofiles
from .config import NEED_MIGRATION
from .version_manager import VersionManager from .version_manager import VersionManager
from .async_logger import AsyncLogger from .async_logger import AsyncLogger
from .utils import get_error_context, create_box_message from .utils import get_error_context, create_box_message
# Set up logging # Set up logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") base_directory = DB_PATH = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(DB_PATH, exist_ok=True) os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(base_directory, "crawl4ai.db") DB_PATH = os.path.join(base_directory, "crawl4ai.db")
class AsyncDatabaseManager: class AsyncDatabaseManager:
def __init__(self, pool_size: int = 10, max_retries: int = 3): def __init__(self, pool_size: int = 10, max_retries: int = 3):
self.db_path = DB_PATH self.db_path = DB_PATH
@@ -32,28 +34,27 @@ class AsyncDatabaseManager:
self.pool_lock = asyncio.Lock() self.pool_lock = asyncio.Lock()
self.init_lock = asyncio.Lock() self.init_lock = asyncio.Lock()
self.connection_semaphore = asyncio.Semaphore(pool_size) self.connection_semaphore = asyncio.Semaphore(pool_size)
self._initialized = False self._initialized = False
self.version_manager = VersionManager() self.version_manager = VersionManager()
self.logger = AsyncLogger( self.logger = AsyncLogger(
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"), log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
verbose=False, verbose=False,
tag_width=10 tag_width=10,
) )
async def initialize(self): async def initialize(self):
"""Initialize the database and connection pool""" """Initialize the database and connection pool"""
try: try:
self.logger.info("Initializing database", tag="INIT") self.logger.info("Initializing database", tag="INIT")
# Ensure the database file exists # Ensure the database file exists
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
# Check if version update is needed # Check if version update is needed
needs_update = self.version_manager.needs_update() needs_update = self.version_manager.needs_update()
# Always ensure base table exists # Always ensure base table exists
await self.ainit_db() await self.ainit_db()
# Verify the table exists # Verify the table exists
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
async with db.execute( async with db.execute(
@@ -62,33 +63,37 @@ class AsyncDatabaseManager:
result = await cursor.fetchone() result = await cursor.fetchone()
if not result: if not result:
raise Exception("crawled_data table was not created") raise Exception("crawled_data table was not created")
# If version changed or fresh install, run updates # If version changed or fresh install, run updates
if needs_update: if needs_update:
self.logger.info("New version detected, running updates", tag="INIT") self.logger.info("New version detected, running updates", tag="INIT")
await self.update_db_schema() await self.update_db_schema()
from .migrations import run_migration # Import here to avoid circular imports from .migrations import (
run_migration,
) # Import here to avoid circular imports
await run_migration() await run_migration()
self.version_manager.update_version() # Update stored version after successful migration self.version_manager.update_version() # Update stored version after successful migration
self.logger.success("Version update completed successfully", tag="COMPLETE") self.logger.success(
"Version update completed successfully", tag="COMPLETE"
)
else: else:
self.logger.success("Database initialization completed successfully", tag="COMPLETE") self.logger.success(
"Database initialization completed successfully", tag="COMPLETE"
)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
message="Database initialization error: {error}", message="Database initialization error: {error}",
tag="ERROR", tag="ERROR",
params={"error": str(e)} params={"error": str(e)},
) )
self.logger.info( self.logger.info(
message="Database will be initialized on first use", message="Database will be initialized on first use", tag="INIT"
tag="INIT"
) )
raise raise
async def cleanup(self): async def cleanup(self):
"""Cleanup connections when shutting down""" """Cleanup connections when shutting down"""
async with self.pool_lock: async with self.pool_lock:
@@ -107,6 +112,7 @@ class AsyncDatabaseManager:
self._initialized = True self._initialized = True
except Exception as e: except Exception as e:
import sys import sys
error_context = get_error_context(sys.exc_info()) error_context = get_error_context(sys.exc_info())
self.logger.error( self.logger.error(
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}", message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
@@ -115,41 +121,52 @@ class AsyncDatabaseManager:
params={ params={
"error": str(e), "error": str(e),
"context": error_context["code_context"], "context": error_context["code_context"],
"traceback": error_context["full_traceback"] "traceback": error_context["full_traceback"],
} },
) )
raise raise
await self.connection_semaphore.acquire() await self.connection_semaphore.acquire()
task_id = id(asyncio.current_task()) task_id = id(asyncio.current_task())
try: try:
async with self.pool_lock: async with self.pool_lock:
if task_id not in self.connection_pool: if task_id not in self.connection_pool:
try: try:
conn = await aiosqlite.connect( conn = await aiosqlite.connect(self.db_path, timeout=30.0)
self.db_path, await conn.execute("PRAGMA journal_mode = WAL")
timeout=30.0 await conn.execute("PRAGMA busy_timeout = 5000")
)
await conn.execute('PRAGMA journal_mode = WAL')
await conn.execute('PRAGMA busy_timeout = 5000')
# Verify database structure # Verify database structure
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor: async with conn.execute(
"PRAGMA table_info(crawled_data)"
) as cursor:
columns = await cursor.fetchall() columns = await cursor.fetchall()
column_names = [col[1] for col in columns] column_names = [col[1] for col in columns]
expected_columns = { expected_columns = {
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content', "url",
'success', 'media', 'links', 'metadata', 'screenshot', "html",
'response_headers', 'downloaded_files' "cleaned_html",
"markdown",
"extracted_content",
"success",
"media",
"links",
"metadata",
"screenshot",
"response_headers",
"downloaded_files",
} }
missing_columns = expected_columns - set(column_names) missing_columns = expected_columns - set(column_names)
if missing_columns: if missing_columns:
raise ValueError(f"Database missing columns: {missing_columns}") raise ValueError(
f"Database missing columns: {missing_columns}"
)
self.connection_pool[task_id] = conn self.connection_pool[task_id] = conn
except Exception as e: except Exception as e:
import sys import sys
error_context = get_error_context(sys.exc_info()) error_context = get_error_context(sys.exc_info())
error_message = ( error_message = (
f"Unexpected error in db get_connection at line {error_context['line_no']} " f"Unexpected error in db get_connection at line {error_context['line_no']} "
@@ -158,7 +175,7 @@ class AsyncDatabaseManager:
f"Code context:\n{error_context['code_context']}" f"Code context:\n{error_context['code_context']}"
) )
self.logger.error( self.logger.error(
message=create_box_message(error_message, type= "error"), message=create_box_message(error_message, type="error"),
) )
raise raise
@@ -167,6 +184,7 @@ class AsyncDatabaseManager:
except Exception as e: except Exception as e:
import sys import sys
error_context = get_error_context(sys.exc_info()) error_context = get_error_context(sys.exc_info())
error_message = ( error_message = (
f"Unexpected error in db get_connection at line {error_context['line_no']} " f"Unexpected error in db get_connection at line {error_context['line_no']} "
@@ -175,7 +193,7 @@ class AsyncDatabaseManager:
f"Code context:\n{error_context['code_context']}" f"Code context:\n{error_context['code_context']}"
) )
self.logger.error( self.logger.error(
message=create_box_message(error_message, type= "error"), message=create_box_message(error_message, type="error"),
) )
raise raise
finally: finally:
@@ -185,7 +203,6 @@ class AsyncDatabaseManager:
del self.connection_pool[task_id] del self.connection_pool[task_id]
self.connection_semaphore.release() self.connection_semaphore.release()
async def execute_with_retry(self, operation, *args): async def execute_with_retry(self, operation, *args):
"""Execute database operations with retry logic""" """Execute database operations with retry logic"""
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
@@ -200,18 +217,16 @@ class AsyncDatabaseManager:
message="Operation failed after {retries} attempts: {error}", message="Operation failed after {retries} attempts: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={ params={"retries": self.max_retries, "error": str(e)},
"retries": self.max_retries, )
"error": str(e)
}
)
raise raise
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
async def ainit_db(self): async def ainit_db(self):
"""Initialize database schema""" """Initialize database schema"""
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
await db.execute(''' await db.execute(
"""
CREATE TABLE IF NOT EXISTS crawled_data ( CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY, url TEXT PRIMARY KEY,
html TEXT, html TEXT,
@@ -226,21 +241,27 @@ class AsyncDatabaseManager:
response_headers TEXT DEFAULT "{}", response_headers TEXT DEFAULT "{}",
downloaded_files TEXT DEFAULT "{}" -- New column added downloaded_files TEXT DEFAULT "{}" -- New column added
) )
''') """
)
await db.commit() await db.commit()
async def update_db_schema(self): async def update_db_schema(self):
"""Update database schema if needed""" """Update database schema if needed"""
async with aiosqlite.connect(self.db_path, timeout=30.0) as db: async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
cursor = await db.execute("PRAGMA table_info(crawled_data)") cursor = await db.execute("PRAGMA table_info(crawled_data)")
columns = await cursor.fetchall() columns = await cursor.fetchall()
column_names = [column[1] for column in columns] column_names = [column[1] for column in columns]
# List of new columns to add # List of new columns to add
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] new_columns = [
"media",
"links",
"metadata",
"screenshot",
"response_headers",
"downloaded_files",
]
for column in new_columns: for column in new_columns:
if column not in column_names: if column not in column_names:
await self.aalter_db_add_column(column, db) await self.aalter_db_add_column(column, db)
@@ -248,75 +269,91 @@ class AsyncDatabaseManager:
async def aalter_db_add_column(self, new_column: str, db): async def aalter_db_add_column(self, new_column: str, db):
"""Add new column to the database""" """Add new column to the database"""
if new_column == 'response_headers': if new_column == "response_headers":
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') await db.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"'
)
else: else:
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') await db.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
)
self.logger.info( self.logger.info(
message="Added column '{column}' to the database", message="Added column '{column}' to the database",
tag="INIT", tag="INIT",
params={"column": new_column} params={"column": new_column},
) )
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
"""Retrieve cached URL data as CrawlResult""" """Retrieve cached URL data as CrawlResult"""
async def _get(db): async def _get(db):
async with db.execute( async with db.execute(
'SELECT * FROM crawled_data WHERE url = ?', (url,) "SELECT * FROM crawled_data WHERE url = ?", (url,)
) as cursor: ) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
if not row: if not row:
return None return None
# Get column names # Get column names
columns = [description[0] for description in cursor.description] columns = [description[0] for description in cursor.description]
# Create dict from row data # Create dict from row data
row_dict = dict(zip(columns, row)) row_dict = dict(zip(columns, row))
# Load content from files using stored hashes # Load content from files using stored hashes
content_fields = { content_fields = {
'html': row_dict['html'], "html": row_dict["html"],
'cleaned_html': row_dict['cleaned_html'], "cleaned_html": row_dict["cleaned_html"],
'markdown': row_dict['markdown'], "markdown": row_dict["markdown"],
'extracted_content': row_dict['extracted_content'], "extracted_content": row_dict["extracted_content"],
'screenshot': row_dict['screenshot'], "screenshot": row_dict["screenshot"],
'screenshots': row_dict['screenshot'], "screenshots": row_dict["screenshot"],
} }
for field, hash_value in content_fields.items(): for field, hash_value in content_fields.items():
if hash_value: if hash_value:
content = await self._load_content( content = await self._load_content(
hash_value, hash_value,
field.split('_')[0] # Get content type from field name field.split("_")[0], # Get content type from field name
) )
row_dict[field] = content or "" row_dict[field] = content or ""
else: else:
row_dict[field] = "" row_dict[field] = ""
# Parse JSON fields # Parse JSON fields
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown'] json_fields = [
"media",
"links",
"metadata",
"response_headers",
"markdown",
]
for field in json_fields: for field in json_fields:
try: try:
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {} row_dict[field] = (
json.loads(row_dict[field]) if row_dict[field] else {}
)
except json.JSONDecodeError: except json.JSONDecodeError:
row_dict[field] = {} row_dict[field] = {}
if isinstance(row_dict['markdown'], Dict): if isinstance(row_dict["markdown"], Dict):
row_dict['markdown_v2'] = row_dict['markdown'] row_dict["markdown_v2"] = row_dict["markdown"]
if row_dict['markdown'].get('raw_markdown'): if row_dict["markdown"].get("raw_markdown"):
row_dict['markdown'] = row_dict['markdown']['raw_markdown'] row_dict["markdown"] = row_dict["markdown"]["raw_markdown"]
# Parse downloaded_files # Parse downloaded_files
try: try:
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else [] row_dict["downloaded_files"] = (
json.loads(row_dict["downloaded_files"])
if row_dict["downloaded_files"]
else []
)
except json.JSONDecodeError: except json.JSONDecodeError:
row_dict['downloaded_files'] = [] row_dict["downloaded_files"] = []
# Remove any fields not in CrawlResult model # Remove any fields not in CrawlResult model
valid_fields = CrawlResult.__annotations__.keys() valid_fields = CrawlResult.__annotations__.keys()
filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields} filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields}
return CrawlResult(**filtered_dict) return CrawlResult(**filtered_dict)
try: try:
@@ -326,7 +363,7 @@ class AsyncDatabaseManager:
message="Error retrieving cached URL: {error}", message="Error retrieving cached URL: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
return None return None
@@ -334,37 +371,52 @@ class AsyncDatabaseManager:
"""Cache CrawlResult data""" """Cache CrawlResult data"""
# Store content files and get hashes # Store content files and get hashes
content_map = { content_map = {
'html': (result.html, 'html'), "html": (result.html, "html"),
'cleaned_html': (result.cleaned_html or "", 'cleaned'), "cleaned_html": (result.cleaned_html or "", "cleaned"),
'markdown': None, "markdown": None,
'extracted_content': (result.extracted_content or "", 'extracted'), "extracted_content": (result.extracted_content or "", "extracted"),
'screenshot': (result.screenshot or "", 'screenshots') "screenshot": (result.screenshot or "", "screenshots"),
} }
try: try:
if isinstance(result.markdown, MarkdownGenerationResult): if isinstance(result.markdown, MarkdownGenerationResult):
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown') content_map["markdown"] = (
elif hasattr(result, 'markdown_v2'): result.markdown.model_dump_json(),
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown') "markdown",
)
elif hasattr(result, "markdown_v2"):
content_map["markdown"] = (
result.markdown_v2.model_dump_json(),
"markdown",
)
elif isinstance(result.markdown, str): elif isinstance(result.markdown, str):
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown) markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown') content_map["markdown"] = (
markdown_result.model_dump_json(),
"markdown",
)
else: else:
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') content_map["markdown"] = (
MarkdownGenerationResult().model_dump_json(),
"markdown",
)
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
message=f"Error processing markdown content: {str(e)}", message=f"Error processing markdown content: {str(e)}", tag="WARNING"
tag="WARNING"
) )
# Fallback to empty markdown result # Fallback to empty markdown result
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown') content_map["markdown"] = (
MarkdownGenerationResult().model_dump_json(),
"markdown",
)
content_hashes = {} content_hashes = {}
for field, (content, content_type) in content_map.items(): for field, (content, content_type) in content_map.items():
content_hashes[field] = await self._store_content(content, content_type) content_hashes[field] = await self._store_content(content, content_type)
async def _cache(db): async def _cache(db):
await db.execute(''' await db.execute(
"""
INSERT INTO crawled_data ( INSERT INTO crawled_data (
url, html, cleaned_html, markdown, url, html, cleaned_html, markdown,
extracted_content, success, media, links, metadata, extracted_content, success, media, links, metadata,
@@ -383,20 +435,22 @@ class AsyncDatabaseManager:
screenshot = excluded.screenshot, screenshot = excluded.screenshot,
response_headers = excluded.response_headers, response_headers = excluded.response_headers,
downloaded_files = excluded.downloaded_files downloaded_files = excluded.downloaded_files
''', ( """,
result.url, (
content_hashes['html'], result.url,
content_hashes['cleaned_html'], content_hashes["html"],
content_hashes['markdown'], content_hashes["cleaned_html"],
content_hashes['extracted_content'], content_hashes["markdown"],
result.success, content_hashes["extracted_content"],
json.dumps(result.media), result.success,
json.dumps(result.links), json.dumps(result.media),
json.dumps(result.metadata or {}), json.dumps(result.links),
content_hashes['screenshot'], json.dumps(result.metadata or {}),
json.dumps(result.response_headers or {}), content_hashes["screenshot"],
json.dumps(result.downloaded_files or []) json.dumps(result.response_headers or {}),
)) json.dumps(result.downloaded_files or []),
),
)
try: try:
await self.execute_with_retry(_cache) await self.execute_with_retry(_cache)
@@ -405,14 +459,14 @@ class AsyncDatabaseManager:
message="Error caching URL: {error}", message="Error caching URL: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
async def aget_total_count(self) -> int: async def aget_total_count(self) -> int:
"""Get total number of cached URLs""" """Get total number of cached URLs"""
async def _count(db): async def _count(db):
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor:
result = await cursor.fetchone() result = await cursor.fetchone()
return result[0] if result else 0 return result[0] if result else 0
@@ -423,14 +477,15 @@ class AsyncDatabaseManager:
message="Error getting total count: {error}", message="Error getting total count: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
return 0 return 0
async def aclear_db(self): async def aclear_db(self):
"""Clear all data from the database""" """Clear all data from the database"""
async def _clear(db): async def _clear(db):
await db.execute('DELETE FROM crawled_data') await db.execute("DELETE FROM crawled_data")
try: try:
await self.execute_with_retry(_clear) await self.execute_with_retry(_clear)
@@ -439,13 +494,14 @@ class AsyncDatabaseManager:
message="Error clearing database: {error}", message="Error clearing database: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
async def aflush_db(self): async def aflush_db(self):
"""Drop the entire table""" """Drop the entire table"""
async def _flush(db): async def _flush(db):
await db.execute('DROP TABLE IF EXISTS crawled_data') await db.execute("DROP TABLE IF EXISTS crawled_data")
try: try:
await self.execute_with_retry(_flush) await self.execute_with_retry(_flush)
@@ -454,42 +510,44 @@ class AsyncDatabaseManager:
message="Error flushing database: {error}", message="Error flushing database: {error}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"error": str(e)} params={"error": str(e)},
) )
async def _store_content(self, content: str, content_type: str) -> str: async def _store_content(self, content: str, content_type: str) -> str:
"""Store content in filesystem and return hash""" """Store content in filesystem and return hash"""
if not content: if not content:
return "" return ""
content_hash = generate_content_hash(content) content_hash = generate_content_hash(content)
file_path = os.path.join(self.content_paths[content_type], content_hash) file_path = os.path.join(self.content_paths[content_type], content_hash)
# Only write if file doesn't exist # Only write if file doesn't exist
if not os.path.exists(file_path): if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
await f.write(content) await f.write(content)
return content_hash return content_hash
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: async def _load_content(
self, content_hash: str, content_type: str
) -> Optional[str]:
"""Load content from filesystem by hash""" """Load content from filesystem by hash"""
if not content_hash: if not content_hash:
return None return None
file_path = os.path.join(self.content_paths[content_type], content_hash) file_path = os.path.join(self.content_paths[content_type], content_hash)
try: try:
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
return await f.read() return await f.read()
except: except:
self.logger.error( self.logger.error(
message="Failed to load content: {file_path}", message="Failed to load content: {file_path}",
tag="ERROR", tag="ERROR",
force_verbose=True, force_verbose=True,
params={"file_path": file_path} params={"file_path": file_path},
) )
return None return None
# Create a singleton instance # Create a singleton instance
async_db_manager = AsyncDatabaseManager() async_db_manager = AsyncDatabaseManager()

View File

@@ -1,14 +1,19 @@
from typing import Dict, Optional, List from typing import Dict, Optional, List, Tuple
from .async_configs import * from .async_configs import CrawlerRunConfig
from .models import * from .models import (
CrawlResult,
CrawlerTaskResult,
CrawlStatus,
DisplayMode,
CrawlStats,
DomainState,
)
from rich.live import Live from rich.live import Live
from rich.table import Table from rich.table import Table
from rich.console import Console from rich.console import Console
from rich.style import Style
from rich import box from rich import box
from datetime import datetime, timedelta from datetime import datetime, timedelta
from dataclasses import dataclass
import time import time
import psutil import psutil
@@ -26,63 +31,66 @@ class RateLimiter:
base_delay: Tuple[float, float] = (1.0, 3.0), base_delay: Tuple[float, float] = (1.0, 3.0),
max_delay: float = 60.0, max_delay: float = 60.0,
max_retries: int = 3, max_retries: int = 3,
rate_limit_codes: List[int] = None rate_limit_codes: List[int] = None,
): ):
self.base_delay = base_delay self.base_delay = base_delay
self.max_delay = max_delay self.max_delay = max_delay
self.max_retries = max_retries self.max_retries = max_retries
self.rate_limit_codes = rate_limit_codes or [429, 503] self.rate_limit_codes = rate_limit_codes or [429, 503]
self.domains: Dict[str, DomainState] = {} self.domains: Dict[str, DomainState] = {}
def get_domain(self, url: str) -> str: def get_domain(self, url: str) -> str:
return urlparse(url).netloc return urlparse(url).netloc
async def wait_if_needed(self, url: str) -> None: async def wait_if_needed(self, url: str) -> None:
domain = self.get_domain(url) domain = self.get_domain(url)
state = self.domains.get(domain) state = self.domains.get(domain)
if not state: if not state:
self.domains[domain] = DomainState() self.domains[domain] = DomainState()
state = self.domains[domain] state = self.domains[domain]
now = time.time() now = time.time()
if state.last_request_time: if state.last_request_time:
wait_time = max(0, state.current_delay - (now - state.last_request_time)) wait_time = max(0, state.current_delay - (now - state.last_request_time))
if wait_time > 0: if wait_time > 0:
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
# Random delay within base range if no current delay # Random delay within base range if no current delay
if state.current_delay == 0: if state.current_delay == 0:
state.current_delay = random.uniform(*self.base_delay) state.current_delay = random.uniform(*self.base_delay)
state.last_request_time = time.time() state.last_request_time = time.time()
def update_delay(self, url: str, status_code: int) -> bool: def update_delay(self, url: str, status_code: int) -> bool:
domain = self.get_domain(url) domain = self.get_domain(url)
state = self.domains[domain] state = self.domains[domain]
if status_code in self.rate_limit_codes: if status_code in self.rate_limit_codes:
state.fail_count += 1 state.fail_count += 1
if state.fail_count > self.max_retries: if state.fail_count > self.max_retries:
return False return False
# Exponential backoff with random jitter # Exponential backoff with random jitter
state.current_delay = min( state.current_delay = min(
state.current_delay * 2 * random.uniform(0.75, 1.25), state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay
self.max_delay
) )
else: else:
# Gradually reduce delay on success # Gradually reduce delay on success
state.current_delay = max( state.current_delay = max(
random.uniform(*self.base_delay), random.uniform(*self.base_delay), state.current_delay * 0.75
state.current_delay * 0.75
) )
state.fail_count = 0 state.fail_count = 0
return True return True
class CrawlerMonitor: class CrawlerMonitor:
def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED): def __init__(
self,
max_visible_rows: int = 15,
display_mode: DisplayMode = DisplayMode.DETAILED,
):
self.console = Console() self.console = Console()
self.max_visible_rows = max_visible_rows self.max_visible_rows = max_visible_rows
self.display_mode = display_mode self.display_mode = display_mode
@@ -90,23 +98,25 @@ class CrawlerMonitor:
self.process = psutil.Process() self.process = psutil.Process()
self.start_time = datetime.now() self.start_time = datetime.now()
self.live = Live(self._create_table(), refresh_per_second=2) self.live = Live(self._create_table(), refresh_per_second=2)
def start(self): def start(self):
self.live.start() self.live.start()
def stop(self): def stop(self):
self.live.stop() self.live.stop()
def add_task(self, task_id: str, url: str): def add_task(self, task_id: str, url: str):
self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED) self.stats[task_id] = CrawlStats(
task_id=task_id, url=url, status=CrawlStatus.QUEUED
)
self.live.update(self._create_table()) self.live.update(self._create_table())
def update_task(self, task_id: str, **kwargs): def update_task(self, task_id: str, **kwargs):
if task_id in self.stats: if task_id in self.stats:
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self.stats[task_id], key, value) setattr(self.stats[task_id], key, value)
self.live.update(self._create_table()) self.live.update(self._create_table())
def _create_aggregated_table(self) -> Table: def _create_aggregated_table(self) -> Table:
"""Creates a compact table showing only aggregated statistics""" """Creates a compact table showing only aggregated statistics"""
table = Table( table = Table(
@@ -114,78 +124,78 @@ class CrawlerMonitor:
title="Crawler Status Overview", title="Crawler Status Overview",
title_style="bold magenta", title_style="bold magenta",
header_style="bold blue", header_style="bold blue",
show_lines=True show_lines=True,
) )
# Calculate statistics # Calculate statistics
total_tasks = len(self.stats) total_tasks = len(self.stats)
queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED) queued = sum(
in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS) 1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED
completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED) )
failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED) in_progress = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
)
completed = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
)
failed = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
)
# Memory statistics # Memory statistics
current_memory = self.process.memory_info().rss / (1024 * 1024) current_memory = self.process.memory_info().rss / (1024 * 1024)
total_task_memory = sum(stat.memory_usage for stat in self.stats.values()) total_task_memory = sum(stat.memory_usage for stat in self.stats.values())
peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0) peak_memory = max(
(stat.peak_memory for stat in self.stats.values()), default=0.0
)
# Duration # Duration
duration = datetime.now() - self.start_time duration = datetime.now() - self.start_time
# Create status row # Create status row
table.add_column("Status", style="bold cyan") table.add_column("Status", style="bold cyan")
table.add_column("Count", justify="right") table.add_column("Count", justify="right")
table.add_column("Percentage", justify="right") table.add_column("Percentage", justify="right")
table.add_row( table.add_row("Total Tasks", str(total_tasks), "100%")
"Total Tasks",
str(total_tasks),
"100%"
)
table.add_row( table.add_row(
"[yellow]In Queue[/yellow]", "[yellow]In Queue[/yellow]",
str(queued), str(queued),
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
table.add_row( table.add_row(
"[blue]In Progress[/blue]", "[blue]In Progress[/blue]",
str(in_progress), str(in_progress),
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
table.add_row( table.add_row(
"[green]Completed[/green]", "[green]Completed[/green]",
str(completed), str(completed),
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
table.add_row( table.add_row(
"[red]Failed[/red]", "[red]Failed[/red]",
str(failed), str(failed),
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%" f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
) )
# Add memory information # Add memory information
table.add_section() table.add_section()
table.add_row( table.add_row(
"[magenta]Current Memory[/magenta]", "[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", ""
f"{current_memory:.1f} MB",
""
) )
table.add_row( table.add_row(
"[magenta]Total Task Memory[/magenta]", "[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", ""
f"{total_task_memory:.1f} MB",
""
) )
table.add_row( table.add_row(
"[magenta]Peak Task Memory[/magenta]", "[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", ""
f"{peak_memory:.1f} MB",
""
) )
table.add_row( table.add_row(
"[yellow]Runtime[/yellow]", "[yellow]Runtime[/yellow]",
str(timedelta(seconds=int(duration.total_seconds()))), str(timedelta(seconds=int(duration.total_seconds()))),
"" "",
) )
return table return table
def _create_detailed_table(self) -> Table: def _create_detailed_table(self) -> Table:
@@ -193,9 +203,9 @@ class CrawlerMonitor:
box=box.ROUNDED, box=box.ROUNDED,
title="Crawler Performance Monitor", title="Crawler Performance Monitor",
title_style="bold magenta", title_style="bold magenta",
header_style="bold blue" header_style="bold blue",
) )
# Add columns # Add columns
table.add_column("Task ID", style="cyan", no_wrap=True) table.add_column("Task ID", style="cyan", no_wrap=True)
table.add_column("URL", style="cyan", no_wrap=True) table.add_column("URL", style="cyan", no_wrap=True)
@@ -204,47 +214,54 @@ class CrawlerMonitor:
table.add_column("Peak (MB)", justify="right") table.add_column("Peak (MB)", justify="right")
table.add_column("Duration", justify="right") table.add_column("Duration", justify="right")
table.add_column("Info", style="italic") table.add_column("Info", style="italic")
# Add summary row # Add summary row
total_memory = sum(stat.memory_usage for stat in self.stats.values()) total_memory = sum(stat.memory_usage for stat in self.stats.values())
active_count = sum(1 for stat in self.stats.values() active_count = sum(
if stat.status == CrawlStatus.IN_PROGRESS) 1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
completed_count = sum(1 for stat in self.stats.values() )
if stat.status == CrawlStatus.COMPLETED) completed_count = sum(
failed_count = sum(1 for stat in self.stats.values() 1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
if stat.status == CrawlStatus.FAILED) )
failed_count = sum(
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
)
table.add_row( table.add_row(
"[bold yellow]SUMMARY", "[bold yellow]SUMMARY",
f"Total: {len(self.stats)}", f"Total: {len(self.stats)}",
f"Active: {active_count}", f"Active: {active_count}",
f"{total_memory:.1f}", f"{total_memory:.1f}",
f"{self.process.memory_info().rss / (1024 * 1024):.1f}", f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))), str(
timedelta(
seconds=int((datetime.now() - self.start_time).total_seconds())
)
),
f"{completed_count}{failed_count}", f"{completed_count}{failed_count}",
style="bold" style="bold",
) )
table.add_section() table.add_section()
# Add rows for each task # Add rows for each task
visible_stats = sorted( visible_stats = sorted(
self.stats.values(), self.stats.values(),
key=lambda x: ( key=lambda x: (
x.status != CrawlStatus.IN_PROGRESS, x.status != CrawlStatus.IN_PROGRESS,
x.status != CrawlStatus.QUEUED, x.status != CrawlStatus.QUEUED,
x.end_time or datetime.max x.end_time or datetime.max,
) ),
)[:self.max_visible_rows] )[: self.max_visible_rows]
for stat in visible_stats: for stat in visible_stats:
status_style = { status_style = {
CrawlStatus.QUEUED: "white", CrawlStatus.QUEUED: "white",
CrawlStatus.IN_PROGRESS: "yellow", CrawlStatus.IN_PROGRESS: "yellow",
CrawlStatus.COMPLETED: "green", CrawlStatus.COMPLETED: "green",
CrawlStatus.FAILED: "red" CrawlStatus.FAILED: "red",
}[stat.status] }[stat.status]
table.add_row( table.add_row(
stat.task_id[:8], # Show first 8 chars of task ID stat.task_id[:8], # Show first 8 chars of task ID
stat.url[:40] + "..." if len(stat.url) > 40 else stat.url, stat.url[:40] + "..." if len(stat.url) > 40 else stat.url,
@@ -252,9 +269,9 @@ class CrawlerMonitor:
f"{stat.memory_usage:.1f}", f"{stat.memory_usage:.1f}",
f"{stat.peak_memory:.1f}", f"{stat.peak_memory:.1f}",
stat.duration, stat.duration,
stat.error_message[:40] if stat.error_message else "" stat.error_message[:40] if stat.error_message else "",
) )
return table return table
def _create_table(self) -> Table: def _create_table(self) -> Table:
@@ -268,7 +285,7 @@ class BaseDispatcher(ABC):
def __init__( def __init__(
self, self,
rate_limiter: Optional[RateLimiter] = None, rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
): ):
self.crawler = None self.crawler = None
self._domain_last_hit: Dict[str, float] = {} self._domain_last_hit: Dict[str, float] = {}
@@ -278,24 +295,25 @@ class BaseDispatcher(ABC):
@abstractmethod @abstractmethod
async def crawl_url( async def crawl_url(
self, self,
url: str, url: str,
config: CrawlerRunConfig, config: CrawlerRunConfig,
task_id: str, task_id: str,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
) -> CrawlerTaskResult: ) -> CrawlerTaskResult:
pass pass
@abstractmethod @abstractmethod
async def run_urls( async def run_urls(
self, self,
urls: List[str], urls: List[str],
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler", # noqa: F821
config: CrawlerRunConfig, config: CrawlerRunConfig,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
) -> List[CrawlerTaskResult]: ) -> List[CrawlerTaskResult]:
pass pass
class MemoryAdaptiveDispatcher(BaseDispatcher): class MemoryAdaptiveDispatcher(BaseDispatcher):
def __init__( def __init__(
self, self,
@@ -304,39 +322,41 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
max_session_permit: int = 20, max_session_permit: int = 20,
memory_wait_timeout: float = 300.0, # 5 minutes default timeout memory_wait_timeout: float = 300.0, # 5 minutes default timeout
rate_limiter: Optional[RateLimiter] = None, rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
): ):
super().__init__(rate_limiter, monitor) super().__init__(rate_limiter, monitor)
self.memory_threshold_percent = memory_threshold_percent self.memory_threshold_percent = memory_threshold_percent
self.check_interval = check_interval self.check_interval = check_interval
self.max_session_permit = max_session_permit self.max_session_permit = max_session_permit
self.memory_wait_timeout = memory_wait_timeout self.memory_wait_timeout = memory_wait_timeout
async def crawl_url( async def crawl_url(
self, self,
url: str, url: str,
config: CrawlerRunConfig, config: CrawlerRunConfig,
task_id: str, task_id: str,
) -> CrawlerTaskResult: ) -> CrawlerTaskResult:
start_time = datetime.now() start_time = datetime.now()
error_message = "" error_message = ""
memory_usage = peak_memory = 0.0 memory_usage = peak_memory = 0.0
try: try:
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) self.monitor.update_task(
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
)
self.concurrent_sessions += 1 self.concurrent_sessions += 1
if self.rate_limiter: if self.rate_limiter:
await self.rate_limiter.wait_if_needed(url) await self.rate_limiter.wait_if_needed(url)
process = psutil.Process() process = psutil.Process()
start_memory = process.memory_info().rss / (1024 * 1024) start_memory = process.memory_info().rss / (1024 * 1024)
result = await self.crawler.arun(url, config=config, session_id=task_id) result = await self.crawler.arun(url, config=config, session_id=task_id)
end_memory = process.memory_info().rss / (1024 * 1024) end_memory = process.memory_info().rss / (1024 * 1024)
memory_usage = peak_memory = end_memory - start_memory memory_usage = peak_memory = end_memory - start_memory
if self.rate_limiter and result.status_code: if self.rate_limiter and result.status_code:
if not self.rate_limiter.update_delay(url, result.status_code): if not self.rate_limiter.update_delay(url, result.status_code):
error_message = f"Rate limit retry count exceeded for domain {urlparse(url).netloc}" error_message = f"Rate limit retry count exceeded for domain {urlparse(url).netloc}"
@@ -350,22 +370,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=datetime.now(), end_time=datetime.now(),
error_message=error_message error_message=error_message,
) )
if not result.success: if not result.success:
error_message = result.error_message error_message = result.error_message
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED) self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
elif self.monitor: elif self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.COMPLETED) self.monitor.update_task(task_id, status=CrawlStatus.COMPLETED)
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED) self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) result = CrawlResult(
url=url, html="", metadata={}, success=False, error_message=str(e)
)
finally: finally:
end_time = datetime.now() end_time = datetime.now()
if self.monitor: if self.monitor:
@@ -374,10 +396,10 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
end_time=end_time, end_time=end_time,
memory_usage=memory_usage, memory_usage=memory_usage,
peak_memory=peak_memory, peak_memory=peak_memory,
error_message=error_message error_message=error_message,
) )
self.concurrent_sessions -= 1 self.concurrent_sessions -= 1
return CrawlerTaskResult( return CrawlerTaskResult(
task_id=task_id, task_id=task_id,
url=url, url=url,
@@ -386,20 +408,20 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
error_message=error_message error_message=error_message,
) )
async def run_urls( async def run_urls(
self, self,
urls: List[str], urls: List[str],
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler", # noqa: F821
config: CrawlerRunConfig, config: CrawlerRunConfig,
) -> List[CrawlerTaskResult]: ) -> List[CrawlerTaskResult]:
self.crawler = crawler self.crawler = crawler
if self.monitor: if self.monitor:
self.monitor.start() self.monitor.start()
try: try:
pending_tasks = [] pending_tasks = []
active_tasks = [] active_tasks = []
@@ -417,23 +439,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
if psutil.virtual_memory().percent >= self.memory_threshold_percent: if psutil.virtual_memory().percent >= self.memory_threshold_percent:
# Check if we've exceeded the timeout # Check if we've exceeded the timeout
if time.time() - wait_start_time > self.memory_wait_timeout: if time.time() - wait_start_time > self.memory_wait_timeout:
raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds") raise MemoryError(
f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds"
)
await asyncio.sleep(self.check_interval) await asyncio.sleep(self.check_interval)
continue continue
url, task_id = task_queue.pop(0) url, task_id = task_queue.pop(0)
task = asyncio.create_task(self.crawl_url(url, config, task_id)) task = asyncio.create_task(self.crawl_url(url, config, task_id))
active_tasks.append(task) active_tasks.append(task)
if not active_tasks: if not active_tasks:
await asyncio.sleep(self.check_interval) await asyncio.sleep(self.check_interval)
continue continue
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
active_tasks, active_tasks, return_when=asyncio.FIRST_COMPLETED
return_when=asyncio.FIRST_COMPLETED
) )
pending_tasks.extend(done) pending_tasks.extend(done)
active_tasks = list(pending) active_tasks = list(pending)
@@ -442,24 +465,25 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
if self.monitor: if self.monitor:
self.monitor.stop() self.monitor.stop()
class SemaphoreDispatcher(BaseDispatcher): class SemaphoreDispatcher(BaseDispatcher):
def __init__( def __init__(
self, self,
semaphore_count: int = 5, semaphore_count: int = 5,
max_session_permit: int = 20, max_session_permit: int = 20,
rate_limiter: Optional[RateLimiter] = None, rate_limiter: Optional[RateLimiter] = None,
monitor: Optional[CrawlerMonitor] = None monitor: Optional[CrawlerMonitor] = None,
): ):
super().__init__(rate_limiter, monitor) super().__init__(rate_limiter, monitor)
self.semaphore_count = semaphore_count self.semaphore_count = semaphore_count
self.max_session_permit = max_session_permit self.max_session_permit = max_session_permit
async def crawl_url( async def crawl_url(
self, self,
url: str, url: str,
config: CrawlerRunConfig, config: CrawlerRunConfig,
task_id: str, task_id: str,
semaphore: asyncio.Semaphore = None semaphore: asyncio.Semaphore = None,
) -> CrawlerTaskResult: ) -> CrawlerTaskResult:
start_time = datetime.now() start_time = datetime.now()
error_message = "" error_message = ""
@@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher):
try: try:
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time) self.monitor.update_task(
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
)
if self.rate_limiter: if self.rate_limiter:
await self.rate_limiter.wait_if_needed(url) await self.rate_limiter.wait_if_needed(url)
@@ -477,7 +503,7 @@ class SemaphoreDispatcher(BaseDispatcher):
start_memory = process.memory_info().rss / (1024 * 1024) start_memory = process.memory_info().rss / (1024 * 1024)
result = await self.crawler.arun(url, config=config, session_id=task_id) result = await self.crawler.arun(url, config=config, session_id=task_id)
end_memory = process.memory_info().rss / (1024 * 1024) end_memory = process.memory_info().rss / (1024 * 1024)
memory_usage = peak_memory = end_memory - start_memory memory_usage = peak_memory = end_memory - start_memory
if self.rate_limiter and result.status_code: if self.rate_limiter and result.status_code:
@@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=datetime.now(), end_time=datetime.now(),
error_message=error_message error_message=error_message,
) )
if not result.success: if not result.success:
@@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher):
error_message = str(e) error_message = str(e)
if self.monitor: if self.monitor:
self.monitor.update_task(task_id, status=CrawlStatus.FAILED) self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e)) result = CrawlResult(
url=url, html="", metadata={}, success=False, error_message=str(e)
)
finally: finally:
end_time = datetime.now() end_time = datetime.now()
@@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher):
end_time=end_time, end_time=end_time,
memory_usage=memory_usage, memory_usage=memory_usage,
peak_memory=peak_memory, peak_memory=peak_memory,
error_message=error_message error_message=error_message,
) )
return CrawlerTaskResult( return CrawlerTaskResult(
@@ -528,13 +556,13 @@ class SemaphoreDispatcher(BaseDispatcher):
peak_memory=peak_memory, peak_memory=peak_memory,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
error_message=error_message error_message=error_message,
) )
async def run_urls( async def run_urls(
self, self,
crawler: "AsyncWebCrawler", crawler: "AsyncWebCrawler", # noqa: F821
urls: List[str], urls: List[str],
config: CrawlerRunConfig, config: CrawlerRunConfig,
) -> List[CrawlerTaskResult]: ) -> List[CrawlerTaskResult]:
self.crawler = crawler self.crawler = crawler
@@ -557,4 +585,4 @@ class SemaphoreDispatcher(BaseDispatcher):
return await asyncio.gather(*tasks, return_exceptions=True) return await asyncio.gather(*tasks, return_exceptions=True)
finally: finally:
if self.monitor: if self.monitor:
self.monitor.stop() self.monitor.stop()

View File

@@ -1,10 +1,10 @@
from enum import Enum from enum import Enum
from typing import Optional, Dict, Any, Union from typing import Optional, Dict, Any
from colorama import Fore, Back, Style, init from colorama import Fore, Style, init
import time
import os import os
from datetime import datetime from datetime import datetime
class LogLevel(Enum): class LogLevel(Enum):
DEBUG = 1 DEBUG = 1
INFO = 2 INFO = 2
@@ -12,23 +12,24 @@ class LogLevel(Enum):
WARNING = 4 WARNING = 4
ERROR = 5 ERROR = 5
class AsyncLogger: class AsyncLogger:
""" """
Asynchronous logger with support for colored console output and file logging. Asynchronous logger with support for colored console output and file logging.
Supports templated messages with colored components. Supports templated messages with colored components.
""" """
DEFAULT_ICONS = { DEFAULT_ICONS = {
'INIT': '', "INIT": "",
'READY': '', "READY": "",
'FETCH': '', "FETCH": "",
'SCRAPE': '', "SCRAPE": "",
'EXTRACT': '', "EXTRACT": "",
'COMPLETE': '', "COMPLETE": "",
'ERROR': '×', "ERROR": "×",
'DEBUG': '', "DEBUG": "",
'INFO': '', "INFO": "",
'WARNING': '', "WARNING": "",
} }
DEFAULT_COLORS = { DEFAULT_COLORS = {
@@ -46,11 +47,11 @@ class AsyncLogger:
tag_width: int = 10, tag_width: int = 10,
icons: Optional[Dict[str, str]] = None, icons: Optional[Dict[str, str]] = None,
colors: Optional[Dict[LogLevel, str]] = None, colors: Optional[Dict[LogLevel, str]] = None,
verbose: bool = True verbose: bool = True,
): ):
""" """
Initialize the logger. Initialize the logger.
Args: Args:
log_file: Optional file path for logging log_file: Optional file path for logging
log_level: Minimum log level to display log_level: Minimum log level to display
@@ -66,7 +67,7 @@ class AsyncLogger:
self.icons = icons or self.DEFAULT_ICONS self.icons = icons or self.DEFAULT_ICONS
self.colors = colors or self.DEFAULT_COLORS self.colors = colors or self.DEFAULT_COLORS
self.verbose = verbose self.verbose = verbose
# Create log file directory if needed # Create log file directory if needed
if log_file: if log_file:
os.makedirs(os.path.dirname(os.path.abspath(log_file)), exist_ok=True) os.makedirs(os.path.dirname(os.path.abspath(log_file)), exist_ok=True)
@@ -77,18 +78,20 @@ class AsyncLogger:
def _get_icon(self, tag: str) -> str: def _get_icon(self, tag: str) -> str:
"""Get the icon for a tag, defaulting to info icon if not found.""" """Get the icon for a tag, defaulting to info icon if not found."""
return self.icons.get(tag, self.icons['INFO']) return self.icons.get(tag, self.icons["INFO"])
def _write_to_file(self, message: str): def _write_to_file(self, message: str):
"""Write a message to the log file if configured.""" """Write a message to the log file if configured."""
if self.log_file: if self.log_file:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
with open(self.log_file, 'a', encoding='utf-8') as f: with open(self.log_file, "a", encoding="utf-8") as f:
# Strip ANSI color codes for file output # Strip ANSI color codes for file output
clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '') clean_message = message.replace(Fore.RESET, "").replace(
Style.RESET_ALL, ""
)
for color in vars(Fore).values(): for color in vars(Fore).values():
if isinstance(color, str): if isinstance(color, str):
clean_message = clean_message.replace(color, '') clean_message = clean_message.replace(color, "")
f.write(f"[{timestamp}] {clean_message}\n") f.write(f"[{timestamp}] {clean_message}\n")
def _log( def _log(
@@ -99,11 +102,11 @@ class AsyncLogger:
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
colors: Optional[Dict[str, str]] = None, colors: Optional[Dict[str, str]] = None,
base_color: Optional[str] = None, base_color: Optional[str] = None,
**kwargs **kwargs,
): ):
""" """
Core logging method that handles message formatting and output. Core logging method that handles message formatting and output.
Args: Args:
level: Log level for this message level: Log level for this message
message: Message template string message: Message template string
@@ -120,7 +123,7 @@ class AsyncLogger:
try: try:
# First format the message with raw parameters # First format the message with raw parameters
formatted_message = message.format(**params) formatted_message = message.format(**params)
# Then apply colors if specified # Then apply colors if specified
if colors: if colors:
for key, color in colors.items(): for key, color in colors.items():
@@ -128,12 +131,13 @@ class AsyncLogger:
if key in params: if key in params:
value_str = str(params[key]) value_str = str(params[key])
formatted_message = formatted_message.replace( formatted_message = formatted_message.replace(
value_str, value_str, f"{color}{value_str}{Style.RESET_ALL}"
f"{color}{value_str}{Style.RESET_ALL}"
) )
except KeyError as e: except KeyError as e:
formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template" formatted_message = (
f"LOGGING ERROR: Missing parameter {e} in message template"
)
level = LogLevel.ERROR level = LogLevel.ERROR
else: else:
formatted_message = message formatted_message = message
@@ -175,11 +179,11 @@ class AsyncLogger:
success: bool, success: bool,
timing: float, timing: float,
tag: str = "FETCH", tag: str = "FETCH",
url_length: int = 50 url_length: int = 50,
): ):
""" """
Convenience method for logging URL fetch status. Convenience method for logging URL fetch status.
Args: Args:
url: The URL being processed url: The URL being processed
success: Whether the operation was successful success: Whether the operation was successful
@@ -195,24 +199,20 @@ class AsyncLogger:
"url": url, "url": url,
"url_length": url_length, "url_length": url_length,
"status": success, "status": success,
"timing": timing "timing": timing,
}, },
colors={ colors={
"status": Fore.GREEN if success else Fore.RED, "status": Fore.GREEN if success else Fore.RED,
"timing": Fore.YELLOW "timing": Fore.YELLOW,
} },
) )
def error_status( def error_status(
self, self, url: str, error: str, tag: str = "ERROR", url_length: int = 50
url: str,
error: str,
tag: str = "ERROR",
url_length: int = 50
): ):
""" """
Convenience method for logging error status. Convenience method for logging error status.
Args: Args:
url: The URL being processed url: The URL being processed
error: Error message error: Error message
@@ -223,9 +223,5 @@ class AsyncLogger:
level=LogLevel.ERROR, level=LogLevel.ERROR,
message="{url:.{url_length}}... | Error: {error}", message="{url:.{url_length}}... | Error: {error}",
tag=tag, tag=tag,
params={ params={"url": url, "url_length": url_length, "error": error},
"url": url, )
"url_length": url_length,
"error": error
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ from enum import Enum
class CacheMode(Enum): class CacheMode(Enum):
""" """
Defines the caching behavior for web crawling operations. Defines the caching behavior for web crawling operations.
Modes: Modes:
- ENABLED: Normal caching behavior (read and write) - ENABLED: Normal caching behavior (read and write)
- DISABLED: No caching at all - DISABLED: No caching at all
@@ -12,6 +12,7 @@ class CacheMode(Enum):
- WRITE_ONLY: Only write to cache, don't read - WRITE_ONLY: Only write to cache, don't read
- BYPASS: Bypass cache for this operation - BYPASS: Bypass cache for this operation
""" """
ENABLED = "enabled" ENABLED = "enabled"
DISABLED = "disabled" DISABLED = "disabled"
READ_ONLY = "read_only" READ_ONLY = "read_only"
@@ -22,10 +23,10 @@ class CacheMode(Enum):
class CacheContext: class CacheContext:
""" """
Encapsulates cache-related decisions and URL handling. Encapsulates cache-related decisions and URL handling.
This class centralizes all cache-related logic and URL type checking, This class centralizes all cache-related logic and URL type checking,
making the caching behavior more predictable and maintainable. making the caching behavior more predictable and maintainable.
Attributes: Attributes:
url (str): The URL being processed. url (str): The URL being processed.
cache_mode (CacheMode): The cache mode for the current operation. cache_mode (CacheMode): The cache mode for the current operation.
@@ -36,10 +37,11 @@ class CacheContext:
is_raw_html (bool): True if the URL is raw HTML, False otherwise. is_raw_html (bool): True if the URL is raw HTML, False otherwise.
_url_display (str): The display name for the URL (web, local file, or raw HTML). _url_display (str): The display name for the URL (web, local file, or raw HTML).
""" """
def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False): def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False):
""" """
Initializes the CacheContext with the provided URL and cache mode. Initializes the CacheContext with the provided URL and cache mode.
Args: Args:
url (str): The URL being processed. url (str): The URL being processed.
cache_mode (CacheMode): The cache mode for the current operation. cache_mode (CacheMode): The cache mode for the current operation.
@@ -48,42 +50,42 @@ class CacheContext:
self.url = url self.url = url
self.cache_mode = cache_mode self.cache_mode = cache_mode
self.always_bypass = always_bypass self.always_bypass = always_bypass
self.is_cacheable = url.startswith(('http://', 'https://', 'file://')) self.is_cacheable = url.startswith(("http://", "https://", "file://"))
self.is_web_url = url.startswith(('http://', 'https://')) self.is_web_url = url.startswith(("http://", "https://"))
self.is_local_file = url.startswith("file://") self.is_local_file = url.startswith("file://")
self.is_raw_html = url.startswith("raw:") self.is_raw_html = url.startswith("raw:")
self._url_display = url if not self.is_raw_html else "Raw HTML" self._url_display = url if not self.is_raw_html else "Raw HTML"
def should_read(self) -> bool: def should_read(self) -> bool:
""" """
Determines if cache should be read based on context. Determines if cache should be read based on context.
How it works: How it works:
1. If always_bypass is True or is_cacheable is False, return False. 1. If always_bypass is True or is_cacheable is False, return False.
2. If cache_mode is ENABLED or READ_ONLY, return True. 2. If cache_mode is ENABLED or READ_ONLY, return True.
Returns: Returns:
bool: True if cache should be read, False otherwise. bool: True if cache should be read, False otherwise.
""" """
if self.always_bypass or not self.is_cacheable: if self.always_bypass or not self.is_cacheable:
return False return False
return self.cache_mode in [CacheMode.ENABLED, CacheMode.READ_ONLY] return self.cache_mode in [CacheMode.ENABLED, CacheMode.READ_ONLY]
def should_write(self) -> bool: def should_write(self) -> bool:
""" """
Determines if cache should be written based on context. Determines if cache should be written based on context.
How it works: How it works:
1. If always_bypass is True or is_cacheable is False, return False. 1. If always_bypass is True or is_cacheable is False, return False.
2. If cache_mode is ENABLED or WRITE_ONLY, return True. 2. If cache_mode is ENABLED or WRITE_ONLY, return True.
Returns: Returns:
bool: True if cache should be written, False otherwise. bool: True if cache should be written, False otherwise.
""" """
if self.always_bypass or not self.is_cacheable: if self.always_bypass or not self.is_cacheable:
return False return False
return self.cache_mode in [CacheMode.ENABLED, CacheMode.WRITE_ONLY] return self.cache_mode in [CacheMode.ENABLED, CacheMode.WRITE_ONLY]
@property @property
def display_url(self) -> str: def display_url(self) -> str:
"""Returns the URL in display format.""" """Returns the URL in display format."""
@@ -94,11 +96,11 @@ def _legacy_to_cache_mode(
disable_cache: bool = False, disable_cache: bool = False,
bypass_cache: bool = False, bypass_cache: bool = False,
no_cache_read: bool = False, no_cache_read: bool = False,
no_cache_write: bool = False no_cache_write: bool = False,
) -> CacheMode: ) -> CacheMode:
""" """
Converts legacy cache parameters to the new CacheMode enum. Converts legacy cache parameters to the new CacheMode enum.
This is an internal function to help transition from the old boolean flags This is an internal function to help transition from the old boolean flags
to the new CacheMode system. to the new CacheMode system.
""" """

View File

@@ -3,49 +3,53 @@ import re
from collections import Counter from collections import Counter
import string import string
from .model_loader import load_nltk_punkt from .model_loader import load_nltk_punkt
from .utils import *
# Define the abstract base class for chunking strategies # Define the abstract base class for chunking strategies
class ChunkingStrategy(ABC): class ChunkingStrategy(ABC):
""" """
Abstract base class for chunking strategies. Abstract base class for chunking strategies.
""" """
@abstractmethod @abstractmethod
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
""" """
Abstract method to chunk the given text. Abstract method to chunk the given text.
Args: Args:
text (str): The text to chunk. text (str): The text to chunk.
Returns: Returns:
list: A list of chunks. list: A list of chunks.
""" """
pass pass
# Create an identity chunking strategy f(x) = [x] # Create an identity chunking strategy f(x) = [x]
class IdentityChunking(ChunkingStrategy): class IdentityChunking(ChunkingStrategy):
""" """
Chunking strategy that returns the input text as a single chunk. Chunking strategy that returns the input text as a single chunk.
""" """
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
return [text] return [text]
# Regex-based chunking # Regex-based chunking
class RegexChunking(ChunkingStrategy): class RegexChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text based on regular expression patterns. Chunking strategy that splits text based on regular expression patterns.
""" """
def __init__(self, patterns=None, **kwargs): def __init__(self, patterns=None, **kwargs):
""" """
Initialize the RegexChunking object. Initialize the RegexChunking object.
Args: Args:
patterns (list): A list of regular expression patterns to split text. patterns (list): A list of regular expression patterns to split text.
""" """
if patterns is None: if patterns is None:
patterns = [r'\n\n'] # Default split pattern patterns = [r"\n\n"] # Default split pattern
self.patterns = patterns self.patterns = patterns
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
@@ -56,18 +60,19 @@ class RegexChunking(ChunkingStrategy):
new_paragraphs.extend(re.split(pattern, paragraph)) new_paragraphs.extend(re.split(pattern, paragraph))
paragraphs = new_paragraphs paragraphs = new_paragraphs
return paragraphs return paragraphs
# NLP-based sentence chunking
# NLP-based sentence chunking
class NlpSentenceChunking(ChunkingStrategy): class NlpSentenceChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text into sentences using NLTK's sentence tokenizer. Chunking strategy that splits text into sentences using NLTK's sentence tokenizer.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
Initialize the NlpSentenceChunking object. Initialize the NlpSentenceChunking object.
""" """
load_nltk_punkt() load_nltk_punkt()
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
# Improved regex for sentence splitting # Improved regex for sentence splitting
@@ -75,31 +80,34 @@ class NlpSentenceChunking(ChunkingStrategy):
# r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z][A-Z]\.)(?<![A-Za-z]\.)(?<=\.|\?|\!|\n)\s' # r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z][A-Z]\.)(?<![A-Za-z]\.)(?<=\.|\?|\!|\n)\s'
# ) # )
# sentences = sentence_endings.split(text) # sentences = sentence_endings.split(text)
# sens = [sent.strip() for sent in sentences if sent] # sens = [sent.strip() for sent in sentences if sent]
from nltk.tokenize import sent_tokenize from nltk.tokenize import sent_tokenize
sentences = sent_tokenize(text) sentences = sent_tokenize(text)
sens = [sent.strip() for sent in sentences] sens = [sent.strip() for sent in sentences]
return list(set(sens)) return list(set(sens))
# Topic-based segmentation using TextTiling # Topic-based segmentation using TextTiling
class TopicSegmentationChunking(ChunkingStrategy): class TopicSegmentationChunking(ChunkingStrategy):
""" """
Chunking strategy that segments text into topics using NLTK's TextTilingTokenizer. Chunking strategy that segments text into topics using NLTK's TextTilingTokenizer.
How it works: How it works:
1. Segment the text into topics using TextTilingTokenizer 1. Segment the text into topics using TextTilingTokenizer
2. Extract keywords for each topic segment 2. Extract keywords for each topic segment
""" """
def __init__(self, num_keywords=3, **kwargs): def __init__(self, num_keywords=3, **kwargs):
""" """
Initialize the TopicSegmentationChunking object. Initialize the TopicSegmentationChunking object.
Args: Args:
num_keywords (int): The number of keywords to extract for each topic segment. num_keywords (int): The number of keywords to extract for each topic segment.
""" """
import nltk as nl import nltk as nl
self.tokenizer = nl.tokenize.TextTilingTokenizer() self.tokenizer = nl.tokenize.TextTilingTokenizer()
self.num_keywords = num_keywords self.num_keywords = num_keywords
@@ -111,8 +119,14 @@ class TopicSegmentationChunking(ChunkingStrategy):
def extract_keywords(self, text: str) -> list: def extract_keywords(self, text: str) -> list:
# Tokenize and remove stopwords and punctuation # Tokenize and remove stopwords and punctuation
import nltk as nl import nltk as nl
tokens = nl.toknize.word_tokenize(text) tokens = nl.toknize.word_tokenize(text)
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation] tokens = [
token.lower()
for token in tokens
if token not in nl.corpus.stopwords.words("english")
and token not in string.punctuation
]
# Calculate frequency distribution # Calculate frequency distribution
freq_dist = Counter(tokens) freq_dist = Counter(tokens)
@@ -123,23 +137,27 @@ class TopicSegmentationChunking(ChunkingStrategy):
# Segment the text into topics # Segment the text into topics
segments = self.chunk(text) segments = self.chunk(text)
# Extract keywords for each topic segment # Extract keywords for each topic segment
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments] segments_with_topics = [
(segment, self.extract_keywords(segment)) for segment in segments
]
return segments_with_topics return segments_with_topics
# Fixed-length word chunks # Fixed-length word chunks
class FixedLengthWordChunking(ChunkingStrategy): class FixedLengthWordChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text into fixed-length word chunks. Chunking strategy that splits text into fixed-length word chunks.
How it works: How it works:
1. Split the text into words 1. Split the text into words
2. Create chunks of fixed length 2. Create chunks of fixed length
3. Return the list of chunks 3. Return the list of chunks
""" """
def __init__(self, chunk_size=100, **kwargs): def __init__(self, chunk_size=100, **kwargs):
""" """
Initialize the fixed-length word chunking strategy with the given chunk size. Initialize the fixed-length word chunking strategy with the given chunk size.
Args: Args:
chunk_size (int): The size of each chunk in words. chunk_size (int): The size of each chunk in words.
""" """
@@ -147,23 +165,28 @@ class FixedLengthWordChunking(ChunkingStrategy):
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
words = text.split() words = text.split()
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)] return [
" ".join(words[i : i + self.chunk_size])
for i in range(0, len(words), self.chunk_size)
]
# Sliding window chunking # Sliding window chunking
class SlidingWindowChunking(ChunkingStrategy): class SlidingWindowChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text into overlapping word chunks. Chunking strategy that splits text into overlapping word chunks.
How it works: How it works:
1. Split the text into words 1. Split the text into words
2. Create chunks of fixed length 2. Create chunks of fixed length
3. Return the list of chunks 3. Return the list of chunks
""" """
def __init__(self, window_size=100, step=50, **kwargs): def __init__(self, window_size=100, step=50, **kwargs):
""" """
Initialize the sliding window chunking strategy with the given window size and Initialize the sliding window chunking strategy with the given window size and
step size. step size.
Args: Args:
window_size (int): The size of the sliding window in words. window_size (int): The size of the sliding window in words.
step (int): The step size for sliding the window in words. step (int): The step size for sliding the window in words.
@@ -174,35 +197,37 @@ class SlidingWindowChunking(ChunkingStrategy):
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
words = text.split() words = text.split()
chunks = [] chunks = []
if len(words) <= self.window_size: if len(words) <= self.window_size:
return [text] return [text]
for i in range(0, len(words) - self.window_size + 1, self.step): for i in range(0, len(words) - self.window_size + 1, self.step):
chunk = ' '.join(words[i:i + self.window_size]) chunk = " ".join(words[i : i + self.window_size])
chunks.append(chunk) chunks.append(chunk)
# Handle the last chunk if it doesn't align perfectly # Handle the last chunk if it doesn't align perfectly
if i + self.window_size < len(words): if i + self.window_size < len(words):
chunks.append(' '.join(words[-self.window_size:])) chunks.append(" ".join(words[-self.window_size :]))
return chunks return chunks
class OverlappingWindowChunking(ChunkingStrategy): class OverlappingWindowChunking(ChunkingStrategy):
""" """
Chunking strategy that splits text into overlapping word chunks. Chunking strategy that splits text into overlapping word chunks.
How it works: How it works:
1. Split the text into words using whitespace 1. Split the text into words using whitespace
2. Create chunks of fixed length equal to the window size 2. Create chunks of fixed length equal to the window size
3. Slide the window by the overlap size 3. Slide the window by the overlap size
4. Return the list of chunks 4. Return the list of chunks
""" """
def __init__(self, window_size=1000, overlap=100, **kwargs): def __init__(self, window_size=1000, overlap=100, **kwargs):
""" """
Initialize the overlapping window chunking strategy with the given window size and Initialize the overlapping window chunking strategy with the given window size and
overlap size. overlap size.
Args: Args:
window_size (int): The size of the window in words. window_size (int): The size of the window in words.
overlap (int): The size of the overlap between consecutive chunks in words. overlap (int): The size of the overlap between consecutive chunks in words.
@@ -213,19 +238,19 @@ class OverlappingWindowChunking(ChunkingStrategy):
def chunk(self, text: str) -> list: def chunk(self, text: str) -> list:
words = text.split() words = text.split()
chunks = [] chunks = []
if len(words) <= self.window_size: if len(words) <= self.window_size:
return [text] return [text]
start = 0 start = 0
while start < len(words): while start < len(words):
end = start + self.window_size end = start + self.window_size
chunk = ' '.join(words[start:end]) chunk = " ".join(words[start:end])
chunks.append(chunk) chunks.append(chunk)
if end >= len(words): if end >= len(words):
break break
start = end - self.overlap start = end - self.overlap
return chunks return chunks

View File

@@ -8,15 +8,22 @@ from .async_logger import AsyncLogger
logger = AsyncLogger(verbose=True) logger = AsyncLogger(verbose=True)
docs_manager = DocsManager(logger) docs_manager = DocsManager(logger)
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2): def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
"""Print formatted table with headers and rows""" """Print formatted table with headers and rows"""
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)] widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+' border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+"
def format_row(row): def format_row(row):
return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}" return (
for cell, w in zip(row, widths)) + '|' "|"
+ "|".join(
f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
for cell, w in zip(row, widths)
)
+ "|"
)
click.echo(border) click.echo(border)
click.echo(format_row(headers)) click.echo(format_row(headers))
click.echo(border) click.echo(border)
@@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
click.echo(format_row(row)) click.echo(format_row(row))
click.echo(border) click.echo(border)
@click.group() @click.group()
def cli(): def cli():
"""Crawl4AI Command Line Interface""" """Crawl4AI Command Line Interface"""
pass pass
@cli.group() @cli.group()
def docs(): def docs():
"""Documentation operations""" """Documentation operations"""
pass pass
@docs.command() @docs.command()
@click.argument('sections', nargs=-1) @click.argument("sections", nargs=-1)
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended') @click.option(
"--mode", type=click.Choice(["extended", "condensed"]), default="extended"
)
def combine(sections: tuple, mode: str): def combine(sections: tuple, mode: str):
"""Combine documentation sections""" """Combine documentation sections"""
try: try:
@@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str):
logger.error(str(e), tag="ERROR") logger.error(str(e), tag="ERROR")
sys.exit(1) sys.exit(1)
@docs.command() @docs.command()
@click.argument('query') @click.argument("query")
@click.option('--top-k', '-k', default=5) @click.option("--top-k", "-k", default=5)
@click.option('--build-index', is_flag=True, help='Build index if missing') @click.option("--build-index", is_flag=True, help="Build index if missing")
def search(query: str, top_k: int, build_index: bool): def search(query: str, top_k: int, build_index: bool):
"""Search documentation""" """Search documentation"""
try: try:
result = docs_manager.search(query, top_k) result = docs_manager.search(query, top_k)
if result == "No search index available. Call build_search_index() first.": if result == "No search index available. Call build_search_index() first.":
if build_index or click.confirm('No search index found. Build it now?'): if build_index or click.confirm("No search index found. Build it now?"):
asyncio.run(docs_manager.llm_text.generate_index_files()) asyncio.run(docs_manager.llm_text.generate_index_files())
result = docs_manager.search(query, top_k) result = docs_manager.search(query, top_k)
click.echo(result) click.echo(result)
@@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool):
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
@docs.command() @docs.command()
def update(): def update():
"""Update docs from GitHub""" """Update docs from GitHub"""
@@ -73,22 +87,25 @@ def update():
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
@docs.command() @docs.command()
@click.option('--force-facts', is_flag=True, help='Force regenerate fact files') @click.option("--force-facts", is_flag=True, help="Force regenerate fact files")
@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache') @click.option("--clear-cache", is_flag=True, help="Clear BM25 cache")
def index(force_facts: bool, clear_cache: bool): def index(force_facts: bool, clear_cache: bool):
"""Build or rebuild search indexes""" """Build or rebuild search indexes"""
try: try:
asyncio.run(docs_manager.ensure_docs_exist()) asyncio.run(docs_manager.ensure_docs_exist())
asyncio.run(docs_manager.llm_text.generate_index_files( asyncio.run(
force_generate_facts=force_facts, docs_manager.llm_text.generate_index_files(
clear_bm25_cache=clear_cache force_generate_facts=force_facts, clear_bm25_cache=clear_cache
)) )
)
click.echo("Search indexes built successfully") click.echo("Search indexes built successfully")
except Exception as e: except Exception as e:
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
# Add docs list command # Add docs list command
@docs.command() @docs.command()
def list(): def list():
@@ -96,10 +113,11 @@ def list():
try: try:
sections = docs_manager.list() sections = docs_manager.list()
print_table(["Sections"], [[section] for section in sections]) print_table(["Sections"], [[section] for section in sections])
except Exception as e: except Exception as e:
click.echo(f"Error: {str(e)}", err=True) click.echo(f"Error: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
if __name__ == '__main__':
cli() if __name__ == "__main__":
cli()

View File

@@ -8,7 +8,7 @@ DEFAULT_PROVIDER = "openai/gpt-4o-mini"
MODEL_REPO_BRANCH = "new-release-0.0.2" MODEL_REPO_BRANCH = "new-release-0.0.2"
# Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy # Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy
PROVIDER_MODELS = { PROVIDER_MODELS = {
"ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token
"groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"), "groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"),
"groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"), "groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"),
"openai/gpt-4o-mini": os.getenv("OPENAI_API_KEY"), "openai/gpt-4o-mini": os.getenv("OPENAI_API_KEY"),
@@ -22,27 +22,49 @@ PROVIDER_MODELS = {
} }
# Chunk token threshold # Chunk token threshold
CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens CHUNK_TOKEN_THRESHOLD = 2**11 # 2048 tokens
OVERLAP_RATE = 0.1 OVERLAP_RATE = 0.1
WORD_TOKEN_RATE = 1.3 WORD_TOKEN_RATE = 1.3
# Threshold for the minimum number of word in a HTML tag to be considered # Threshold for the minimum number of word in a HTML tag to be considered
MIN_WORD_THRESHOLD = 1 MIN_WORD_THRESHOLD = 1
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1 IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1
IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height'] IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"]
ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark'] ONLY_TEXT_ELIGIBLE_TAGS = [
"b",
"i",
"u",
"span",
"del",
"ins",
"sub",
"sup",
"strong",
"em",
"code",
"kbd",
"var",
"s",
"q",
"abbr",
"cite",
"dfn",
"time",
"small",
"mark",
]
SOCIAL_MEDIA_DOMAINS = [ SOCIAL_MEDIA_DOMAINS = [
'facebook.com', "facebook.com",
'twitter.com', "twitter.com",
'x.com', "x.com",
'linkedin.com', "linkedin.com",
'instagram.com', "instagram.com",
'pinterest.com', "pinterest.com",
'tiktok.com', "tiktok.com",
'snapchat.com', "snapchat.com",
'reddit.com', "reddit.com",
] ]
# Threshold for the Image extraction - Range is 1 to 6 # Threshold for the Image extraction - Range is 1 to 6
# Images are scored based on point based system, to filter based on usefulness. Points are assigned # Images are scored based on point based system, to filter based on usefulness. Points are assigned
@@ -60,5 +82,5 @@ NEED_MIGRATION = True
URL_LOG_SHORTEN_LENGTH = 30 URL_LOG_SHORTEN_LENGTH = 30
SHOW_DEPRECATION_WARNINGS = True SHOW_DEPRECATION_WARNINGS = True
SCREENSHOT_HEIGHT_TRESHOLD = 10000 SCREENSHOT_HEIGHT_TRESHOLD = 10000
PAGE_TIMEOUT=60000 PAGE_TIMEOUT = 60000
DOWNLOAD_PAGE_TIMEOUT=60000 DOWNLOAD_PAGE_TIMEOUT = 60000

View File

@@ -1,59 +1,100 @@
import re import re
from bs4 import BeautifulSoup, Tag from bs4 import BeautifulSoup, Tag
from typing import List, Tuple, Dict from typing import List, Tuple
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
from time import perf_counter
from collections import deque from collections import deque
from bs4 import BeautifulSoup, NavigableString, Tag, Comment from bs4 import NavigableString, Comment
from .utils import clean_tokens from .utils import clean_tokens
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import math import math
from snowballstemmer import stemmer from snowballstemmer import stemmer
class RelevantContentFilter(ABC): class RelevantContentFilter(ABC):
"""Abstract base class for content filtering strategies""" """Abstract base class for content filtering strategies"""
def __init__(self, user_query: str = None): def __init__(self, user_query: str = None):
self.user_query = user_query self.user_query = user_query
self.included_tags = { self.included_tags = {
# Primary structure # Primary structure
'article', 'main', 'section', 'div', "article",
"main",
"section",
"div",
# List structures # List structures
'ul', 'ol', 'li', 'dl', 'dt', 'dd', "ul",
"ol",
"li",
"dl",
"dt",
"dd",
# Text content # Text content
'p', 'span', 'blockquote', 'pre', 'code', "p",
"span",
"blockquote",
"pre",
"code",
# Headers # Headers
'h1', 'h2', 'h3', 'h4', 'h5', 'h6', "h1",
"h2",
"h3",
"h4",
"h5",
"h6",
# Tables # Tables
'table', 'thead', 'tbody', 'tr', 'td', 'th', "table",
"thead",
"tbody",
"tr",
"td",
"th",
# Other semantic elements # Other semantic elements
'figure', 'figcaption', 'details', 'summary', "figure",
"figcaption",
"details",
"summary",
# Text formatting # Text formatting
'em', 'strong', 'b', 'i', 'mark', 'small', "em",
"strong",
"b",
"i",
"mark",
"small",
# Rich content # Rich content
'time', 'address', 'cite', 'q' "time",
"address",
"cite",
"q",
} }
self.excluded_tags = { self.excluded_tags = {
'nav', 'footer', 'header', 'aside', 'script', "nav",
'style', 'form', 'iframe', 'noscript' "footer",
"header",
"aside",
"script",
"style",
"form",
"iframe",
"noscript",
} }
self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'} self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"}
self.negative_patterns = re.compile( self.negative_patterns = re.compile(
r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share', r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I
re.I
) )
self.min_word_count = 2 self.min_word_count = 2
@abstractmethod @abstractmethod
def filter_content(self, html: str) -> List[str]: def filter_content(self, html: str) -> List[str]:
"""Abstract method to be implemented by specific filtering strategies""" """Abstract method to be implemented by specific filtering strategies"""
pass pass
def extract_page_query(self, soup: BeautifulSoup, body: Tag) -> str: def extract_page_query(self, soup: BeautifulSoup, body: Tag) -> str:
"""Common method to extract page metadata with fallbacks""" """Common method to extract page metadata with fallbacks"""
if self.user_query: if self.user_query:
return self.user_query return self.user_query
query_parts = [] query_parts = []
# Title # Title
try: try:
title = soup.title.string title = soup.title.string
@@ -62,109 +103,145 @@ class RelevantContentFilter(ABC):
except Exception: except Exception:
pass pass
if soup.find('h1'): if soup.find("h1"):
query_parts.append(soup.find('h1').get_text()) query_parts.append(soup.find("h1").get_text())
# Meta tags # Meta tags
temp = "" temp = ""
for meta_name in ['keywords', 'description']: for meta_name in ["keywords", "description"]:
meta = soup.find('meta', attrs={'name': meta_name}) meta = soup.find("meta", attrs={"name": meta_name})
if meta and meta.get('content'): if meta and meta.get("content"):
query_parts.append(meta['content']) query_parts.append(meta["content"])
temp += meta['content'] temp += meta["content"]
# If still empty, grab first significant paragraph # If still empty, grab first significant paragraph
if not temp: if not temp:
# Find the first tag P thatits text contains more than 50 characters # Find the first tag P thatits text contains more than 50 characters
for p in body.find_all('p'): for p in body.find_all("p"):
if len(p.get_text()) > 150: if len(p.get_text()) > 150:
query_parts.append(p.get_text()[:150]) query_parts.append(p.get_text()[:150])
break break
return ' '.join(filter(None, query_parts))
def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]: return " ".join(filter(None, query_parts))
def extract_text_chunks(
self, body: Tag, min_word_threshold: int = None
) -> List[Tuple[str, str]]:
""" """
Extracts text chunks from a BeautifulSoup body element while preserving order. Extracts text chunks from a BeautifulSoup body element while preserving order.
Returns list of tuples (text, tag_name) for classification. Returns list of tuples (text, tag_name) for classification.
Args: Args:
body: BeautifulSoup Tag object representing the body element body: BeautifulSoup Tag object representing the body element
Returns: Returns:
List of (text, tag_name) tuples List of (text, tag_name) tuples
""" """
# Tags to ignore - inline elements that shouldn't break text flow # Tags to ignore - inline elements that shouldn't break text flow
INLINE_TAGS = { INLINE_TAGS = {
'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code', "a",
'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q', "abbr",
'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup', "acronym",
'textarea', 'time', 'tt', 'var' "b",
"bdo",
"big",
"br",
"button",
"cite",
"code",
"dfn",
"em",
"i",
"img",
"input",
"kbd",
"label",
"map",
"object",
"q",
"samp",
"script",
"select",
"small",
"span",
"strong",
"sub",
"sup",
"textarea",
"time",
"tt",
"var",
} }
# Tags that typically contain meaningful headers # Tags that typically contain meaningful headers
HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'} HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"}
chunks = [] chunks = []
current_text = [] current_text = []
chunk_index = 0 chunk_index = 0
def should_break_chunk(tag: Tag) -> bool: def should_break_chunk(tag: Tag) -> bool:
"""Determine if a tag should cause a break in the current text chunk""" """Determine if a tag should cause a break in the current text chunk"""
return ( return tag.name not in INLINE_TAGS and not (
tag.name not in INLINE_TAGS tag.name == "p" and len(current_text) == 0
and not (tag.name == 'p' and len(current_text) == 0)
) )
# Use deque for efficient push/pop operations # Use deque for efficient push/pop operations
stack = deque([(body, False)]) stack = deque([(body, False)])
while stack: while stack:
element, visited = stack.pop() element, visited = stack.pop()
if visited: if visited:
# End of block element - flush accumulated text # End of block element - flush accumulated text
if current_text and should_break_chunk(element): if current_text and should_break_chunk(element):
text = ' '.join(''.join(current_text).split()) text = " ".join("".join(current_text).split())
if text: if text:
tag_type = 'header' if element.name in HEADER_TAGS else 'content' tag_type = (
"header" if element.name in HEADER_TAGS else "content"
)
chunks.append((chunk_index, text, tag_type, element)) chunks.append((chunk_index, text, tag_type, element))
chunk_index += 1 chunk_index += 1
current_text = [] current_text = []
continue continue
if isinstance(element, NavigableString): if isinstance(element, NavigableString):
if str(element).strip(): if str(element).strip():
current_text.append(str(element).strip()) current_text.append(str(element).strip())
continue continue
# Pre-allocate children to avoid multiple list operations # Pre-allocate children to avoid multiple list operations
children = list(element.children) children = list(element.children)
if not children: if not children:
continue continue
# Mark block for revisit after processing children # Mark block for revisit after processing children
stack.append((element, True)) stack.append((element, True))
# Add children in reverse order for correct processing # Add children in reverse order for correct processing
for child in reversed(children): for child in reversed(children):
if isinstance(child, (Tag, NavigableString)): if isinstance(child, (Tag, NavigableString)):
stack.append((child, False)) stack.append((child, False))
# Handle any remaining text # Handle any remaining text
if current_text: if current_text:
text = ' '.join(''.join(current_text).split()) text = " ".join("".join(current_text).split())
if text: if text:
chunks.append((chunk_index, text, 'content', body)) chunks.append((chunk_index, text, "content", body))
if min_word_threshold:
chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold]
return chunks
def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]: if min_word_threshold:
chunks = [
chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold
]
return chunks
def _deprecated_extract_text_chunks(
self, soup: BeautifulSoup
) -> List[Tuple[int, str, Tag]]:
"""Common method for extracting text chunks""" """Common method for extracting text chunks"""
_text_cache = {} _text_cache = {}
def fast_text(element: Tag) -> str: def fast_text(element: Tag) -> str:
elem_id = id(element) elem_id = id(element)
if elem_id in _text_cache: if elem_id in _text_cache:
@@ -175,13 +252,13 @@ class RelevantContentFilter(ABC):
text = content.strip() text = content.strip()
if text: if text:
texts.append(text) texts.append(text)
result = ' '.join(texts) result = " ".join(texts)
_text_cache[elem_id] = result _text_cache[elem_id] = result
return result return result
candidates = [] candidates = []
index = 0 index = 0
def dfs(element): def dfs(element):
nonlocal index nonlocal index
if isinstance(element, Tag): if isinstance(element, Tag):
@@ -189,7 +266,7 @@ class RelevantContentFilter(ABC):
if not self.is_excluded(element): if not self.is_excluded(element):
text = fast_text(element) text = fast_text(element)
word_count = len(text.split()) word_count = len(text.split())
# Headers pass through with adjusted minimum # Headers pass through with adjusted minimum
if element.name in self.header_tags: if element.name in self.header_tags:
if word_count >= 3: # Minimal sanity check for headers if word_count >= 3: # Minimal sanity check for headers
@@ -199,7 +276,7 @@ class RelevantContentFilter(ABC):
elif word_count >= self.min_word_count: elif word_count >= self.min_word_count:
candidates.append((index, text, element)) candidates.append((index, text, element))
index += 1 index += 1
for child in element.children: for child in element.children:
dfs(child) dfs(child)
@@ -210,59 +287,67 @@ class RelevantContentFilter(ABC):
"""Common method for exclusion logic""" """Common method for exclusion logic"""
if tag.name in self.excluded_tags: if tag.name in self.excluded_tags:
return True return True
class_id = ' '.join(filter(None, [ class_id = " ".join(
' '.join(tag.get('class', [])), filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")])
tag.get('id', '') )
]))
return bool(self.negative_patterns.search(class_id)) return bool(self.negative_patterns.search(class_id))
def clean_element(self, tag: Tag) -> str: def clean_element(self, tag: Tag) -> str:
"""Common method for cleaning HTML elements with minimal overhead""" """Common method for cleaning HTML elements with minimal overhead"""
if not tag or not isinstance(tag, Tag): if not tag or not isinstance(tag, Tag):
return "" return ""
unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'} unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"}
unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'} unwanted_attrs = {
"style",
"onclick",
"onmouseover",
"align",
"bgcolor",
"class",
"id",
}
# Use string builder pattern for better performance # Use string builder pattern for better performance
builder = [] builder = []
def render_tag(elem): def render_tag(elem):
if not isinstance(elem, Tag): if not isinstance(elem, Tag):
if isinstance(elem, str): if isinstance(elem, str):
builder.append(elem.strip()) builder.append(elem.strip())
return return
if elem.name in unwanted_tags: if elem.name in unwanted_tags:
return return
# Start tag # Start tag
builder.append(f'<{elem.name}') builder.append(f"<{elem.name}")
# Add cleaned attributes # Add cleaned attributes
attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs} attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs}
for key, value in attrs.items(): for key, value in attrs.items():
builder.append(f' {key}="{value}"') builder.append(f' {key}="{value}"')
builder.append('>') builder.append(">")
# Process children # Process children
for child in elem.children: for child in elem.children:
render_tag(child) render_tag(child)
# Close tag # Close tag
builder.append(f'</{elem.name}>') builder.append(f"</{elem.name}>")
try: try:
render_tag(tag) render_tag(tag)
return ''.join(builder) return "".join(builder)
except Exception: except Exception:
return str(tag) # Fallback to original if anything fails return str(tag) # Fallback to original if anything fails
class BM25ContentFilter(RelevantContentFilter): class BM25ContentFilter(RelevantContentFilter):
""" """
Content filtering using BM25 algorithm with priority tag handling. Content filtering using BM25 algorithm with priority tag handling.
How it works: How it works:
1. Extracts page metadata with fallbacks. 1. Extracts page metadata with fallbacks.
2. Extracts text chunks from the body element. 2. Extracts text chunks from the body element.
@@ -271,22 +356,28 @@ class BM25ContentFilter(RelevantContentFilter):
5. Filters out chunks below the threshold. 5. Filters out chunks below the threshold.
6. Sorts chunks by score in descending order. 6. Sorts chunks by score in descending order.
7. Returns the top N chunks. 7. Returns the top N chunks.
Attributes: Attributes:
user_query (str): User query for filtering (optional). user_query (str): User query for filtering (optional).
bm25_threshold (float): BM25 threshold for filtering (default: 1.0). bm25_threshold (float): BM25 threshold for filtering (default: 1.0).
language (str): Language for stemming (default: 'english'). language (str): Language for stemming (default: 'english').
Methods: Methods:
filter_content(self, html: str, min_word_threshold: int = None) filter_content(self, html: str, min_word_threshold: int = None)
""" """
def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'):
def __init__(
self,
user_query: str = None,
bm25_threshold: float = 1.0,
language: str = "english",
):
""" """
Initializes the BM25ContentFilter class, if not provided, falls back to page metadata. Initializes the BM25ContentFilter class, if not provided, falls back to page metadata.
Note: Note:
If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph. If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph.
Args: Args:
user_query (str): User query for filtering (optional). user_query (str): User query for filtering (optional).
bm25_threshold (float): BM25 threshold for filtering (default: 1.0). bm25_threshold (float): BM25 threshold for filtering (default: 1.0).
@@ -295,52 +386,52 @@ class BM25ContentFilter(RelevantContentFilter):
super().__init__(user_query=user_query) super().__init__(user_query=user_query)
self.bm25_threshold = bm25_threshold self.bm25_threshold = bm25_threshold
self.priority_tags = { self.priority_tags = {
'h1': 5.0, "h1": 5.0,
'h2': 4.0, "h2": 4.0,
'h3': 3.0, "h3": 3.0,
'title': 4.0, "title": 4.0,
'strong': 2.0, "strong": 2.0,
'b': 1.5, "b": 1.5,
'em': 1.5, "em": 1.5,
'blockquote': 2.0, "blockquote": 2.0,
'code': 2.0, "code": 2.0,
'pre': 1.5, "pre": 1.5,
'th': 1.5, # Table headers "th": 1.5, # Table headers
} }
self.stemmer = stemmer(language) self.stemmer = stemmer(language)
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
""" """
Implements content filtering using BM25 algorithm with priority tag handling. Implements content filtering using BM25 algorithm with priority tag handling.
Note: Note:
This method implements the filtering logic for the BM25ContentFilter class. This method implements the filtering logic for the BM25ContentFilter class.
It takes HTML content as input and returns a list of filtered text chunks. It takes HTML content as input and returns a list of filtered text chunks.
Args: Args:
html (str): HTML content to be filtered. html (str): HTML content to be filtered.
min_word_threshold (int): Minimum word threshold for filtering (optional). min_word_threshold (int): Minimum word threshold for filtering (optional).
Returns: Returns:
List[str]: List of filtered text chunks. List[str]: List of filtered text chunks.
""" """
if not html or not isinstance(html, str): if not html or not isinstance(html, str):
return [] return []
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
# Check if body is present # Check if body is present
if not soup.body: if not soup.body:
# Wrap in body tag if missing # Wrap in body tag if missing
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml') soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
body = soup.find('body') body = soup.find("body")
query = self.extract_page_query(soup, body) query = self.extract_page_query(soup, body)
if not query: if not query:
return [] return []
# return [self.clean_element(soup)] # return [self.clean_element(soup)]
candidates = self.extract_text_chunks(body, min_word_threshold) candidates = self.extract_text_chunks(body, min_word_threshold)
if not candidates: if not candidates:
@@ -349,16 +440,20 @@ class BM25ContentFilter(RelevantContentFilter):
# Tokenize corpus # Tokenize corpus
# tokenized_corpus = [chunk.lower().split() for _, chunk, _, _ in candidates] # tokenized_corpus = [chunk.lower().split() for _, chunk, _, _ in candidates]
# tokenized_query = query.lower().split() # tokenized_query = query.lower().split()
# tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()]
# for _, chunk, _, _ in candidates]
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()]
for _, chunk, _, _ in candidates]
tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()]
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] # tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()]
# for _, chunk, _, _ in candidates]
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
tokenized_corpus = [
[self.stemmer.stemWord(word) for word in chunk.lower().split()]
for _, chunk, _, _ in candidates
]
tokenized_query = [
self.stemmer.stemWord(word) for word in query.lower().split()
]
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
# for _, chunk, _, _ in candidates] # for _, chunk, _, _ in candidates]
# tokenized_query = [self.stemmer.stemWord(word) for word in tokenize_text(query.lower())] # tokenized_query = [self.stemmer.stemWord(word) for word in tokenize_text(query.lower())]
@@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter):
# Filter candidates by threshold # Filter candidates by threshold
selected_candidates = [ selected_candidates = [
(index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates (index, chunk, tag)
for adjusted_score, index, chunk, tag in adjusted_candidates
if adjusted_score >= self.bm25_threshold if adjusted_score >= self.bm25_threshold
] ]
@@ -390,10 +486,11 @@ class BM25ContentFilter(RelevantContentFilter):
return [self.clean_element(tag) for _, _, tag in selected_candidates] return [self.clean_element(tag) for _, _, tag in selected_candidates]
class PruningContentFilter(RelevantContentFilter): class PruningContentFilter(RelevantContentFilter):
""" """
Content filtering using pruning algorithm with dynamic threshold. Content filtering using pruning algorithm with dynamic threshold.
How it works: How it works:
1. Extracts page metadata with fallbacks. 1. Extracts page metadata with fallbacks.
2. Extracts text chunks from the body element. 2. Extracts text chunks from the body element.
@@ -407,18 +504,24 @@ class PruningContentFilter(RelevantContentFilter):
min_word_threshold (int): Minimum word threshold for filtering (optional). min_word_threshold (int): Minimum word threshold for filtering (optional).
threshold_type (str): Threshold type for dynamic threshold (default: 'fixed'). threshold_type (str): Threshold type for dynamic threshold (default: 'fixed').
threshold (float): Fixed threshold value (default: 0.48). threshold (float): Fixed threshold value (default: 0.48).
Methods: Methods:
filter_content(self, html: str, min_word_threshold: int = None): filter_content(self, html: str, min_word_threshold: int = None):
""" """
def __init__(self, user_query: str = None, min_word_threshold: int = None,
threshold_type: str = 'fixed', threshold: float = 0.48): def __init__(
self,
user_query: str = None,
min_word_threshold: int = None,
threshold_type: str = "fixed",
threshold: float = 0.48,
):
""" """
Initializes the PruningContentFilter class, if not provided, falls back to page metadata. Initializes the PruningContentFilter class, if not provided, falls back to page metadata.
Note: Note:
If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph. If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph.
Args: Args:
user_query (str): User query for filtering (optional). user_query (str): User query for filtering (optional).
min_word_threshold (int): Minimum word threshold for filtering (optional). min_word_threshold (int): Minimum word threshold for filtering (optional).
@@ -429,92 +532,92 @@ class PruningContentFilter(RelevantContentFilter):
self.min_word_threshold = min_word_threshold self.min_word_threshold = min_word_threshold
self.threshold_type = threshold_type self.threshold_type = threshold_type
self.threshold = threshold self.threshold = threshold
# Add tag importance for dynamic threshold # Add tag importance for dynamic threshold
self.tag_importance = { self.tag_importance = {
'article': 1.5, "article": 1.5,
'main': 1.4, "main": 1.4,
'section': 1.3, "section": 1.3,
'p': 1.2, "p": 1.2,
'h1': 1.4, "h1": 1.4,
'h2': 1.3, "h2": 1.3,
'h3': 1.2, "h3": 1.2,
'div': 0.7, "div": 0.7,
'span': 0.6 "span": 0.6,
} }
# Metric configuration # Metric configuration
self.metric_config = { self.metric_config = {
'text_density': True, "text_density": True,
'link_density': True, "link_density": True,
'tag_weight': True, "tag_weight": True,
'class_id_weight': True, "class_id_weight": True,
'text_length': True, "text_length": True,
} }
self.metric_weights = { self.metric_weights = {
'text_density': 0.4, "text_density": 0.4,
'link_density': 0.2, "link_density": 0.2,
'tag_weight': 0.2, "tag_weight": 0.2,
'class_id_weight': 0.1, "class_id_weight": 0.1,
'text_length': 0.1, "text_length": 0.1,
} }
self.tag_weights = { self.tag_weights = {
'div': 0.5, "div": 0.5,
'p': 1.0, "p": 1.0,
'article': 1.5, "article": 1.5,
'section': 1.0, "section": 1.0,
'span': 0.3, "span": 0.3,
'li': 0.5, "li": 0.5,
'ul': 0.5, "ul": 0.5,
'ol': 0.5, "ol": 0.5,
'h1': 1.2, "h1": 1.2,
'h2': 1.1, "h2": 1.1,
'h3': 1.0, "h3": 1.0,
'h4': 0.9, "h4": 0.9,
'h5': 0.8, "h5": 0.8,
'h6': 0.7, "h6": 0.7,
} }
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
""" """
Implements content filtering using pruning algorithm with dynamic threshold. Implements content filtering using pruning algorithm with dynamic threshold.
Note: Note:
This method implements the filtering logic for the PruningContentFilter class. This method implements the filtering logic for the PruningContentFilter class.
It takes HTML content as input and returns a list of filtered text chunks. It takes HTML content as input and returns a list of filtered text chunks.
Args: Args:
html (str): HTML content to be filtered. html (str): HTML content to be filtered.
min_word_threshold (int): Minimum word threshold for filtering (optional). min_word_threshold (int): Minimum word threshold for filtering (optional).
Returns: Returns:
List[str]: List of filtered text chunks. List[str]: List of filtered text chunks.
""" """
if not html or not isinstance(html, str): if not html or not isinstance(html, str):
return [] return []
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
if not soup.body: if not soup.body:
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml') soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
# Remove comments and unwanted tags # Remove comments and unwanted tags
self._remove_comments(soup) self._remove_comments(soup)
self._remove_unwanted_tags(soup) self._remove_unwanted_tags(soup)
# Prune tree starting from body # Prune tree starting from body
body = soup.find('body') body = soup.find("body")
self._prune_tree(body) self._prune_tree(body)
# Extract remaining content as list of HTML strings # Extract remaining content as list of HTML strings
content_blocks = [] content_blocks = []
for element in body.children: for element in body.children:
if isinstance(element, str) or not hasattr(element, 'name'): if isinstance(element, str) or not hasattr(element, "name"):
continue continue
if len(element.get_text(strip=True)) > 0: if len(element.get_text(strip=True)) > 0:
content_blocks.append(str(element)) content_blocks.append(str(element))
return content_blocks return content_blocks
def _remove_comments(self, soup): def _remove_comments(self, soup):
@@ -531,34 +634,38 @@ class PruningContentFilter(RelevantContentFilter):
def _prune_tree(self, node): def _prune_tree(self, node):
""" """
Prunes the tree starting from the given node. Prunes the tree starting from the given node.
Args: Args:
node (Tag): The node from which the pruning starts. node (Tag): The node from which the pruning starts.
""" """
if not node or not hasattr(node, 'name') or node.name is None: if not node or not hasattr(node, "name") or node.name is None:
return return
text_len = len(node.get_text(strip=True)) text_len = len(node.get_text(strip=True))
tag_len = len(node.encode_contents().decode('utf-8')) tag_len = len(node.encode_contents().decode("utf-8"))
link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s) link_text_len = sum(
len(s.strip())
for s in (a.string for a in node.find_all("a", recursive=False))
if s
)
metrics = { metrics = {
'node': node, "node": node,
'tag_name': node.name, "tag_name": node.name,
'text_len': text_len, "text_len": text_len,
'tag_len': tag_len, "tag_len": tag_len,
'link_text_len': link_text_len "link_text_len": link_text_len,
} }
score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len) score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len)
if self.threshold_type == 'fixed': if self.threshold_type == "fixed":
should_remove = score < self.threshold should_remove = score < self.threshold
else: # dynamic else: # dynamic
tag_importance = self.tag_importance.get(node.name, 0.7) tag_importance = self.tag_importance.get(node.name, 0.7)
text_ratio = text_len / tag_len if tag_len > 0 else 0 text_ratio = text_len / tag_len if tag_len > 0 else 0
link_ratio = link_text_len / text_len if text_len > 0 else 1 link_ratio = link_text_len / text_len if text_len > 0 else 1
threshold = self.threshold # base threshold threshold = self.threshold # base threshold
if tag_importance > 1: if tag_importance > 1:
threshold *= 0.8 threshold *= 0.8
@@ -566,13 +673,13 @@ class PruningContentFilter(RelevantContentFilter):
threshold *= 0.9 threshold *= 0.9
if link_ratio > 0.6: if link_ratio > 0.6:
threshold *= 1.2 threshold *= 1.2
should_remove = score < threshold should_remove = score < threshold
if should_remove: if should_remove:
node.decompose() node.decompose()
else: else:
children = [child for child in node.children if hasattr(child, 'name')] children = [child for child in node.children if hasattr(child, "name")]
for child in children: for child in children:
self._prune_tree(child) self._prune_tree(child)
@@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter):
"""Computes the composite score""" """Computes the composite score"""
if self.min_word_threshold: if self.min_word_threshold:
# Get raw text from metrics node - avoid extra processing # Get raw text from metrics node - avoid extra processing
text = metrics['node'].get_text(strip=True) text = metrics["node"].get_text(strip=True)
word_count = text.count(' ') + 1 word_count = text.count(" ") + 1
if word_count < self.min_word_threshold: if word_count < self.min_word_threshold:
return -1.0 # Guaranteed removal return -1.0 # Guaranteed removal
score = 0.0 score = 0.0
total_weight = 0.0 total_weight = 0.0
if self.metric_config['text_density']: if self.metric_config["text_density"]:
density = text_len / tag_len if tag_len > 0 else 0 density = text_len / tag_len if tag_len > 0 else 0
score += self.metric_weights['text_density'] * density score += self.metric_weights["text_density"] * density
total_weight += self.metric_weights['text_density'] total_weight += self.metric_weights["text_density"]
if self.metric_config['link_density']: if self.metric_config["link_density"]:
density = 1 - (link_text_len / text_len if text_len > 0 else 0) density = 1 - (link_text_len / text_len if text_len > 0 else 0)
score += self.metric_weights['link_density'] * density score += self.metric_weights["link_density"] * density
total_weight += self.metric_weights['link_density'] total_weight += self.metric_weights["link_density"]
if self.metric_config['tag_weight']: if self.metric_config["tag_weight"]:
tag_score = self.tag_weights.get(metrics['tag_name'], 0.5) tag_score = self.tag_weights.get(metrics["tag_name"], 0.5)
score += self.metric_weights['tag_weight'] * tag_score score += self.metric_weights["tag_weight"] * tag_score
total_weight += self.metric_weights['tag_weight'] total_weight += self.metric_weights["tag_weight"]
if self.metric_config['class_id_weight']: if self.metric_config["class_id_weight"]:
class_score = self._compute_class_id_weight(metrics['node']) class_score = self._compute_class_id_weight(metrics["node"])
score += self.metric_weights['class_id_weight'] * max(0, class_score) score += self.metric_weights["class_id_weight"] * max(0, class_score)
total_weight += self.metric_weights['class_id_weight'] total_weight += self.metric_weights["class_id_weight"]
if self.metric_config['text_length']: if self.metric_config["text_length"]:
score += self.metric_weights['text_length'] * math.log(text_len + 1) score += self.metric_weights["text_length"] * math.log(text_len + 1)
total_weight += self.metric_weights['text_length'] total_weight += self.metric_weights["text_length"]
return score / total_weight if total_weight > 0 else 0 return score / total_weight if total_weight > 0 else 0
def _compute_class_id_weight(self, node): def _compute_class_id_weight(self, node):
"""Computes the class ID weight""" """Computes the class ID weight"""
class_id_score = 0 class_id_score = 0
if 'class' in node.attrs: if "class" in node.attrs:
classes = ' '.join(node['class']) classes = " ".join(node["class"])
if self.negative_patterns.match(classes): if self.negative_patterns.match(classes):
class_id_score -= 0.5 class_id_score -= 0.5
if 'id' in node.attrs: if "id" in node.attrs:
element_id = node['id'] element_id = node["id"]
if self.negative_patterns.match(element_id): if self.negative_patterns.match(element_id):
class_id_score -= 0.5 class_id_score -= 0.5
return class_id_score return class_id_score

File diff suppressed because it is too large Load Diff

View File

@@ -15,54 +15,53 @@ import logging, time
import base64 import base64
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from io import BytesIO from io import BytesIO
from typing import List, Callable from typing import Callable
import requests import requests
import os import os
from pathlib import Path from pathlib import Path
from .utils import * from .utils import *
logger = logging.getLogger('selenium.webdriver.remote.remote_connection') logger = logging.getLogger("selenium.webdriver.remote.remote_connection")
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
logger_driver = logging.getLogger('selenium.webdriver.common.service') logger_driver = logging.getLogger("selenium.webdriver.common.service")
logger_driver.setLevel(logging.WARNING) logger_driver.setLevel(logging.WARNING)
urllib3_logger = logging.getLogger('urllib3.connectionpool') urllib3_logger = logging.getLogger("urllib3.connectionpool")
urllib3_logger.setLevel(logging.WARNING) urllib3_logger.setLevel(logging.WARNING)
# Disable http.client logging # Disable http.client logging
http_client_logger = logging.getLogger('http.client') http_client_logger = logging.getLogger("http.client")
http_client_logger.setLevel(logging.WARNING) http_client_logger.setLevel(logging.WARNING)
# Disable driver_finder and service logging # Disable driver_finder and service logging
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder') driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder")
driver_finder_logger.setLevel(logging.WARNING) driver_finder_logger.setLevel(logging.WARNING)
class CrawlerStrategy(ABC): class CrawlerStrategy(ABC):
@abstractmethod @abstractmethod
def crawl(self, url: str, **kwargs) -> str: def crawl(self, url: str, **kwargs) -> str:
pass pass
@abstractmethod @abstractmethod
def take_screenshot(self, save_path: str): def take_screenshot(self, save_path: str):
pass pass
@abstractmethod @abstractmethod
def update_user_agent(self, user_agent: str): def update_user_agent(self, user_agent: str):
pass pass
@abstractmethod @abstractmethod
def set_hook(self, hook_type: str, hook: Callable): def set_hook(self, hook_type: str, hook: Callable):
pass pass
class CloudCrawlerStrategy(CrawlerStrategy): class CloudCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html = False): def __init__(self, use_cached_html=False):
super().__init__() super().__init__()
self.use_cached_html = use_cached_html self.use_cached_html = use_cached_html
def crawl(self, url: str) -> str: def crawl(self, url: str) -> str:
data = { data = {
"urls": [url], "urls": [url],
@@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
html = response["results"][0]["html"] html = response["results"][0]["html"]
return sanitize_input_encode(html) return sanitize_input_encode(html)
class LocalSeleniumCrawlerStrategy(CrawlerStrategy): class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None, **kwargs): def __init__(self, use_cached_html=False, js_code=None, **kwargs):
super().__init__() super().__init__()
@@ -87,20 +87,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if kwargs.get("user_agent"): if kwargs.get("user_agent"):
self.options.add_argument("--user-agent=" + kwargs.get("user_agent")) self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
else: else:
user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") user_agent = kwargs.get(
"user_agent",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
)
self.options.add_argument(f"--user-agent={user_agent}") self.options.add_argument(f"--user-agent={user_agent}")
self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") self.options.add_argument(
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
)
self.options.headless = kwargs.get("headless", True) self.options.headless = kwargs.get("headless", True)
if self.options.headless: if self.options.headless:
self.options.add_argument("--headless") self.options.add_argument("--headless")
self.options.add_argument("--disable-gpu") self.options.add_argument("--disable-gpu")
self.options.add_argument("--window-size=1920,1080") self.options.add_argument("--window-size=1920,1080")
self.options.add_argument("--no-sandbox") self.options.add_argument("--no-sandbox")
self.options.add_argument("--disable-dev-shm-usage") self.options.add_argument("--disable-dev-shm-usage")
self.options.add_argument("--disable-blink-features=AutomationControlled") self.options.add_argument("--disable-blink-features=AutomationControlled")
# self.options.add_argument("--disable-dev-shm-usage") # self.options.add_argument("--disable-dev-shm-usage")
self.options.add_argument("--disable-gpu") self.options.add_argument("--disable-gpu")
# self.options.add_argument("--disable-extensions") # self.options.add_argument("--disable-extensions")
@@ -120,14 +125,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.use_cached_html = use_cached_html self.use_cached_html = use_cached_html
self.js_code = js_code self.js_code = js_code
self.verbose = kwargs.get("verbose", False) self.verbose = kwargs.get("verbose", False)
# Hooks # Hooks
self.hooks = { self.hooks = {
'on_driver_created': None, "on_driver_created": None,
'on_user_agent_updated': None, "on_user_agent_updated": None,
'before_get_url': None, "before_get_url": None,
'after_get_url': None, "after_get_url": None,
'before_return_html': None "before_return_html": None,
} }
# chromedriver_autoinstaller.install() # chromedriver_autoinstaller.install()
@@ -137,31 +142,28 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# chromedriver_path = chromedriver_autoinstaller.install() # chromedriver_path = chromedriver_autoinstaller.install()
# chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver() # chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver()
# self.service = Service(chromedriver_autoinstaller.install()) # self.service = Service(chromedriver_autoinstaller.install())
# chromedriver_path = ChromeDriverManager().install() # chromedriver_path = ChromeDriverManager().install()
# self.service = Service(chromedriver_path) # self.service = Service(chromedriver_path)
# self.service.log_path = "NUL" # self.service.log_path = "NUL"
# self.driver = webdriver.Chrome(service=self.service, options=self.options) # self.driver = webdriver.Chrome(service=self.service, options=self.options)
# Use selenium-manager (built into Selenium 4.10.0+) # Use selenium-manager (built into Selenium 4.10.0+)
self.service = Service() self.service = Service()
self.driver = webdriver.Chrome(options=self.options) self.driver = webdriver.Chrome(options=self.options)
self.driver = self.execute_hook('on_driver_created', self.driver) self.driver = self.execute_hook("on_driver_created", self.driver)
if kwargs.get("cookies"): if kwargs.get("cookies"):
for cookie in kwargs.get("cookies"): for cookie in kwargs.get("cookies"):
self.driver.add_cookie(cookie) self.driver.add_cookie(cookie)
def set_hook(self, hook_type: str, hook: Callable): def set_hook(self, hook_type: str, hook: Callable):
if hook_type in self.hooks: if hook_type in self.hooks:
self.hooks[hook_type] = hook self.hooks[hook_type] = hook
else: else:
raise ValueError(f"Invalid hook type: {hook_type}") raise ValueError(f"Invalid hook type: {hook_type}")
def execute_hook(self, hook_type: str, *args): def execute_hook(self, hook_type: str, *args):
hook = self.hooks.get(hook_type) hook = self.hooks.get(hook_type)
if hook: if hook:
@@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
if isinstance(result, webdriver.Chrome): if isinstance(result, webdriver.Chrome):
return result return result
else: else:
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.") raise TypeError(
f"Hook {hook_type} must return an instance of webdriver.Chrome or None."
)
# If the hook returns None or there is no hook, return self.driver # If the hook returns None or there is no hook, return self.driver
return self.driver return self.driver
@@ -178,60 +182,77 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.add_argument(f"user-agent={user_agent}") self.options.add_argument(f"user-agent={user_agent}")
self.driver.quit() self.driver.quit()
self.driver = webdriver.Chrome(service=self.service, options=self.options) self.driver = webdriver.Chrome(service=self.service, options=self.options)
self.driver = self.execute_hook('on_user_agent_updated', self.driver) self.driver = self.execute_hook("on_user_agent_updated", self.driver)
def set_custom_headers(self, headers: dict): def set_custom_headers(self, headers: dict):
# Enable Network domain for sending headers # Enable Network domain for sending headers
self.driver.execute_cdp_cmd('Network.enable', {}) self.driver.execute_cdp_cmd("Network.enable", {})
# Set extra HTTP headers # Set extra HTTP headers
self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers}) self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers})
def _ensure_page_load(self, max_checks=6, check_interval=0.01): def _ensure_page_load(self, max_checks=6, check_interval=0.01):
initial_length = len(self.driver.page_source) initial_length = len(self.driver.page_source)
for ix in range(max_checks): for ix in range(max_checks):
# print(f"Checking page load: {ix}") # print(f"Checking page load: {ix}")
time.sleep(check_interval) time.sleep(check_interval)
current_length = len(self.driver.page_source) current_length = len(self.driver.page_source)
if current_length != initial_length: if current_length != initial_length:
break break
return self.driver.page_source return self.driver.page_source
def crawl(self, url: str, **kwargs) -> str: def crawl(self, url: str, **kwargs) -> str:
# Create md5 hash of the URL # Create md5 hash of the URL
import hashlib import hashlib
url_hash = hashlib.md5(url.encode()).hexdigest() url_hash = hashlib.md5(url.encode()).hexdigest()
if self.use_cached_html: if self.use_cached_html:
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) cache_file_path = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
".crawl4ai",
"cache",
url_hash,
)
if os.path.exists(cache_file_path): if os.path.exists(cache_file_path):
with open(cache_file_path, "r") as f: with open(cache_file_path, "r") as f:
return sanitize_input_encode(f.read()) return sanitize_input_encode(f.read())
try: try:
self.driver = self.execute_hook('before_get_url', self.driver) self.driver = self.execute_hook("before_get_url", self.driver)
if self.verbose: if self.verbose:
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...") print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
self.driver.get(url) #<html><head></head><body></body></html> self.driver.get(url) # <html><head></head><body></body></html>
WebDriverWait(self.driver, 20).until( WebDriverWait(self.driver, 20).until(
lambda d: d.execute_script('return document.readyState') == 'complete' lambda d: d.execute_script("return document.readyState") == "complete"
) )
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "body")) EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
) )
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);") self.driver.execute_script(
"window.scrollTo(0, document.body.scrollHeight);"
self.driver = self.execute_hook('after_get_url', self.driver) )
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source
can_not_be_done_headless = False # Look at my creativity for naming variables self.driver = self.execute_hook("after_get_url", self.driver)
html = sanitize_input_encode(
self._ensure_page_load()
) # self.driver.page_source
can_not_be_done_headless = (
False # Look at my creativity for naming variables
)
# TODO: Very ugly approach, but promise to change it! # TODO: Very ugly approach, but promise to change it!
if kwargs.get('bypass_headless', False) or html == "<html><head></head><body></body></html>": if (
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...") kwargs.get("bypass_headless", False)
or html == "<html><head></head><body></body></html>"
):
print(
"[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..."
)
can_not_be_done_headless = True can_not_be_done_headless = True
options = Options() options = Options()
options.headless = False options.headless = False
@@ -239,27 +260,31 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
options.add_argument("--window-size=5,5") options.add_argument("--window-size=5,5")
driver = webdriver.Chrome(service=self.service, options=options) driver = webdriver.Chrome(service=self.service, options=options)
driver.get(url) driver.get(url)
self.driver = self.execute_hook('after_get_url', driver) self.driver = self.execute_hook("after_get_url", driver)
html = sanitize_input_encode(driver.page_source) html = sanitize_input_encode(driver.page_source)
driver.quit() driver.quit()
# Execute JS code if provided # Execute JS code if provided
self.js_code = kwargs.get("js_code", self.js_code) self.js_code = kwargs.get("js_code", self.js_code)
if self.js_code and type(self.js_code) == str: if self.js_code and type(self.js_code) == str:
self.driver.execute_script(self.js_code) self.driver.execute_script(self.js_code)
# Optionally, wait for some condition after executing the JS code # Optionally, wait for some condition after executing the JS code
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete" lambda driver: driver.execute_script("return document.readyState")
== "complete"
) )
elif self.js_code and type(self.js_code) == list: elif self.js_code and type(self.js_code) == list:
for js in self.js_code: for js in self.js_code:
self.driver.execute_script(js) self.driver.execute_script(js)
WebDriverWait(self.driver, 10).until( WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete" lambda driver: driver.execute_script(
"return document.readyState"
)
== "complete"
) )
# Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky) # Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky)
wait_for = kwargs.get('wait_for', False) wait_for = kwargs.get("wait_for", False)
if wait_for: if wait_for:
if callable(wait_for): if callable(wait_for):
print("[LOG] 🔄 Waiting for condition...") print("[LOG] 🔄 Waiting for condition...")
@@ -268,32 +293,37 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
print("[LOG] 🔄 Waiting for condition...") print("[LOG] 🔄 Waiting for condition...")
WebDriverWait(self.driver, 20).until( WebDriverWait(self.driver, 20).until(
EC.presence_of_element_located((By.CSS_SELECTOR, wait_for)) EC.presence_of_element_located((By.CSS_SELECTOR, wait_for))
) )
if not can_not_be_done_headless: if not can_not_be_done_headless:
html = sanitize_input_encode(self.driver.page_source) html = sanitize_input_encode(self.driver.page_source)
self.driver = self.execute_hook('before_return_html', self.driver, html) self.driver = self.execute_hook("before_return_html", self.driver, html)
# Store in cache # Store in cache
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash) cache_file_path = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
".crawl4ai",
"cache",
url_hash,
)
with open(cache_file_path, "w", encoding="utf-8") as f: with open(cache_file_path, "w", encoding="utf-8") as f:
f.write(html) f.write(html)
if self.verbose: if self.verbose:
print(f"[LOG] ✅ Crawled {url} successfully!") print(f"[LOG] ✅ Crawled {url} successfully!")
return html return html
except InvalidArgumentException as e: except InvalidArgumentException as e:
if not hasattr(e, 'msg'): if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e)) e.msg = sanitize_input_encode(str(e))
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}") raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
except WebDriverException as e: except WebDriverException as e:
# If e does nlt have msg attribute create it and set it to str(e) # If e does nlt have msg attribute create it and set it to str(e)
if not hasattr(e, 'msg'): if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e)) e.msg = sanitize_input_encode(str(e))
raise WebDriverException(f"Failed to crawl {url}: {e.msg}") raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
except Exception as e: except Exception as e:
if not hasattr(e, 'msg'): if not hasattr(e, "msg"):
e.msg = sanitize_input_encode(str(e)) e.msg = sanitize_input_encode(str(e))
raise Exception(f"Failed to crawl {url}: {e.msg}") raise Exception(f"Failed to crawl {url}: {e.msg}")
@@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
try: try:
# Get the dimensions of the page # Get the dimensions of the page
total_width = self.driver.execute_script("return document.body.scrollWidth") total_width = self.driver.execute_script("return document.body.scrollWidth")
total_height = self.driver.execute_script("return document.body.scrollHeight") total_height = self.driver.execute_script(
"return document.body.scrollHeight"
)
# Set the window size to the dimensions of the page # Set the window size to the dimensions of the page
self.driver.set_window_size(total_width, total_height) self.driver.set_window_size(total_width, total_height)
@@ -313,25 +345,27 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
image = Image.open(BytesIO(screenshot)) image = Image.open(BytesIO(screenshot))
# Convert image to RGB mode (this will handle both RGB and RGBA images) # Convert image to RGB mode (this will handle both RGB and RGBA images)
rgb_image = image.convert('RGB') rgb_image = image.convert("RGB")
# Convert to JPEG and compress # Convert to JPEG and compress
buffered = BytesIO() buffered = BytesIO()
rgb_image.save(buffered, format="JPEG", quality=85) rgb_image.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
if self.verbose: if self.verbose:
print(f"[LOG] 📸 Screenshot taken and converted to base64") print("[LOG] 📸 Screenshot taken and converted to base64")
return img_base64 return img_base64
except Exception as e: except Exception as e:
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}") error_message = sanitize_input_encode(
f"Failed to take screenshot: {str(e)}"
)
print(error_message) print(error_message)
# Generate an image with black background # Generate an image with black background
img = Image.new('RGB', (800, 600), color='black') img = Image.new("RGB", (800, 600), color="black")
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
# Load a font # Load a font
try: try:
font = ImageFont.truetype("arial.ttf", 40) font = ImageFont.truetype("arial.ttf", 40)
@@ -345,16 +379,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
# Calculate text position # Calculate text position
text_position = (10, 10) text_position = (10, 10)
# Draw the text on the image # Draw the text on the image
draw.text(text_position, wrapped_text, fill=text_color, font=font) draw.text(text_position, wrapped_text, fill=text_color, font=font)
# Convert to base64 # Convert to base64
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format="JPEG") img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_base64 return img_base64
def quit(self): def quit(self):
self.driver.quit() self.driver.quit()

View File

@@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra
os.makedirs(DB_PATH, exist_ok=True) os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
def init_db(): def init_db():
global DB_PATH global DB_PATH
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS crawled_data ( CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY, url TEXT PRIMARY KEY,
html TEXT, html TEXT,
@@ -24,31 +26,42 @@ def init_db():
metadata TEXT DEFAULT "{}", metadata TEXT DEFAULT "{}",
screenshot TEXT DEFAULT "" screenshot TEXT DEFAULT ""
) )
''') """
)
conn.commit() conn.commit()
conn.close() conn.close()
def alter_db_add_screenshot(new_column: str = "media"): def alter_db_add_screenshot(new_column: str = "media"):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') cursor.execute(
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error altering database to add screenshot column: {e}") print(f"Error altering database to add screenshot column: {e}")
def check_db_path(): def check_db_path():
if not DB_PATH: if not DB_PATH:
raise ValueError("Database path is not set or is empty.") raise ValueError("Database path is not set or is empty.")
def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
def get_cached_url(
url: str,
) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,)) cursor.execute(
"SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?",
(url,),
)
result = cursor.fetchone() result = cursor.fetchone()
conn.close() conn.close()
return result return result
@@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str
print(f"Error retrieving cached URL: {e}") print(f"Error retrieving cached URL: {e}")
return None return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""):
def cache_url(
url: str,
html: str,
cleaned_html: str,
markdown: str,
extracted_content: str,
success: bool,
media: str = "{}",
links: str = "{}",
metadata: str = "{}",
screenshot: str = "",
):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot) INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET ON CONFLICT(url) DO UPDATE SET
@@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
links = excluded.links, links = excluded.links,
metadata = excluded.metadata, metadata = excluded.metadata,
screenshot = excluded.screenshot screenshot = excluded.screenshot
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)) """,
(
url,
html,
cleaned_html,
markdown,
extracted_content,
success,
media,
links,
metadata,
screenshot,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error caching URL: {e}") print(f"Error caching URL: {e}")
def get_total_count() -> int: def get_total_count() -> int:
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM crawled_data') cursor.execute("SELECT COUNT(*) FROM crawled_data")
result = cursor.fetchone() result = cursor.fetchone()
conn.close() conn.close()
return result[0] return result[0]
@@ -93,43 +133,48 @@ def get_total_count() -> int:
print(f"Error getting total count: {e}") print(f"Error getting total count: {e}")
return 0 return 0
def clear_db(): def clear_db():
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM crawled_data') cursor.execute("DELETE FROM crawled_data")
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error clearing database: {e}") print(f"Error clearing database: {e}")
def flush_db(): def flush_db():
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DROP TABLE crawled_data') cursor.execute("DROP TABLE crawled_data")
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error flushing database: {e}") print(f"Error flushing database: {e}")
def update_existing_records(new_column: str = "media", default_value: str = "{}"): def update_existing_records(new_column: str = "media", default_value: str = "{}"):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL') cursor.execute(
f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL'
)
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error updating existing records: {e}") print(f"Error updating existing records: {e}")
if __name__ == "__main__": if __name__ == "__main__":
# Delete the existing database file # Delete the existing database file
if os.path.exists(DB_PATH): if os.path.exists(DB_PATH):
os.remove(DB_PATH) os.remove(DB_PATH)
init_db() init_db()
# alter_db_add_screenshot("COL_NAME") # alter_db_add_screenshot("COL_NAME")

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from crawl4ai.async_logger import AsyncLogger from crawl4ai.async_logger import AsyncLogger
from crawl4ai.llmtxt import AsyncLLMTextManager from crawl4ai.llmtxt import AsyncLLMTextManager
class DocsManager: class DocsManager:
def __init__(self, logger=None): def __init__(self, logger=None):
self.docs_dir = Path.home() / ".crawl4ai" / "docs" self.docs_dir = Path.home() / ".crawl4ai" / "docs"
@@ -21,11 +22,14 @@ class DocsManager:
"""Copy from local docs or download from GitHub""" """Copy from local docs or download from GitHub"""
try: try:
# Try local first # Try local first
if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))): if self.local_docs.exists() and (
any(self.local_docs.glob("*.md"))
or any(self.local_docs.glob("*.tokens"))
):
# Empty the local docs directory # Empty the local docs directory
for file_path in self.docs_dir.glob("*.md"): for file_path in self.docs_dir.glob("*.md"):
file_path.unlink() file_path.unlink()
# for file_path in self.docs_dir.glob("*.tokens"): # for file_path in self.docs_dir.glob("*.tokens"):
# file_path.unlink() # file_path.unlink()
for file_path in self.local_docs.glob("*.md"): for file_path in self.local_docs.glob("*.md"):
shutil.copy2(file_path, self.docs_dir / file_path.name) shutil.copy2(file_path, self.docs_dir / file_path.name)
@@ -36,14 +40,14 @@ class DocsManager:
# Fallback to GitHub # Fallback to GitHub
response = requests.get( response = requests.get(
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt", "https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
headers={'Accept': 'application/vnd.github.v3+json'} headers={"Accept": "application/vnd.github.v3+json"},
) )
response.raise_for_status() response.raise_for_status()
for item in response.json(): for item in response.json():
if item['type'] == 'file' and item['name'].endswith('.md'): if item["type"] == "file" and item["name"].endswith(".md"):
content = requests.get(item['download_url']).text content = requests.get(item["download_url"]).text
with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f: with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f:
f.write(content) f.write(content)
return True return True
@@ -57,11 +61,15 @@ class DocsManager:
# Remove [0-9]+_ prefix # Remove [0-9]+_ prefix
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names] names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
# Exclude those end with .xs.md and .q.md # Exclude those end with .xs.md and .q.md
names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")] names = [
name
for name in names
if not name.endswith(".xs") and not name.endswith(".q")
]
return names return names
def generate(self, sections, mode="extended"): def generate(self, sections, mode="extended"):
return self.llm_text.generate(sections, mode) return self.llm_text.generate(sections, mode)
def search(self, query: str, top_k: int = 5): def search(self, query: str, top_k: int = 5):
return self.llm_text.search(query, top_k) return self.llm_text.search(query, top_k)

File diff suppressed because it is too large Load Diff

View File

@@ -54,13 +54,13 @@ class HTML2Text(html.parser.HTMLParser):
self.td_count = 0 self.td_count = 0
self.table_start = False self.table_start = False
self.unicode_snob = config.UNICODE_SNOB # covered in cli self.unicode_snob = config.UNICODE_SNOB # covered in cli
self.escape_snob = config.ESCAPE_SNOB # covered in cli self.escape_snob = config.ESCAPE_SNOB # covered in cli
self.escape_backslash = config.ESCAPE_BACKSLASH # covered in cli self.escape_backslash = config.ESCAPE_BACKSLASH # covered in cli
self.escape_dot = config.ESCAPE_DOT # covered in cli self.escape_dot = config.ESCAPE_DOT # covered in cli
self.escape_plus = config.ESCAPE_PLUS # covered in cli self.escape_plus = config.ESCAPE_PLUS # covered in cli
self.escape_dash = config.ESCAPE_DASH # covered in cli self.escape_dash = config.ESCAPE_DASH # covered in cli
self.links_each_paragraph = config.LINKS_EACH_PARAGRAPH self.links_each_paragraph = config.LINKS_EACH_PARAGRAPH
self.body_width = bodywidth # covered in cli self.body_width = bodywidth # covered in cli
self.skip_internal_links = config.SKIP_INTERNAL_LINKS # covered in cli self.skip_internal_links = config.SKIP_INTERNAL_LINKS # covered in cli
@@ -144,8 +144,8 @@ class HTML2Text(html.parser.HTMLParser):
def update_params(self, **kwargs): def update_params(self, **kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self, key, value) setattr(self, key, value)
def feed(self, data: str) -> None: def feed(self, data: str) -> None:
data = data.replace("</' + 'script>", "</ignore>") data = data.replace("</' + 'script>", "</ignore>")
super().feed(data) super().feed(data)
@@ -903,7 +903,13 @@ class HTML2Text(html.parser.HTMLParser):
self.empty_link = False self.empty_link = False
if not self.code and not self.pre and not entity_char: if not self.code and not self.pre and not entity_char:
data = escape_md_section(data, snob=self.escape_snob, escape_dot=self.escape_dot, escape_plus=self.escape_plus, escape_dash=self.escape_dash) data = escape_md_section(
data,
snob=self.escape_snob,
escape_dot=self.escape_dot,
escape_plus=self.escape_plus,
escape_dash=self.escape_dash,
)
self.preceding_data = data self.preceding_data = data
self.o(data, puredata=True) self.o(data, puredata=True)
@@ -1006,6 +1012,7 @@ class HTML2Text(html.parser.HTMLParser):
newlines += 1 newlines += 1
return result return result
def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str: def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str:
if bodywidth is None: if bodywidth is None:
bodywidth = config.BODY_WIDTH bodywidth = config.BODY_WIDTH
@@ -1013,6 +1020,7 @@ def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) ->
return h.handle(html) return h.handle(html)
class CustomHTML2Text(HTML2Text): class CustomHTML2Text(HTML2Text):
def __init__(self, *args, handle_code_in_pre=False, **kwargs): def __init__(self, *args, handle_code_in_pre=False, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -1022,8 +1030,8 @@ class CustomHTML2Text(HTML2Text):
self.current_preserved_tag = None self.current_preserved_tag = None
self.preserved_content = [] self.preserved_content = []
self.preserve_depth = 0 self.preserve_depth = 0
self.handle_code_in_pre = handle_code_in_pre self.handle_code_in_pre = handle_code_in_pre
# Configuration options # Configuration options
self.skip_internal_links = False self.skip_internal_links = False
self.single_line_break = False self.single_line_break = False
@@ -1041,9 +1049,9 @@ class CustomHTML2Text(HTML2Text):
def update_params(self, **kwargs): def update_params(self, **kwargs):
"""Update parameters and set preserved tags.""" """Update parameters and set preserved tags."""
for key, value in kwargs.items(): for key, value in kwargs.items():
if key == 'preserve_tags': if key == "preserve_tags":
self.preserve_tags = set(value) self.preserve_tags = set(value)
elif key == 'handle_code_in_pre': elif key == "handle_code_in_pre":
self.handle_code_in_pre = value self.handle_code_in_pre = value
else: else:
setattr(self, key, value) setattr(self, key, value)
@@ -1056,17 +1064,19 @@ class CustomHTML2Text(HTML2Text):
self.current_preserved_tag = tag self.current_preserved_tag = tag
self.preserved_content = [] self.preserved_content = []
# Format opening tag with attributes # Format opening tag with attributes
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) attr_str = "".join(
self.preserved_content.append(f'<{tag}{attr_str}>') f' {k}="{v}"' for k, v in attrs.items() if v is not None
)
self.preserved_content.append(f"<{tag}{attr_str}>")
self.preserve_depth += 1 self.preserve_depth += 1
return return
else: else:
self.preserve_depth -= 1 self.preserve_depth -= 1
if self.preserve_depth == 0: if self.preserve_depth == 0:
self.preserved_content.append(f'</{tag}>') self.preserved_content.append(f"</{tag}>")
# Output the preserved HTML block with proper spacing # Output the preserved HTML block with proper spacing
preserved_html = ''.join(self.preserved_content) preserved_html = "".join(self.preserved_content)
self.o('\n' + preserved_html + '\n') self.o("\n" + preserved_html + "\n")
self.current_preserved_tag = None self.current_preserved_tag = None
return return
@@ -1074,29 +1084,31 @@ class CustomHTML2Text(HTML2Text):
if self.preserve_depth > 0: if self.preserve_depth > 0:
if start: if start:
# Format nested tags with attributes # Format nested tags with attributes
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) attr_str = "".join(
self.preserved_content.append(f'<{tag}{attr_str}>') f' {k}="{v}"' for k, v in attrs.items() if v is not None
)
self.preserved_content.append(f"<{tag}{attr_str}>")
else: else:
self.preserved_content.append(f'</{tag}>') self.preserved_content.append(f"</{tag}>")
return return
# Handle pre tags # Handle pre tags
if tag == 'pre': if tag == "pre":
if start: if start:
self.o('```\n') # Markdown code block start self.o("```\n") # Markdown code block start
self.inside_pre = True self.inside_pre = True
else: else:
self.o('\n```\n') # Markdown code block end self.o("\n```\n") # Markdown code block end
self.inside_pre = False self.inside_pre = False
elif tag == 'code': elif tag == "code":
if self.inside_pre and not self.handle_code_in_pre: if self.inside_pre and not self.handle_code_in_pre:
# Ignore code tags inside pre blocks if handle_code_in_pre is False # Ignore code tags inside pre blocks if handle_code_in_pre is False
return return
if start: if start:
self.o('`') # Markdown inline code start self.o("`") # Markdown inline code start
self.inside_code = True self.inside_code = True
else: else:
self.o('`') # Markdown inline code end self.o("`") # Markdown inline code end
self.inside_code = False self.inside_code = False
else: else:
super().handle_tag(tag, attrs, start) super().handle_tag(tag, attrs, start)
@@ -1113,13 +1125,12 @@ class CustomHTML2Text(HTML2Text):
return return
if self.inside_code: if self.inside_code:
# Inline code: no newlines allowed # Inline code: no newlines allowed
self.o(data.replace('\n', ' ')) self.o(data.replace("\n", " "))
return return
# Default behavior for other tags # Default behavior for other tags
super().handle_data(data, entity_char) super().handle_data(data, entity_char)
# # Handle pre tags # # Handle pre tags
# if tag == 'pre': # if tag == 'pre':
# if start: # if start:

View File

@@ -1,2 +1,3 @@
class OutCallback: class OutCallback:
def __call__(self, s: str) -> None: ... def __call__(self, s: str) -> None:
...

View File

@@ -210,7 +210,7 @@ def escape_md_section(
snob: bool = False, snob: bool = False,
escape_dot: bool = True, escape_dot: bool = True,
escape_plus: bool = True, escape_plus: bool = True,
escape_dash: bool = True escape_dash: bool = True,
) -> str: ) -> str:
""" """
Escapes markdown-sensitive characters across whole document sections. Escapes markdown-sensitive characters across whole document sections.
@@ -233,6 +233,7 @@ def escape_md_section(
return text return text
def reformat_table(lines: List[str], right_margin: int) -> List[str]: def reformat_table(lines: List[str], right_margin: int) -> List[str]:
""" """
Given the lines of a table Given the lines of a table

View File

@@ -6,25 +6,44 @@ from .async_logger import AsyncLogger, LogLevel
# Initialize logger # Initialize logger
logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True) logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
def post_install(): def post_install():
"""Run all post-installation tasks""" """Run all post-installation tasks"""
logger.info("Running post-installation setup...", tag="INIT") logger.info("Running post-installation setup...", tag="INIT")
install_playwright() install_playwright()
run_migration() run_migration()
logger.success("Post-installation setup completed!", tag="COMPLETE") logger.success("Post-installation setup completed!", tag="COMPLETE")
def install_playwright(): def install_playwright():
logger.info("Installing Playwright browsers...", tag="INIT") logger.info("Installing Playwright browsers...", tag="INIT")
try: try:
# subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"]) # subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"])
subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chromium"]) subprocess.check_call(
logger.success("Playwright installation completed successfully.", tag="COMPLETE") [
except subprocess.CalledProcessError as e: sys.executable,
"-m",
"playwright",
"install",
"--with-deps",
"--force",
"chromium",
]
)
logger.success(
"Playwright installation completed successfully.", tag="COMPLETE"
)
except subprocess.CalledProcessError:
# logger.error(f"Error during Playwright installation: {e}", tag="ERROR") # logger.error(f"Error during Playwright installation: {e}", tag="ERROR")
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.") logger.warning(
except Exception as e: f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
)
except Exception:
# logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR") # logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR")
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.") logger.warning(
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
)
def run_migration(): def run_migration():
"""Initialize database during installation""" """Initialize database during installation"""
@@ -33,18 +52,26 @@ def run_migration():
from crawl4ai.async_database import async_db_manager from crawl4ai.async_database import async_db_manager
asyncio.run(async_db_manager.initialize()) asyncio.run(async_db_manager.initialize())
logger.success("Database initialization completed successfully.", tag="COMPLETE") logger.success(
"Database initialization completed successfully.", tag="COMPLETE"
)
except ImportError: except ImportError:
logger.warning("Database module not found. Will initialize on first use.") logger.warning("Database module not found. Will initialize on first use.")
except Exception as e: except Exception as e:
logger.warning(f"Database initialization failed: {e}") logger.warning(f"Database initialization failed: {e}")
logger.warning("Database will be initialized on first use") logger.warning("Database will be initialized on first use")
async def run_doctor(): async def run_doctor():
"""Test if Crawl4AI is working properly""" """Test if Crawl4AI is working properly"""
logger.info("Running Crawl4AI health check...", tag="INIT") logger.info("Running Crawl4AI health check...", tag="INIT")
try: try:
from .async_webcrawler import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from .async_webcrawler import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
CacheMode,
)
browser_config = BrowserConfig( browser_config = BrowserConfig(
headless=True, headless=True,
@@ -52,7 +79,7 @@ async def run_doctor():
ignore_https_errors=True, ignore_https_errors=True,
light_mode=True, light_mode=True,
viewport_width=1280, viewport_width=1280,
viewport_height=720 viewport_height=720,
) )
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
@@ -62,10 +89,7 @@ async def run_doctor():
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
logger.info("Testing crawling capabilities...", tag="TEST") logger.info("Testing crawling capabilities...", tag="TEST")
result = await crawler.arun( result = await crawler.arun(url="https://crawl4ai.com", config=run_config)
url="https://crawl4ai.com",
config=run_config
)
if result and result.markdown: if result and result.markdown:
logger.success("✅ Crawling test passed!", tag="COMPLETE") logger.success("✅ Crawling test passed!", tag="COMPLETE")
@@ -77,7 +101,9 @@ async def run_doctor():
logger.error(f"❌ Test failed: {e}", tag="ERROR") logger.error(f"❌ Test failed: {e}", tag="ERROR")
return False return False
def doctor(): def doctor():
"""Entry point for the doctor command""" """Entry point for the doctor command"""
import asyncio import asyncio
return asyncio.run(run_doctor()) return asyncio.run(run_doctor())

View File

@@ -1,15 +1,18 @@
import os, sys import os
# Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free # Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free
def load_js_script(script_name): def load_js_script(script_name):
# Get the path of the current script # Get the path of the current script
current_script_path = os.path.dirname(os.path.realpath(__file__)) current_script_path = os.path.dirname(os.path.realpath(__file__))
# Get the path of the script to load # Get the path of the script to load
script_path = os.path.join(current_script_path, script_name + '.js') script_path = os.path.join(current_script_path, script_name + ".js")
# Check if the script exists # Check if the script exists
if not os.path.exists(script_path): if not os.path.exists(script_path):
raise ValueError(f"Script {script_name} not found in the folder {current_script_path}") raise ValueError(
f"Script {script_name} not found in the folder {current_script_path}"
)
# Load the content of the script # Load the content of the script
with open(script_path, 'r') as f: with open(script_path, "r") as f:
script_content = f.read() script_content = f.read()
return script_content return script_content

View File

@@ -11,16 +11,16 @@ from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer from nltk.stem import WordNetLemmatizer
from litellm import completion, batch_completion from litellm import batch_completion
from .async_logger import AsyncLogger from .async_logger import AsyncLogger
import litellm import litellm
import pickle import pickle
import hashlib # <--- ADDED for file-hash import hashlib # <--- ADDED for file-hash
from fnmatch import fnmatch
import glob import glob
litellm.set_verbose = False litellm.set_verbose = False
def _compute_file_hash(file_path: Path) -> str: def _compute_file_hash(file_path: Path) -> str:
"""Compute MD5 hash for the file's entire content.""" """Compute MD5 hash for the file's entire content."""
hash_md5 = hashlib.md5() hash_md5 = hashlib.md5()
@@ -29,13 +29,14 @@ def _compute_file_hash(file_path: Path) -> str:
hash_md5.update(chunk) hash_md5.update(chunk)
return hash_md5.hexdigest() return hash_md5.hexdigest()
class AsyncLLMTextManager: class AsyncLLMTextManager:
def __init__( def __init__(
self, self,
docs_dir: Path, docs_dir: Path,
logger: Optional[AsyncLogger] = None, logger: Optional[AsyncLogger] = None,
max_concurrent_calls: int = 5, max_concurrent_calls: int = 5,
batch_size: int = 3 batch_size: int = 3,
) -> None: ) -> None:
self.docs_dir = docs_dir self.docs_dir = docs_dir
self.logger = logger self.logger = logger
@@ -51,7 +52,7 @@ class AsyncLLMTextManager:
contents = [] contents = []
for file_path in doc_batch: for file_path in doc_batch:
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
contents.append(f.read()) contents.append(f.read())
except Exception as e: except Exception as e:
self.logger.error(f"Error reading {file_path}: {str(e)}") self.logger.error(f"Error reading {file_path}: {str(e)}")
@@ -77,43 +78,53 @@ Wrap your response in <index>...</index> tags.
# Prepare messages for batch processing # Prepare messages for batch processing
messages_list = [ messages_list = [
[ [
{"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"} {
"role": "user",
"content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}",
}
] ]
for content in contents if content for content in contents
if content
] ]
try: try:
responses = batch_completion( responses = batch_completion(
model="anthropic/claude-3-5-sonnet-latest", model="anthropic/claude-3-5-sonnet-latest",
messages=messages_list, messages=messages_list,
logger_fn=None logger_fn=None,
) )
# Process responses and save index files # Process responses and save index files
for response, file_path in zip(responses, doc_batch): for response, file_path in zip(responses, doc_batch):
try: try:
index_content_match = re.search( index_content_match = re.search(
r'<index>(.*?)</index>', r"<index>(.*?)</index>",
response.choices[0].message.content, response.choices[0].message.content,
re.DOTALL re.DOTALL,
) )
if not index_content_match: if not index_content_match:
self.logger.warning(f"No <index>...</index> content found for {file_path}") self.logger.warning(
f"No <index>...</index> content found for {file_path}"
)
continue continue
index_content = re.sub( index_content = re.sub(
r"\n\s*\n", "\n", index_content_match.group(1) r"\n\s*\n", "\n", index_content_match.group(1)
).strip() ).strip()
if index_content: if index_content:
index_file = file_path.with_suffix('.q.md') index_file = file_path.with_suffix(".q.md")
with open(index_file, 'w', encoding='utf-8') as f: with open(index_file, "w", encoding="utf-8") as f:
f.write(index_content) f.write(index_content)
self.logger.info(f"Created index file: {index_file}") self.logger.info(f"Created index file: {index_file}")
else: else:
self.logger.warning(f"No index content found in response for {file_path}") self.logger.warning(
f"No index content found in response for {file_path}"
)
except Exception as e: except Exception as e:
self.logger.error(f"Error processing response for {file_path}: {str(e)}") self.logger.error(
f"Error processing response for {file_path}: {str(e)}"
)
except Exception as e: except Exception as e:
self.logger.error(f"Error in batch completion: {str(e)}") self.logger.error(f"Error in batch completion: {str(e)}")
@@ -171,7 +182,12 @@ Wrap your response in <index>...</index> tags.
lemmatizer = WordNetLemmatizer() lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words("english")) - { stop_words = set(stopwords.words("english")) - {
"how", "what", "when", "where", "why", "which", "how",
"what",
"when",
"where",
"why",
"which",
} }
tokens = [] tokens = []
@@ -222,7 +238,9 @@ Wrap your response in <index>...</index> tags.
self.logger.info("Checking which .q.md files need (re)indexing...") self.logger.info("Checking which .q.md files need (re)indexing...")
# Gather all .q.md files # Gather all .q.md files
q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")] q_files = [
self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")
]
# We'll store known (unchanged) facts in these lists # We'll store known (unchanged) facts in these lists
existing_facts: List[str] = [] existing_facts: List[str] = []
@@ -243,7 +261,9 @@ Wrap your response in <index>...</index> tags.
# Otherwise, load the existing cache and compare hash # Otherwise, load the existing cache and compare hash
cache = self._load_or_create_token_cache(qf) cache = self._load_or_create_token_cache(qf)
# If the .q.tokens was out of date (i.e. changed hash), we reindex # If the .q.tokens was out of date (i.e. changed hash), we reindex
if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf): if len(cache["facts"]) == 0 or cache.get(
"content_hash"
) != _compute_file_hash(qf):
needSet.append(qf) needSet.append(qf)
else: else:
# File is unchanged → retrieve cached token data # File is unchanged → retrieve cached token data
@@ -255,20 +275,29 @@ Wrap your response in <index>...</index> tags.
if not needSet and not clear_cache: if not needSet and not clear_cache:
# If no file needs reindexing, try loading existing index # If no file needs reindexing, try loading existing index
if self.maybe_load_bm25_index(clear_cache=False): if self.maybe_load_bm25_index(clear_cache=False):
self.logger.info("No new/changed .q.md files found. Using existing BM25 index.") self.logger.info(
"No new/changed .q.md files found. Using existing BM25 index."
)
return return
else: else:
# If there's no existing index, we must build a fresh index from the old caches # If there's no existing index, we must build a fresh index from the old caches
self.logger.info("No existing BM25 index found. Building from cached facts.") self.logger.info(
"No existing BM25 index found. Building from cached facts."
)
if existing_facts: if existing_facts:
self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.") self.logger.info(
f"Building BM25 index with {len(existing_facts)} cached facts."
)
self.bm25_index = BM25Okapi(existing_tokens) self.bm25_index = BM25Okapi(existing_tokens)
self.tokenized_facts = existing_facts self.tokenized_facts = existing_facts
with open(self.bm25_index_file, "wb") as f: with open(self.bm25_index_file, "wb") as f:
pickle.dump({ pickle.dump(
"bm25_index": self.bm25_index, {
"tokenized_facts": self.tokenized_facts "bm25_index": self.bm25_index,
}, f) "tokenized_facts": self.tokenized_facts,
},
f,
)
else: else:
self.logger.warning("No facts found at all. Index remains empty.") self.logger.warning("No facts found at all. Index remains empty.")
return return
@@ -311,7 +340,9 @@ Wrap your response in <index>...</index> tags.
self._save_token_cache(file, fresh_cache) self._save_token_cache(file, fresh_cache)
mem_usage = process.memory_info().rss / 1024 / 1024 mem_usage = process.memory_info().rss / 1024 / 1024
self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB") self.logger.debug(
f"Memory usage after {file.name}: {mem_usage:.2f}MB"
)
except Exception as e: except Exception as e:
self.logger.error(f"Error processing {file}: {str(e)}") self.logger.error(f"Error processing {file}: {str(e)}")
@@ -328,40 +359,49 @@ Wrap your response in <index>...</index> tags.
all_tokens = existing_tokens + new_tokens all_tokens = existing_tokens + new_tokens
# 3) Build BM25 index from combined facts # 3) Build BM25 index from combined facts
self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).") self.logger.info(
f"Building BM25 index with {len(all_facts)} total facts (old + new)."
)
self.bm25_index = BM25Okapi(all_tokens) self.bm25_index = BM25Okapi(all_tokens)
self.tokenized_facts = all_facts self.tokenized_facts = all_facts
# 4) Save the updated BM25 index to disk # 4) Save the updated BM25 index to disk
with open(self.bm25_index_file, "wb") as f: with open(self.bm25_index_file, "wb") as f:
pickle.dump({ pickle.dump(
"bm25_index": self.bm25_index, {
"tokenized_facts": self.tokenized_facts "bm25_index": self.bm25_index,
}, f) "tokenized_facts": self.tokenized_facts,
},
f,
)
final_mem = process.memory_info().rss / 1024 / 1024 final_mem = process.memory_info().rss / 1024 / 1024
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB") self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None: async def generate_index_files(
self, force_generate_facts: bool = False, clear_bm25_cache: bool = False
) -> None:
""" """
Generate index files for all documents in parallel batches Generate index files for all documents in parallel batches
Args: Args:
force_generate_facts (bool): If True, regenerate indexes even if they exist force_generate_facts (bool): If True, regenerate indexes even if they exist
clear_bm25_cache (bool): If True, clear existing BM25 index cache clear_bm25_cache (bool): If True, clear existing BM25 index cache
""" """
self.logger.info("Starting index generation for documentation files.") self.logger.info("Starting index generation for documentation files.")
md_files = [ md_files = [
self.docs_dir / f for f in os.listdir(self.docs_dir) self.docs_dir / f
if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md']) for f in os.listdir(self.docs_dir)
if f.endswith(".md") and not any(f.endswith(x) for x in [".q.md", ".xs.md"])
] ]
# Filter out files that already have .q files unless force=True # Filter out files that already have .q files unless force=True
if not force_generate_facts: if not force_generate_facts:
md_files = [ md_files = [
f for f in md_files f
if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists() for f in md_files
if not (self.docs_dir / f.name.replace(".md", ".q.md")).exists()
] ]
if not md_files: if not md_files:
@@ -369,8 +409,10 @@ Wrap your response in <index>...</index> tags.
else: else:
# Process documents in batches # Process documents in batches
for i in range(0, len(md_files), self.batch_size): for i in range(0, len(md_files), self.batch_size):
batch = md_files[i:i + self.batch_size] batch = md_files[i : i + self.batch_size]
self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}") self.logger.info(
f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}"
)
await self._process_document_batch(batch) await self._process_document_batch(batch)
self.logger.info("Index generation complete, building/updating search index.") self.logger.info("Index generation complete, building/updating search index.")
@@ -378,21 +420,31 @@ Wrap your response in <index>...</index> tags.
def generate(self, sections: List[str], mode: str = "extended") -> str: def generate(self, sections: List[str], mode: str = "extended") -> str:
# Get all markdown files # Get all markdown files
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \ all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + glob.glob(
glob.glob(str(self.docs_dir / "[0-9]*.xs.md")) str(self.docs_dir / "[0-9]*.xs.md")
)
# Extract base names without extensions # Extract base names without extensions
base_docs = {Path(f).name.split('.')[0] for f in all_files base_docs = {
if not Path(f).name.endswith('.q.md')} Path(f).name.split(".")[0]
for f in all_files
if not Path(f).name.endswith(".q.md")
}
# Filter by sections if provided # Filter by sections if provided
if sections: if sections:
base_docs = {doc for doc in base_docs base_docs = {
if any(section.lower() in doc.lower() for section in sections)} doc
for doc in base_docs
if any(section.lower() in doc.lower() for section in sections)
}
# Get file paths based on mode # Get file paths based on mode
files = [] files = []
for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999): for doc in sorted(
base_docs,
key=lambda x: int(x.split("_")[0]) if x.split("_")[0].isdigit() else 999999,
):
if mode == "condensed": if mode == "condensed":
xs_file = self.docs_dir / f"{doc}.xs.md" xs_file = self.docs_dir / f"{doc}.xs.md"
regular_file = self.docs_dir / f"{doc}.md" regular_file = self.docs_dir / f"{doc}.md"
@@ -404,7 +456,7 @@ Wrap your response in <index>...</index> tags.
content = [] content = []
for file in files: for file in files:
try: try:
with open(file, 'r', encoding='utf-8') as f: with open(file, "r", encoding="utf-8") as f:
fname = Path(file).name fname = Path(file).name
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}") content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
except Exception as e: except Exception as e:
@@ -443,15 +495,9 @@ Wrap your response in <index>...</index> tags.
for file, _ in ranked_files: for file, _ in ranked_files:
main_doc = str(file).replace(".q.md", ".md") main_doc = str(file).replace(".q.md", ".md")
if os.path.exists(self.docs_dir / main_doc): if os.path.exists(self.docs_dir / main_doc):
with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f: with open(self.docs_dir / main_doc, "r", encoding="utf-8") as f:
only_file_name = main_doc.split("/")[-1] only_file_name = main_doc.split("/")[-1]
content = [ content = ["#" * 20, f"# {only_file_name}", "#" * 20, "", f.read()]
"#" * 20,
f"# {only_file_name}",
"#" * 20,
"",
f.read()
]
results.append("\n".join(content)) results.append("\n".join(content))
return "\n\n---\n\n".join(results) return "\n\n---\n\n".join(results)
@@ -482,7 +528,9 @@ Wrap your response in <index>...</index> tags.
if len(components) == 3: if len(components) == 3:
code_ref = components[2].strip() code_ref = components[2].strip()
code_tokens = self.preprocess_text(code_ref) code_tokens = self.preprocess_text(code_ref)
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens) code_match_score = len(set(query_tokens) & set(code_tokens)) / len(
query_tokens
)
file_data[file_path]["total_score"] += score file_data[file_path]["total_score"] += score
file_data[file_path]["match_count"] += 1 file_data[file_path]["match_count"] += 1

View File

@@ -2,77 +2,94 @@ from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple from typing import Optional, Dict, Any, Tuple
from .models import MarkdownGenerationResult from .models import MarkdownGenerationResult
from .html2text import CustomHTML2Text from .html2text import CustomHTML2Text
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter from .content_filter_strategy import RelevantContentFilter
import re import re
from urllib.parse import urljoin from urllib.parse import urljoin
# Pre-compile the regex pattern # Pre-compile the regex pattern
LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)') LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)')
def fast_urljoin(base: str, url: str) -> str: def fast_urljoin(base: str, url: str) -> str:
"""Fast URL joining for common cases.""" """Fast URL joining for common cases."""
if url.startswith(('http://', 'https://', 'mailto:', '//')): if url.startswith(("http://", "https://", "mailto:", "//")):
return url return url
if url.startswith('/'): if url.startswith("/"):
# Handle absolute paths # Handle absolute paths
if base.endswith('/'): if base.endswith("/"):
return base[:-1] + url return base[:-1] + url
return base + url return base + url
return urljoin(base, url) return urljoin(base, url)
class MarkdownGenerationStrategy(ABC): class MarkdownGenerationStrategy(ABC):
"""Abstract base class for markdown generation strategies.""" """Abstract base class for markdown generation strategies."""
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
def __init__(
self,
content_filter: Optional[RelevantContentFilter] = None,
options: Optional[Dict[str, Any]] = None,
):
self.content_filter = content_filter self.content_filter = content_filter
self.options = options or {} self.options = options or {}
@abstractmethod @abstractmethod
def generate_markdown(self, def generate_markdown(
cleaned_html: str, self,
base_url: str = "", cleaned_html: str,
html2text_options: Optional[Dict[str, Any]] = None, base_url: str = "",
content_filter: Optional[RelevantContentFilter] = None, html2text_options: Optional[Dict[str, Any]] = None,
citations: bool = True, content_filter: Optional[RelevantContentFilter] = None,
**kwargs) -> MarkdownGenerationResult: citations: bool = True,
**kwargs,
) -> MarkdownGenerationResult:
"""Generate markdown from cleaned HTML.""" """Generate markdown from cleaned HTML."""
pass pass
class DefaultMarkdownGenerator(MarkdownGenerationStrategy): class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
""" """
Default implementation of markdown generation strategy. Default implementation of markdown generation strategy.
How it works: How it works:
1. Generate raw markdown from cleaned HTML. 1. Generate raw markdown from cleaned HTML.
2. Convert links to citations. 2. Convert links to citations.
3. Generate fit markdown if content filter is provided. 3. Generate fit markdown if content filter is provided.
4. Return MarkdownGenerationResult. 4. Return MarkdownGenerationResult.
Args: Args:
content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown. content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown.
options (Optional[Dict[str, Any]]): Additional options for markdown generation. Defaults to None. options (Optional[Dict[str, Any]]): Additional options for markdown generation. Defaults to None.
Returns: Returns:
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown. MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
""" """
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
def __init__(
self,
content_filter: Optional[RelevantContentFilter] = None,
options: Optional[Dict[str, Any]] = None,
):
super().__init__(content_filter, options) super().__init__(content_filter, options)
def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]: def convert_links_to_citations(
self, markdown: str, base_url: str = ""
) -> Tuple[str, str]:
""" """
Convert links in markdown to citations. Convert links in markdown to citations.
How it works: How it works:
1. Find all links in the markdown. 1. Find all links in the markdown.
2. Convert links to citations. 2. Convert links to citations.
3. Return converted markdown and references markdown. 3. Return converted markdown and references markdown.
Note: Note:
This function uses a regex pattern to find links in markdown. This function uses a regex pattern to find links in markdown.
Args: Args:
markdown (str): Markdown text. markdown (str): Markdown text.
base_url (str): Base URL for URL joins. base_url (str): Base URL for URL joins.
Returns: Returns:
Tuple[str, str]: Converted markdown and references markdown. Tuple[str, str]: Converted markdown and references markdown.
""" """
@@ -81,57 +98,65 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
parts = [] parts = []
last_end = 0 last_end = 0
counter = 1 counter = 1
for match in LINK_PATTERN.finditer(markdown): for match in LINK_PATTERN.finditer(markdown):
parts.append(markdown[last_end:match.start()]) parts.append(markdown[last_end : match.start()])
text, url, title = match.groups() text, url, title = match.groups()
# Use cached URL if available, otherwise compute and cache # Use cached URL if available, otherwise compute and cache
if base_url and not url.startswith(('http://', 'https://', 'mailto:')): if base_url and not url.startswith(("http://", "https://", "mailto:")):
if url not in url_cache: if url not in url_cache:
url_cache[url] = fast_urljoin(base_url, url) url_cache[url] = fast_urljoin(base_url, url)
url = url_cache[url] url = url_cache[url]
if url not in link_map: if url not in link_map:
desc = [] desc = []
if title: desc.append(title) if title:
if text and text != title: desc.append(text) desc.append(title)
if text and text != title:
desc.append(text)
link_map[url] = (counter, ": " + " - ".join(desc) if desc else "") link_map[url] = (counter, ": " + " - ".join(desc) if desc else "")
counter += 1 counter += 1
num = link_map[url][0] num = link_map[url][0]
parts.append(f"{text}{num}" if not match.group(0).startswith('!') else f"![{text}{num}⟩]") parts.append(
f"{text}{num}"
if not match.group(0).startswith("!")
else f"![{text}{num}⟩]"
)
last_end = match.end() last_end = match.end()
parts.append(markdown[last_end:]) parts.append(markdown[last_end:])
converted_text = ''.join(parts) converted_text = "".join(parts)
# Pre-build reference strings # Pre-build reference strings
references = ["\n\n## References\n\n"] references = ["\n\n## References\n\n"]
references.extend( references.extend(
f"{num}{url}{desc}\n" f"{num}{url}{desc}\n"
for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0]) for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0])
) )
return converted_text, ''.join(references)
def generate_markdown(self, return converted_text, "".join(references)
cleaned_html: str,
base_url: str = "", def generate_markdown(
html2text_options: Optional[Dict[str, Any]] = None, self,
options: Optional[Dict[str, Any]] = None, cleaned_html: str,
content_filter: Optional[RelevantContentFilter] = None, base_url: str = "",
citations: bool = True, html2text_options: Optional[Dict[str, Any]] = None,
**kwargs) -> MarkdownGenerationResult: options: Optional[Dict[str, Any]] = None,
content_filter: Optional[RelevantContentFilter] = None,
citations: bool = True,
**kwargs,
) -> MarkdownGenerationResult:
""" """
Generate markdown with citations from cleaned HTML. Generate markdown with citations from cleaned HTML.
How it works: How it works:
1. Generate raw markdown from cleaned HTML. 1. Generate raw markdown from cleaned HTML.
2. Convert links to citations. 2. Convert links to citations.
3. Generate fit markdown if content filter is provided. 3. Generate fit markdown if content filter is provided.
4. Return MarkdownGenerationResult. 4. Return MarkdownGenerationResult.
Args: Args:
cleaned_html (str): Cleaned HTML content. cleaned_html (str): Cleaned HTML content.
base_url (str): Base URL for URL joins. base_url (str): Base URL for URL joins.
@@ -139,7 +164,7 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
options (Optional[Dict[str, Any]]): Additional options for markdown generation. options (Optional[Dict[str, Any]]): Additional options for markdown generation.
content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown. content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown.
citations (bool): Whether to generate citations. citations (bool): Whether to generate citations.
Returns: Returns:
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown. MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
""" """
@@ -147,16 +172,16 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
# Initialize HTML2Text with default options for better conversion # Initialize HTML2Text with default options for better conversion
h = CustomHTML2Text(baseurl=base_url) h = CustomHTML2Text(baseurl=base_url)
default_options = { default_options = {
'body_width': 0, # Disable text wrapping "body_width": 0, # Disable text wrapping
'ignore_emphasis': False, "ignore_emphasis": False,
'ignore_links': False, "ignore_links": False,
'ignore_images': False, "ignore_images": False,
'protect_links': True, "protect_links": True,
'single_line_break': True, "single_line_break": True,
'mark_code': True, "mark_code": True,
'escape_snob': False "escape_snob": False,
} }
# Update with custom options if provided # Update with custom options if provided
if html2text_options: if html2text_options:
default_options.update(html2text_options) default_options.update(html2text_options)
@@ -164,7 +189,7 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
default_options.update(options) default_options.update(options)
elif self.options: elif self.options:
default_options.update(self.options) default_options.update(self.options)
h.update_params(**default_options) h.update_params(**default_options)
# Ensure we have valid input # Ensure we have valid input
@@ -178,17 +203,18 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
raw_markdown = h.handle(cleaned_html) raw_markdown = h.handle(cleaned_html)
except Exception as e: except Exception as e:
raw_markdown = f"Error converting HTML to markdown: {str(e)}" raw_markdown = f"Error converting HTML to markdown: {str(e)}"
raw_markdown = raw_markdown.replace(' ```', '```') raw_markdown = raw_markdown.replace(" ```", "```")
# Convert links to citations # Convert links to citations
markdown_with_citations: str = raw_markdown markdown_with_citations: str = raw_markdown
references_markdown: str = "" references_markdown: str = ""
if citations: if citations:
try: try:
markdown_with_citations, references_markdown = self.convert_links_to_citations( (
raw_markdown, base_url markdown_with_citations,
) references_markdown,
) = self.convert_links_to_citations(raw_markdown, base_url)
except Exception as e: except Exception as e:
markdown_with_citations = raw_markdown markdown_with_citations = raw_markdown
references_markdown = f"Error generating citations: {str(e)}" references_markdown = f"Error generating citations: {str(e)}"
@@ -200,7 +226,9 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
try: try:
content_filter = content_filter or self.content_filter content_filter = content_filter or self.content_filter
filtered_html = content_filter.filter_content(cleaned_html) filtered_html = content_filter.filter_content(cleaned_html)
filtered_html = '\n'.join('<div>{}</div>'.format(s) for s in filtered_html) filtered_html = "\n".join(
"<div>{}</div>".format(s) for s in filtered_html
)
fit_markdown = h.handle(filtered_html) fit_markdown = h.handle(filtered_html)
except Exception as e: except Exception as e:
fit_markdown = f"Error generating fit markdown: {str(e)}" fit_markdown = f"Error generating fit markdown: {str(e)}"

View File

@@ -1,13 +1,11 @@
import os import os
import asyncio import asyncio
import logging
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
from typing import Optional from typing import Optional
import xxhash import xxhash
import aiofiles import aiofiles
import shutil import shutil
import time
from datetime import datetime from datetime import datetime
from .async_logger import AsyncLogger, LogLevel from .async_logger import AsyncLogger, LogLevel
@@ -17,18 +15,19 @@ logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
# logging.basicConfig(level=logging.INFO) # logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
class DatabaseMigration: class DatabaseMigration:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
self.content_paths = self._ensure_content_dirs(os.path.dirname(db_path)) self.content_paths = self._ensure_content_dirs(os.path.dirname(db_path))
def _ensure_content_dirs(self, base_path: str) -> dict: def _ensure_content_dirs(self, base_path: str) -> dict:
dirs = { dirs = {
'html': 'html_content', "html": "html_content",
'cleaned': 'cleaned_html', "cleaned": "cleaned_html",
'markdown': 'markdown_content', "markdown": "markdown_content",
'extracted': 'extracted_content', "extracted": "extracted_content",
'screenshots': 'screenshots' "screenshots": "screenshots",
} }
content_paths = {} content_paths = {}
for key, dirname in dirs.items(): for key, dirname in dirs.items():
@@ -47,43 +46,55 @@ class DatabaseMigration:
async def _store_content(self, content: str, content_type: str) -> str: async def _store_content(self, content: str, content_type: str) -> str:
if not content: if not content:
return "" return ""
content_hash = self._generate_content_hash(content) content_hash = self._generate_content_hash(content)
file_path = os.path.join(self.content_paths[content_type], content_hash) file_path = os.path.join(self.content_paths[content_type], content_hash)
if not os.path.exists(file_path): if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
await f.write(content) await f.write(content)
return content_hash return content_hash
async def migrate_database(self): async def migrate_database(self):
"""Migrate existing database to file-based storage""" """Migrate existing database to file-based storage"""
# logger.info("Starting database migration...") # logger.info("Starting database migration...")
logger.info("Starting database migration...", tag="INIT") logger.info("Starting database migration...", tag="INIT")
try: try:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
# Get all rows # Get all rows
async with db.execute( async with db.execute(
'''SELECT url, html, cleaned_html, markdown, """SELECT url, html, cleaned_html, markdown,
extracted_content, screenshot FROM crawled_data''' extracted_content, screenshot FROM crawled_data"""
) as cursor: ) as cursor:
rows = await cursor.fetchall() rows = await cursor.fetchall()
migrated_count = 0 migrated_count = 0
for row in rows: for row in rows:
url, html, cleaned_html, markdown, extracted_content, screenshot = row (
url,
html,
cleaned_html,
markdown,
extracted_content,
screenshot,
) = row
# Store content in files and get hashes # Store content in files and get hashes
html_hash = await self._store_content(html, 'html') html_hash = await self._store_content(html, "html")
cleaned_hash = await self._store_content(cleaned_html, 'cleaned') cleaned_hash = await self._store_content(cleaned_html, "cleaned")
markdown_hash = await self._store_content(markdown, 'markdown') markdown_hash = await self._store_content(markdown, "markdown")
extracted_hash = await self._store_content(extracted_content, 'extracted') extracted_hash = await self._store_content(
screenshot_hash = await self._store_content(screenshot, 'screenshots') extracted_content, "extracted"
)
screenshot_hash = await self._store_content(
screenshot, "screenshots"
)
# Update database with hashes # Update database with hashes
await db.execute(''' await db.execute(
"""
UPDATE crawled_data UPDATE crawled_data
SET html = ?, SET html = ?,
cleaned_html = ?, cleaned_html = ?,
@@ -91,40 +102,51 @@ class DatabaseMigration:
extracted_content = ?, extracted_content = ?,
screenshot = ? screenshot = ?
WHERE url = ? WHERE url = ?
''', (html_hash, cleaned_hash, markdown_hash, """,
extracted_hash, screenshot_hash, url)) (
html_hash,
cleaned_hash,
markdown_hash,
extracted_hash,
screenshot_hash,
url,
),
)
migrated_count += 1 migrated_count += 1
if migrated_count % 100 == 0: if migrated_count % 100 == 0:
logger.info(f"Migrated {migrated_count} records...", tag="INIT") logger.info(f"Migrated {migrated_count} records...", tag="INIT")
await db.commit() await db.commit()
logger.success(f"Migration completed. {migrated_count} records processed.", tag="COMPLETE") logger.success(
f"Migration completed. {migrated_count} records processed.",
tag="COMPLETE",
)
except Exception as e: except Exception as e:
# logger.error(f"Migration failed: {e}") # logger.error(f"Migration failed: {e}")
logger.error( logger.error(
message="Migration failed: {error}", message="Migration failed: {error}",
tag="ERROR", tag="ERROR",
params={"error": str(e)} params={"error": str(e)},
) )
raise e raise e
async def backup_database(db_path: str) -> str: async def backup_database(db_path: str) -> str:
"""Create backup of existing database""" """Create backup of existing database"""
if not os.path.exists(db_path): if not os.path.exists(db_path):
logger.info("No existing database found. Skipping backup.", tag="INIT") logger.info("No existing database found. Skipping backup.", tag="INIT")
return None return None
# Create backup with timestamp # Create backup with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = f"{db_path}.backup_{timestamp}" backup_path = f"{db_path}.backup_{timestamp}"
try: try:
# Wait for any potential write operations to finish # Wait for any potential write operations to finish
await asyncio.sleep(1) await asyncio.sleep(1)
# Create backup # Create backup
shutil.copy2(db_path, backup_path) shutil.copy2(db_path, backup_path)
logger.info(f"Database backup created at: {backup_path}", tag="COMPLETE") logger.info(f"Database backup created at: {backup_path}", tag="COMPLETE")
@@ -132,37 +154,41 @@ async def backup_database(db_path: str) -> str:
except Exception as e: except Exception as e:
# logger.error(f"Backup failed: {e}") # logger.error(f"Backup failed: {e}")
logger.error( logger.error(
message="Migration failed: {error}", message="Migration failed: {error}", tag="ERROR", params={"error": str(e)}
tag="ERROR", )
params={"error": str(e)}
)
raise e raise e
async def run_migration(db_path: Optional[str] = None): async def run_migration(db_path: Optional[str] = None):
"""Run database migration""" """Run database migration"""
if db_path is None: if db_path is None:
db_path = os.path.join(Path.home(), ".crawl4ai", "crawl4ai.db") db_path = os.path.join(Path.home(), ".crawl4ai", "crawl4ai.db")
if not os.path.exists(db_path): if not os.path.exists(db_path):
logger.info("No existing database found. Skipping migration.", tag="INIT") logger.info("No existing database found. Skipping migration.", tag="INIT")
return return
# Create backup first # Create backup first
backup_path = await backup_database(db_path) backup_path = await backup_database(db_path)
if not backup_path: if not backup_path:
return return
migration = DatabaseMigration(db_path) migration = DatabaseMigration(db_path)
await migration.migrate_database() await migration.migrate_database()
def main(): def main():
"""CLI entry point for migration""" """CLI entry point for migration"""
import argparse import argparse
parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage')
parser.add_argument('--db-path', help='Custom database path') parser = argparse.ArgumentParser(
description="Migrate Crawl4AI database to file-based storage"
)
parser.add_argument("--db-path", help="Custom database path")
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_migration(args.db_path)) asyncio.run(run_migration(args.db_path))
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,109 +2,125 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
import subprocess, os import subprocess, os
import shutil import shutil
import tarfile
from .model_loader import * from .model_loader import *
import argparse import argparse
import urllib.request
from crawl4ai.config import MODEL_REPO_BRANCH from crawl4ai.config import MODEL_REPO_BRANCH
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
@lru_cache() @lru_cache()
def get_available_memory(device): def get_available_memory(device):
import torch import torch
if device.type == 'cuda':
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory return torch.cuda.get_device_properties(device).total_memory
elif device.type == 'mps': elif device.type == "mps":
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
else: else:
return 0 return 0
@lru_cache() @lru_cache()
def calculate_batch_size(device): def calculate_batch_size(device):
available_memory = get_available_memory(device) available_memory = get_available_memory(device)
if device.type == 'cpu': if device.type == "cpu":
return 16 return 16
elif device.type in ['cuda', 'mps']: elif device.type in ["cuda", "mps"]:
# Adjust these thresholds based on your model size and available memory # Adjust these thresholds based on your model size and available memory
if available_memory >= 31 * 1024 ** 3: # > 32GB if available_memory >= 31 * 1024**3: # > 32GB
return 256 return 256
elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB elif available_memory >= 15 * 1024**3: # > 16GB to 32GB
return 128 return 128
elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB elif available_memory >= 8 * 1024**3: # 8GB to 16GB
return 64 return 64
else: else:
return 32 return 32
else: else:
return 16 # Default batch size return 16 # Default batch size
@lru_cache() @lru_cache()
def get_device(): def get_device():
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device("cuda")
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
device = torch.device('mps') device = torch.device("mps")
else: else:
device = torch.device('cpu') device = torch.device("cpu")
return device return device
def set_model_device(model): def set_model_device(model):
device = get_device() device = get_device()
model.to(device) model.to(device)
return model, device return model, device
@lru_cache() @lru_cache()
def get_home_folder(): def get_home_folder():
home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") home_folder = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(home_folder, exist_ok=True) os.makedirs(home_folder, exist_ok=True)
os.makedirs(f"{home_folder}/cache", exist_ok=True) os.makedirs(f"{home_folder}/cache", exist_ok=True)
os.makedirs(f"{home_folder}/models", exist_ok=True) os.makedirs(f"{home_folder}/models", exist_ok=True)
return home_folder return home_folder
@lru_cache() @lru_cache()
def load_bert_base_uncased(): def load_bert_base_uncased():
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None)
model = BertModel.from_pretrained("bert-base-uncased", resume_download=None)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
return tokenizer, model return tokenizer, model
@lru_cache() @lru_cache()
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple: def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
"""Load the Hugging Face model for embedding. """Load the Hugging Face model for embedding.
Args: Args:
model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5". model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5".
Returns: Returns:
tuple: The tokenizer and model. tuple: The tokenizer and model.
""" """
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None) tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
model = AutoModel.from_pretrained(model_name, resume_download=None) model = AutoModel.from_pretrained(model_name, resume_download=None)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
return tokenizer, model return tokenizer, model
@lru_cache() @lru_cache()
def load_text_classifier(): def load_text_classifier():
from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline from transformers import pipeline
import torch
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") tokenizer = AutoTokenizer.from_pretrained(
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") "dstefa/roberta-base_topic_classification_nyt_news"
)
model = AutoModelForSequenceClassification.from_pretrained(
"dstefa/roberta-base_topic_classification_nyt_news"
)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe return pipe
@lru_cache() @lru_cache()
def load_text_multilabel_classifier(): def load_text_multilabel_classifier():
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
from scipy.special import expit from scipy.special import expit
import torch import torch
@@ -116,18 +132,27 @@ def load_text_multilabel_classifier():
# else: # else:
# device = torch.device("cpu") # device = torch.device("cpu")
# # return load_spacy_model(), torch.device("cpu") # # return load_spacy_model(), torch.device("cpu")
MODEL = "cardiffnlp/tweet-topic-21-multi" MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) model = AutoModelForSequenceClassification.from_pretrained(
MODEL, resume_download=None
)
model.eval() model.eval()
model, device = set_model_device(model) model, device = set_model_device(model)
class_mapping = model.config.id2label class_mapping = model.config.id2label
def _classifier(texts, threshold=0.5, max_length=64): def _classifier(texts, threshold=0.5, max_length=64):
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length) tokens = tokenizer(
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
tokens = {
key: val.to(device) for key, val in tokens.items()
} # Move tokens to the selected device
with torch.no_grad(): with torch.no_grad():
output = model(**tokens) output = model(**tokens)
@@ -138,35 +163,41 @@ def load_text_multilabel_classifier():
batch_labels = [] batch_labels = []
for prediction in predictions: for prediction in predictions:
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1] labels = [
class_mapping[i] for i, value in enumerate(prediction) if value == 1
]
batch_labels.append(labels) batch_labels.append(labels)
return batch_labels return batch_labels
return _classifier, device return _classifier, device
@lru_cache() @lru_cache()
def load_nltk_punkt(): def load_nltk_punkt():
import nltk import nltk
try: try:
nltk.data.find('tokenizers/punkt') nltk.data.find("tokenizers/punkt")
except LookupError: except LookupError:
nltk.download('punkt') nltk.download("punkt")
return nltk.data.find('tokenizers/punkt') return nltk.data.find("tokenizers/punkt")
@lru_cache() @lru_cache()
def load_spacy_model(): def load_spacy_model():
import spacy import spacy
name = "models/reuters" name = "models/reuters"
home_folder = get_home_folder() home_folder = get_home_folder()
model_folder = Path(home_folder) / name model_folder = Path(home_folder) / name
# Check if the model directory already exists # Check if the model directory already exists
if not (model_folder.exists() and any(model_folder.iterdir())): if not (model_folder.exists() and any(model_folder.iterdir())):
repo_url = "https://github.com/unclecode/crawl4ai.git" repo_url = "https://github.com/unclecode/crawl4ai.git"
branch = MODEL_REPO_BRANCH branch = MODEL_REPO_BRANCH
repo_folder = Path(home_folder) / "crawl4ai" repo_folder = Path(home_folder) / "crawl4ai"
print("[LOG] ⏬ Downloading Spacy model for the first time...") print("[LOG] ⏬ Downloading Spacy model for the first time...")
# Remove existing repo folder if it exists # Remove existing repo folder if it exists
@@ -176,7 +207,9 @@ def load_spacy_model():
if model_folder.exists(): if model_folder.exists():
shutil.rmtree(model_folder) shutil.rmtree(model_folder)
except PermissionError: except PermissionError:
print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:") print(
"[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:"
)
print(f"- {repo_folder}") print(f"- {repo_folder}")
print(f"- {model_folder}") print(f"- {model_folder}")
return None return None
@@ -187,7 +220,7 @@ def load_spacy_model():
["git", "clone", "-b", branch, repo_url, str(repo_folder)], ["git", "clone", "-b", branch, repo_url, str(repo_folder)],
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
check=True check=True,
) )
# Create the models directory if it doesn't exist # Create the models directory if it doesn't exist
@@ -215,6 +248,7 @@ def load_spacy_model():
print(f"Error loading spacy model: {e}") print(f"Error loading spacy model: {e}")
return None return None
def download_all_models(remove_existing=False): def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI.""" """Download all models required for Crawl4AI."""
if remove_existing: if remove_existing:
@@ -243,14 +277,20 @@ def download_all_models(remove_existing=False):
load_nltk_punkt() load_nltk_punkt()
print("[LOG] ✅ All models downloaded successfully.") print("[LOG] ✅ All models downloaded successfully.")
def main(): def main():
print("[LOG] Welcome to the Crawl4AI Model Downloader!") print("[LOG] Welcome to the Crawl4AI Model Downloader!")
print("[LOG] This script will download all the models required for Crawl4AI.") print("[LOG] This script will download all the models required for Crawl4AI.")
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader") parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading") parser.add_argument(
"--remove-existing",
action="store_true",
help="Remove existing models before downloading",
)
args = parser.parse_args() args = parser.parse_args()
download_all_models(remove_existing=args.remove_existing) download_all_models(remove_existing=args.remove_existing)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,18 +1,12 @@
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Optional, Callable, Awaitable, Union, Tuple, Any from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
from enum import Enum from enum import Enum
from dataclasses import dataclass, field
from .ssl_certificate import SSLCertificate
from dataclasses import dataclass from dataclasses import dataclass
from .ssl_certificate import SSLCertificate
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Optional
from datetime import timedelta from datetime import timedelta
############################### ###############################
# Dispatcher Models # Dispatcher Models
############################### ###############################
@@ -22,6 +16,7 @@ class DomainState:
current_delay: float = 0 current_delay: float = 0
fail_count: int = 0 fail_count: int = 0
@dataclass @dataclass
class CrawlerTaskResult: class CrawlerTaskResult:
task_id: str task_id: str
@@ -33,12 +28,14 @@ class CrawlerTaskResult:
end_time: datetime end_time: datetime
error_message: str = "" error_message: str = ""
class CrawlStatus(Enum): class CrawlStatus(Enum):
QUEUED = "QUEUED" QUEUED = "QUEUED"
IN_PROGRESS = "IN_PROGRESS" IN_PROGRESS = "IN_PROGRESS"
COMPLETED = "COMPLETED" COMPLETED = "COMPLETED"
FAILED = "FAILED" FAILED = "FAILED"
@dataclass @dataclass
class CrawlStats: class CrawlStats:
task_id: str task_id: str
@@ -49,7 +46,7 @@ class CrawlStats:
memory_usage: float = 0.0 memory_usage: float = 0.0
peak_memory: float = 0.0 peak_memory: float = 0.0
error_message: str = "" error_message: str = ""
@property @property
def duration(self) -> str: def duration(self) -> str:
if not self.start_time: if not self.start_time:
@@ -58,26 +55,29 @@ class CrawlStats:
duration = end - self.start_time duration = end - self.start_time
return str(timedelta(seconds=int(duration.total_seconds()))) return str(timedelta(seconds=int(duration.total_seconds())))
class DisplayMode(Enum): class DisplayMode(Enum):
DETAILED = "DETAILED" DETAILED = "DETAILED"
AGGREGATED = "AGGREGATED" AGGREGATED = "AGGREGATED"
############################### ###############################
# Crawler Models # Crawler Models
############################### ###############################
@dataclass @dataclass
class TokenUsage: class TokenUsage:
completion_tokens: int = 0 completion_tokens: int = 0
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
completion_tokens_details: Optional[dict] = None completion_tokens_details: Optional[dict] = None
prompt_tokens_details: Optional[dict] = None prompt_tokens_details: Optional[dict] = None
class UrlModel(BaseModel): class UrlModel(BaseModel):
url: HttpUrl url: HttpUrl
forced: bool = False forced: bool = False
class MarkdownGenerationResult(BaseModel): class MarkdownGenerationResult(BaseModel):
raw_markdown: str raw_markdown: str
markdown_with_citations: str markdown_with_citations: str
@@ -85,6 +85,7 @@ class MarkdownGenerationResult(BaseModel):
fit_markdown: Optional[str] = None fit_markdown: Optional[str] = None
fit_html: Optional[str] = None fit_html: Optional[str] = None
class DispatchResult(BaseModel): class DispatchResult(BaseModel):
task_id: str task_id: str
memory_usage: float memory_usage: float
@@ -92,6 +93,8 @@ class DispatchResult(BaseModel):
start_time: datetime start_time: datetime
end_time: datetime end_time: datetime
error_message: str = "" error_message: str = ""
class CrawlResult(BaseModel): class CrawlResult(BaseModel):
url: str url: str
html: str html: str
@@ -101,7 +104,7 @@ class CrawlResult(BaseModel):
links: Dict[str, List[Dict]] = {} links: Dict[str, List[Dict]] = {}
downloaded_files: Optional[List[str]] = None downloaded_files: Optional[List[str]] = None
screenshot: Optional[str] = None screenshot: Optional[str] = None
pdf : Optional[bytes] = None pdf: Optional[bytes] = None
markdown: Optional[Union[str, MarkdownGenerationResult]] = None markdown: Optional[Union[str, MarkdownGenerationResult]] = None
markdown_v2: Optional[MarkdownGenerationResult] = None markdown_v2: Optional[MarkdownGenerationResult] = None
fit_markdown: Optional[str] = None fit_markdown: Optional[str] = None
@@ -114,9 +117,11 @@ class CrawlResult(BaseModel):
status_code: Optional[int] = None status_code: Optional[int] = None
ssl_certificate: Optional[SSLCertificate] = None ssl_certificate: Optional[SSLCertificate] = None
dispatch_result: Optional[DispatchResult] = None dispatch_result: Optional[DispatchResult] = None
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class AsyncCrawlResponse(BaseModel): class AsyncCrawlResponse(BaseModel):
html: str html: str
response_headers: Dict[str, str] response_headers: Dict[str, str]
@@ -130,6 +135,7 @@ class AsyncCrawlResponse(BaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
############################### ###############################
# Scraping Models # Scraping Models
############################### ###############################
@@ -143,21 +149,29 @@ class MediaItem(BaseModel):
format: Optional[str] = None format: Optional[str] = None
width: Optional[int] = None width: Optional[int] = None
class Link(BaseModel): class Link(BaseModel):
href: str href: str
text: str text: str
title: Optional[str] = None title: Optional[str] = None
base_domain: str base_domain: str
class Media(BaseModel): class Media(BaseModel):
images: List[MediaItem] = [] images: List[MediaItem] = []
videos: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Video model if needed videos: List[
audios: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Audio model if needed MediaItem
] = [] # Using MediaItem model for now, can be extended with Video model if needed
audios: List[
MediaItem
] = [] # Using MediaItem model for now, can be extended with Audio model if needed
class Links(BaseModel): class Links(BaseModel):
internal: List[Link] = [] internal: List[Link] = []
external: List[Link] = [] external: List[Link] = []
class ScrapingResult(BaseModel): class ScrapingResult(BaseModel):
cleaned_html: str cleaned_html: str
success: bool success: bool

View File

@@ -13,10 +13,10 @@ from pathlib import Path
class SSLCertificate: class SSLCertificate:
""" """
A class representing an SSL certificate with methods to export in various formats. A class representing an SSL certificate with methods to export in various formats.
Attributes: Attributes:
cert_info (Dict[str, Any]): The certificate information. cert_info (Dict[str, Any]): The certificate information.
Methods: Methods:
from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: Create SSLCertificate instance from a URL. from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: Create SSLCertificate instance from a URL.
from_file(file_path: str) -> Optional['SSLCertificate']: Create SSLCertificate instance from a file. from_file(file_path: str) -> Optional['SSLCertificate']: Create SSLCertificate instance from a file.
@@ -26,32 +26,35 @@ class SSLCertificate:
export_as_json() -> Dict[str, Any]: Export the certificate as JSON format. export_as_json() -> Dict[str, Any]: Export the certificate as JSON format.
export_as_text() -> str: Export the certificate as text format. export_as_text() -> str: Export the certificate as text format.
""" """
def __init__(self, cert_info: Dict[str, Any]): def __init__(self, cert_info: Dict[str, Any]):
self._cert_info = self._decode_cert_data(cert_info) self._cert_info = self._decode_cert_data(cert_info)
@staticmethod @staticmethod
def from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: def from_url(url: str, timeout: int = 10) -> Optional["SSLCertificate"]:
""" """
Create SSLCertificate instance from a URL. Create SSLCertificate instance from a URL.
Args: Args:
url (str): URL of the website. url (str): URL of the website.
timeout (int): Timeout for the connection (default: 10). timeout (int): Timeout for the connection (default: 10).
Returns: Returns:
Optional[SSLCertificate]: SSLCertificate instance if successful, None otherwise. Optional[SSLCertificate]: SSLCertificate instance if successful, None otherwise.
""" """
try: try:
hostname = urlparse(url).netloc hostname = urlparse(url).netloc
if ':' in hostname: if ":" in hostname:
hostname = hostname.split(':')[0] hostname = hostname.split(":")[0]
context = ssl.create_default_context() context = ssl.create_default_context()
with socket.create_connection((hostname, 443), timeout=timeout) as sock: with socket.create_connection((hostname, 443), timeout=timeout) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock: with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert_binary = ssock.getpeercert(binary_form=True) cert_binary = ssock.getpeercert(binary_form=True)
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary) x509 = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, cert_binary
)
cert_info = { cert_info = {
"subject": dict(x509.get_subject().get_components()), "subject": dict(x509.get_subject().get_components()),
"issuer": dict(x509.get_issuer().get_components()), "issuer": dict(x509.get_issuer().get_components()),
@@ -61,32 +64,33 @@ class SSLCertificate:
"not_after": x509.get_notAfter(), "not_after": x509.get_notAfter(),
"fingerprint": x509.digest("sha256").hex(), "fingerprint": x509.digest("sha256").hex(),
"signature_algorithm": x509.get_signature_algorithm(), "signature_algorithm": x509.get_signature_algorithm(),
"raw_cert": base64.b64encode(cert_binary) "raw_cert": base64.b64encode(cert_binary),
} }
# Add extensions # Add extensions
extensions = [] extensions = []
for i in range(x509.get_extension_count()): for i in range(x509.get_extension_count()):
ext = x509.get_extension(i) ext = x509.get_extension(i)
extensions.append({ extensions.append(
"name": ext.get_short_name(), {"name": ext.get_short_name(), "value": str(ext)}
"value": str(ext) )
})
cert_info["extensions"] = extensions cert_info["extensions"] = extensions
return SSLCertificate(cert_info) return SSLCertificate(cert_info)
except Exception as e: except Exception:
return None return None
@staticmethod @staticmethod
def _decode_cert_data(data: Any) -> Any: def _decode_cert_data(data: Any) -> Any:
"""Helper method to decode bytes in certificate data.""" """Helper method to decode bytes in certificate data."""
if isinstance(data, bytes): if isinstance(data, bytes):
return data.decode('utf-8') return data.decode("utf-8")
elif isinstance(data, dict): elif isinstance(data, dict):
return { return {
(k.decode('utf-8') if isinstance(k, bytes) else k): SSLCertificate._decode_cert_data(v) (
k.decode("utf-8") if isinstance(k, bytes) else k
): SSLCertificate._decode_cert_data(v)
for k, v in data.items() for k, v in data.items()
} }
elif isinstance(data, list): elif isinstance(data, list):
@@ -96,58 +100,57 @@ class SSLCertificate:
def to_json(self, filepath: Optional[str] = None) -> Optional[str]: def to_json(self, filepath: Optional[str] = None) -> Optional[str]:
""" """
Export certificate as JSON. Export certificate as JSON.
Args: Args:
filepath (Optional[str]): Path to save the JSON file (default: None). filepath (Optional[str]): Path to save the JSON file (default: None).
Returns: Returns:
Optional[str]: JSON string if successful, None otherwise. Optional[str]: JSON string if successful, None otherwise.
""" """
json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False) json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False)
if filepath: if filepath:
Path(filepath).write_text(json_str, encoding='utf-8') Path(filepath).write_text(json_str, encoding="utf-8")
return None return None
return json_str return json_str
def to_pem(self, filepath: Optional[str] = None) -> Optional[str]: def to_pem(self, filepath: Optional[str] = None) -> Optional[str]:
""" """
Export certificate as PEM. Export certificate as PEM.
Args: Args:
filepath (Optional[str]): Path to save the PEM file (default: None). filepath (Optional[str]): Path to save the PEM file (default: None).
Returns: Returns:
Optional[str]: PEM string if successful, None otherwise. Optional[str]: PEM string if successful, None otherwise.
""" """
try: try:
x509 = OpenSSL.crypto.load_certificate( x509 = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, OpenSSL.crypto.FILETYPE_ASN1,
base64.b64decode(self._cert_info['raw_cert']) base64.b64decode(self._cert_info["raw_cert"]),
) )
pem_data = OpenSSL.crypto.dump_certificate( pem_data = OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM, x509
x509 ).decode("utf-8")
).decode('utf-8')
if filepath: if filepath:
Path(filepath).write_text(pem_data, encoding='utf-8') Path(filepath).write_text(pem_data, encoding="utf-8")
return None return None
return pem_data return pem_data
except Exception as e: except Exception:
return None return None
def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]: def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]:
""" """
Export certificate as DER. Export certificate as DER.
Args: Args:
filepath (Optional[str]): Path to save the DER file (default: None). filepath (Optional[str]): Path to save the DER file (default: None).
Returns: Returns:
Optional[bytes]: DER bytes if successful, None otherwise. Optional[bytes]: DER bytes if successful, None otherwise.
""" """
try: try:
der_data = base64.b64decode(self._cert_info['raw_cert']) der_data = base64.b64decode(self._cert_info["raw_cert"])
if filepath: if filepath:
Path(filepath).write_bytes(der_data) Path(filepath).write_bytes(der_data)
return None return None
@@ -158,24 +161,24 @@ class SSLCertificate:
@property @property
def issuer(self) -> Dict[str, str]: def issuer(self) -> Dict[str, str]:
"""Get certificate issuer information.""" """Get certificate issuer information."""
return self._cert_info.get('issuer', {}) return self._cert_info.get("issuer", {})
@property @property
def subject(self) -> Dict[str, str]: def subject(self) -> Dict[str, str]:
"""Get certificate subject information.""" """Get certificate subject information."""
return self._cert_info.get('subject', {}) return self._cert_info.get("subject", {})
@property @property
def valid_from(self) -> str: def valid_from(self) -> str:
"""Get certificate validity start date.""" """Get certificate validity start date."""
return self._cert_info.get('not_before', '') return self._cert_info.get("not_before", "")
@property @property
def valid_until(self) -> str: def valid_until(self) -> str:
"""Get certificate validity end date.""" """Get certificate validity end date."""
return self._cert_info.get('not_after', '') return self._cert_info.get("not_after", "")
@property @property
def fingerprint(self) -> str: def fingerprint(self) -> str:
"""Get certificate fingerprint.""" """Get certificate fingerprint."""
return self._cert_info.get('fingerprint', '') return self._cert_info.get("fingerprint", "")

View File

@@ -6,7 +6,7 @@ import re
class UserAgentGenerator: class UserAgentGenerator:
""" """
Generate random user agents with specified constraints. Generate random user agents with specified constraints.
Attributes: Attributes:
desktop_platforms (dict): A dictionary of possible desktop platforms and their corresponding user agent strings. desktop_platforms (dict): A dictionary of possible desktop platforms and their corresponding user agent strings.
mobile_platforms (dict): A dictionary of possible mobile platforms and their corresponding user agent strings. mobile_platforms (dict): A dictionary of possible mobile platforms and their corresponding user agent strings.
@@ -18,7 +18,7 @@ class UserAgentGenerator:
safari_versions (list): A list of possible Safari browser versions. safari_versions (list): A list of possible Safari browser versions.
ios_versions (list): A list of possible iOS browser versions. ios_versions (list): A list of possible iOS browser versions.
android_versions (list): A list of possible Android browser versions. android_versions (list): A list of possible Android browser versions.
Methods: Methods:
generate_user_agent( generate_user_agent(
platform: Literal["desktop", "mobile"] = "desktop", platform: Literal["desktop", "mobile"] = "desktop",
@@ -30,8 +30,9 @@ class UserAgentGenerator:
safari_version: Optional[str] = None, safari_version: Optional[str] = None,
ios_version: Optional[str] = None, ios_version: Optional[str] = None,
android_version: Optional[str] = None android_version: Optional[str] = None
): Generates a random user agent string based on the specified parameters. ): Generates a random user agent string based on the specified parameters.
""" """
def __init__(self): def __init__(self):
# Previous platform definitions remain the same... # Previous platform definitions remain the same...
self.desktop_platforms = { self.desktop_platforms = {
@@ -47,7 +48,7 @@ class UserAgentGenerator:
"generic": "(X11; Linux x86_64)", "generic": "(X11; Linux x86_64)",
"ubuntu": "(X11; Ubuntu; Linux x86_64)", "ubuntu": "(X11; Ubuntu; Linux x86_64)",
"chrome_os": "(X11; CrOS x86_64 14541.0.0)", "chrome_os": "(X11; CrOS x86_64 14541.0.0)",
} },
} }
self.mobile_platforms = { self.mobile_platforms = {
@@ -60,26 +61,14 @@ class UserAgentGenerator:
"ios": { "ios": {
"iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)", "iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)",
"ipad": "(iPad; CPU OS 16_5 like Mac OS X)", "ipad": "(iPad; CPU OS 16_5 like Mac OS X)",
} },
} }
# Browser Combinations # Browser Combinations
self.browser_combinations = { self.browser_combinations = {
1: [ 1: [["chrome"], ["firefox"], ["safari"], ["edge"]],
["chrome"], 2: [["gecko", "firefox"], ["chrome", "safari"], ["webkit", "safari"]],
["firefox"], 3: [["chrome", "safari", "edge"], ["webkit", "chrome", "safari"]],
["safari"],
["edge"]
],
2: [
["gecko", "firefox"],
["chrome", "safari"],
["webkit", "safari"]
],
3: [
["chrome", "safari", "edge"],
["webkit", "chrome", "safari"]
]
} }
# Rendering Engines with versions # Rendering Engines with versions
@@ -90,7 +79,7 @@ class UserAgentGenerator:
"Gecko/20100101", "Gecko/20100101",
"Gecko/20100101", # Firefox usually uses this constant version "Gecko/20100101", # Firefox usually uses this constant version
"Gecko/2010010", "Gecko/2010010",
] ],
} }
# Browser Versions # Browser Versions
@@ -135,25 +124,25 @@ class UserAgentGenerator:
def get_browser_stack(self, num_browsers: int = 1) -> List[str]: def get_browser_stack(self, num_browsers: int = 1) -> List[str]:
""" """
Get a valid combination of browser versions. Get a valid combination of browser versions.
How it works: How it works:
1. Check if the number of browsers is supported. 1. Check if the number of browsers is supported.
2. Randomly choose a combination of browsers. 2. Randomly choose a combination of browsers.
3. Iterate through the combination and add browser versions. 3. Iterate through the combination and add browser versions.
4. Return the browser stack. 4. Return the browser stack.
Args: Args:
num_browsers: Number of browser specifications (1-3) num_browsers: Number of browser specifications (1-3)
Returns: Returns:
List[str]: A list of browser versions. List[str]: A list of browser versions.
""" """
if num_browsers not in self.browser_combinations: if num_browsers not in self.browser_combinations:
raise ValueError(f"Unsupported number of browsers: {num_browsers}") raise ValueError(f"Unsupported number of browsers: {num_browsers}")
combination = random.choice(self.browser_combinations[num_browsers]) combination = random.choice(self.browser_combinations[num_browsers])
browser_stack = [] browser_stack = []
for browser in combination: for browser in combination:
if browser == "chrome": if browser == "chrome":
browser_stack.append(random.choice(self.chrome_versions)) browser_stack.append(random.choice(self.chrome_versions))
@@ -167,18 +156,20 @@ class UserAgentGenerator:
browser_stack.append(random.choice(self.rendering_engines["gecko"])) browser_stack.append(random.choice(self.rendering_engines["gecko"]))
elif browser == "webkit": elif browser == "webkit":
browser_stack.append(self.rendering_engines["chrome_webkit"]) browser_stack.append(self.rendering_engines["chrome_webkit"])
return browser_stack return browser_stack
def generate(self, def generate(
device_type: Optional[Literal['desktop', 'mobile']] = None, self,
os_type: Optional[str] = None, device_type: Optional[Literal["desktop", "mobile"]] = None,
device_brand: Optional[str] = None, os_type: Optional[str] = None,
browser_type: Optional[Literal['chrome', 'edge', 'safari', 'firefox']] = None, device_brand: Optional[str] = None,
num_browsers: int = 3) -> str: browser_type: Optional[Literal["chrome", "edge", "safari", "firefox"]] = None,
num_browsers: int = 3,
) -> str:
""" """
Generate a random user agent with specified constraints. Generate a random user agent with specified constraints.
Args: Args:
device_type: 'desktop' or 'mobile' device_type: 'desktop' or 'mobile'
os_type: 'windows', 'macos', 'linux', 'android', 'ios' os_type: 'windows', 'macos', 'linux', 'android', 'ios'
@@ -188,23 +179,23 @@ class UserAgentGenerator:
""" """
# Get platform string # Get platform string
platform = self.get_random_platform(device_type, os_type, device_brand) platform = self.get_random_platform(device_type, os_type, device_brand)
# Start with Mozilla # Start with Mozilla
components = ["Mozilla/5.0", platform] components = ["Mozilla/5.0", platform]
# Add browser stack # Add browser stack
browser_stack = self.get_browser_stack(num_browsers) browser_stack = self.get_browser_stack(num_browsers)
# Add appropriate legacy token based on browser stack # Add appropriate legacy token based on browser stack
if "Firefox" in str(browser_stack): if "Firefox" in str(browser_stack):
components.append(random.choice(self.rendering_engines["gecko"])) components.append(random.choice(self.rendering_engines["gecko"]))
elif "Chrome" in str(browser_stack) or "Safari" in str(browser_stack): elif "Chrome" in str(browser_stack) or "Safari" in str(browser_stack):
components.append(self.rendering_engines["chrome_webkit"]) components.append(self.rendering_engines["chrome_webkit"])
components.append("(KHTML, like Gecko)") components.append("(KHTML, like Gecko)")
# Add browser versions # Add browser versions
components.extend(browser_stack) components.extend(browser_stack)
return " ".join(components) return " ".join(components)
def generate_with_client_hints(self, **kwargs) -> Tuple[str, str]: def generate_with_client_hints(self, **kwargs) -> Tuple[str, str]:
@@ -215,16 +206,20 @@ class UserAgentGenerator:
def get_random_platform(self, device_type, os_type, device_brand): def get_random_platform(self, device_type, os_type, device_brand):
"""Helper method to get random platform based on constraints""" """Helper method to get random platform based on constraints"""
platforms = self.desktop_platforms if device_type == 'desktop' else \ platforms = (
self.mobile_platforms if device_type == 'mobile' else \ self.desktop_platforms
{**self.desktop_platforms, **self.mobile_platforms} if device_type == "desktop"
else self.mobile_platforms
if device_type == "mobile"
else {**self.desktop_platforms, **self.mobile_platforms}
)
if os_type: if os_type:
for platform_group in [self.desktop_platforms, self.mobile_platforms]: for platform_group in [self.desktop_platforms, self.mobile_platforms]:
if os_type in platform_group: if os_type in platform_group:
platforms = {os_type: platform_group[os_type]} platforms = {os_type: platform_group[os_type]}
break break
os_key = random.choice(list(platforms.keys())) os_key = random.choice(list(platforms.keys()))
if device_brand and device_brand in platforms[os_key]: if device_brand and device_brand in platforms[os_key]:
return platforms[os_key][device_brand] return platforms[os_key][device_brand]
@@ -233,73 +228,72 @@ class UserAgentGenerator:
def parse_user_agent(self, user_agent: str) -> Dict[str, str]: def parse_user_agent(self, user_agent: str) -> Dict[str, str]:
"""Parse a user agent string to extract browser and version information""" """Parse a user agent string to extract browser and version information"""
browsers = { browsers = {
'chrome': r'Chrome/(\d+)', "chrome": r"Chrome/(\d+)",
'edge': r'Edg/(\d+)', "edge": r"Edg/(\d+)",
'safari': r'Version/(\d+)', "safari": r"Version/(\d+)",
'firefox': r'Firefox/(\d+)' "firefox": r"Firefox/(\d+)",
} }
result = {} result = {}
for browser, pattern in browsers.items(): for browser, pattern in browsers.items():
match = re.search(pattern, user_agent) match = re.search(pattern, user_agent)
if match: if match:
result[browser] = match.group(1) result[browser] = match.group(1)
return result return result
def generate_client_hints(self, user_agent: str) -> str: def generate_client_hints(self, user_agent: str) -> str:
"""Generate Sec-CH-UA header value based on user agent string""" """Generate Sec-CH-UA header value based on user agent string"""
browsers = self.parse_user_agent(user_agent) browsers = self.parse_user_agent(user_agent)
# Client hints components # Client hints components
hints = [] hints = []
# Handle different browser combinations # Handle different browser combinations
if 'chrome' in browsers: if "chrome" in browsers:
hints.append(f'"Chromium";v="{browsers["chrome"]}"') hints.append(f'"Chromium";v="{browsers["chrome"]}"')
hints.append('"Not_A Brand";v="8"') hints.append('"Not_A Brand";v="8"')
if 'edge' in browsers: if "edge" in browsers:
hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"') hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"')
else: else:
hints.append(f'"Google Chrome";v="{browsers["chrome"]}"') hints.append(f'"Google Chrome";v="{browsers["chrome"]}"')
elif 'firefox' in browsers: elif "firefox" in browsers:
# Firefox doesn't typically send Sec-CH-UA # Firefox doesn't typically send Sec-CH-UA
return '""' return '""'
elif 'safari' in browsers: elif "safari" in browsers:
# Safari's format for client hints # Safari's format for client hints
hints.append(f'"Safari";v="{browsers["safari"]}"') hints.append(f'"Safari";v="{browsers["safari"]}"')
hints.append('"Not_A Brand";v="8"') hints.append('"Not_A Brand";v="8"')
return ', '.join(hints) return ", ".join(hints)
# Example usage: # Example usage:
if __name__ == "__main__": if __name__ == "__main__":
generator = UserAgentGenerator() generator = UserAgentGenerator()
print(generator.generate()) print(generator.generate())
print("\nSingle browser (Chrome):") print("\nSingle browser (Chrome):")
print(generator.generate(num_browsers=1, browser_type='chrome')) print(generator.generate(num_browsers=1, browser_type="chrome"))
print("\nTwo browsers (Gecko/Firefox):") print("\nTwo browsers (Gecko/Firefox):")
print(generator.generate(num_browsers=2)) print(generator.generate(num_browsers=2))
print("\nThree browsers (Chrome/Safari/Edge):") print("\nThree browsers (Chrome/Safari/Edge):")
print(generator.generate(num_browsers=3)) print(generator.generate(num_browsers=3))
print("\nFirefox on Linux:") print("\nFirefox on Linux:")
print(generator.generate( print(
device_type='desktop', generator.generate(
os_type='linux', device_type="desktop",
browser_type='firefox', os_type="linux",
num_browsers=2 browser_type="firefox",
)) num_browsers=2,
)
)
print("\nChrome/Safari/Edge on Windows:") print("\nChrome/Safari/Edge on Windows:")
print(generator.generate( print(generator.generate(device_type="desktop", os_type="windows", num_browsers=3))
device_type='desktop',
os_type='windows',
num_browsers=3
))

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,14 @@
# version_manager.py # version_manager.py
import os
from pathlib import Path from pathlib import Path
from packaging import version from packaging import version
from . import __version__ from . import __version__
class VersionManager: class VersionManager:
def __init__(self): def __init__(self):
self.home_dir = Path.home() / ".crawl4ai" self.home_dir = Path.home() / ".crawl4ai"
self.version_file = self.home_dir / "version.txt" self.version_file = self.home_dir / "version.txt"
def get_installed_version(self): def get_installed_version(self):
"""Get the version recorded in home directory""" """Get the version recorded in home directory"""
if not self.version_file.exists(): if not self.version_file.exists():
@@ -17,14 +17,13 @@ class VersionManager:
return version.parse(self.version_file.read_text().strip()) return version.parse(self.version_file.read_text().strip())
except: except:
return None return None
def update_version(self): def update_version(self):
"""Update the version file to current library version""" """Update the version file to current library version"""
self.version_file.write_text(__version__.__version__) self.version_file.write_text(__version__.__version__)
def needs_update(self): def needs_update(self):
"""Check if database needs update based on version""" """Check if database needs update based on version"""
installed = self.get_installed_version() installed = self.get_installed_version()
current = version.parse(__version__.__version__) current = version.parse(__version__.__version__)
return installed is None or installed < current return installed is None or installed < current

View File

@@ -1,9 +1,10 @@
import os, time import os, time
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path from pathlib import Path
from .models import UrlModel, CrawlResult from .models import UrlModel, CrawlResult
from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db from .database import init_db, get_cached_url, cache_url
from .utils import * from .utils import *
from .chunking_strategy import * from .chunking_strategy import *
from .extraction_strategy import * from .extraction_strategy import *
@@ -14,31 +15,44 @@ from .content_scraping_strategy import WebScrapingStrategy
from .config import * from .config import *
import warnings import warnings
import json import json
warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".')
warnings.filterwarnings(
"ignore",
message='Field "model_name" has conflict with protected namespace "model_".',
)
class WebCrawler: class WebCrawler:
def __init__(self, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False): def __init__(
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose) self,
crawler_strategy: CrawlerStrategy = None,
always_by_pass_cache: bool = False,
verbose: bool = False,
):
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(
verbose=verbose
)
self.always_by_pass_cache = always_by_pass_cache self.always_by_pass_cache = always_by_pass_cache
self.crawl4ai_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai") self.crawl4ai_folder = os.path.join(
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
)
os.makedirs(self.crawl4ai_folder, exist_ok=True) os.makedirs(self.crawl4ai_folder, exist_ok=True)
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
init_db() init_db()
self.ready = False self.ready = False
def warmup(self): def warmup(self):
print("[LOG] 🌤️ Warming up the WebCrawler") print("[LOG] 🌤️ Warming up the WebCrawler")
self.run( self.run(
url='https://google.com/', url="https://google.com/",
word_count_threshold=5, word_count_threshold=5,
extraction_strategy=NoExtractionStrategy(), extraction_strategy=NoExtractionStrategy(),
bypass_cache=False, bypass_cache=False,
verbose=False verbose=False,
) )
self.ready = True self.ready = True
print("[LOG] 🌞 WebCrawler is ready to crawl") print("[LOG] 🌞 WebCrawler is ready to crawl")
def fetch_page( def fetch_page(
self, self,
url_model: UrlModel, url_model: UrlModel,
@@ -80,6 +94,7 @@ class WebCrawler:
**kwargs, **kwargs,
) -> List[CrawlResult]: ) -> List[CrawlResult]:
extraction_strategy = extraction_strategy or NoExtractionStrategy() extraction_strategy = extraction_strategy or NoExtractionStrategy()
def fetch_page_wrapper(url_model, *args, **kwargs): def fetch_page_wrapper(url_model, *args, **kwargs):
return self.fetch_page(url_model, *args, **kwargs) return self.fetch_page(url_model, *args, **kwargs)
@@ -104,150 +119,176 @@ class WebCrawler:
return results return results
def run( def run(
self, self,
url: str, url: str,
word_count_threshold=MIN_WORD_THRESHOLD, word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = None, extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
bypass_cache: bool = False, bypass_cache: bool = False,
css_selector: str = None, css_selector: str = None,
screenshot: bool = False, screenshot: bool = False,
user_agent: str = None, user_agent: str = None,
verbose=True, verbose=True,
**kwargs, **kwargs,
) -> CrawlResult: ) -> CrawlResult:
try: try:
extraction_strategy = extraction_strategy or NoExtractionStrategy() extraction_strategy = extraction_strategy or NoExtractionStrategy()
extraction_strategy.verbose = verbose extraction_strategy.verbose = verbose
if not isinstance(extraction_strategy, ExtractionStrategy): if not isinstance(extraction_strategy, ExtractionStrategy):
raise ValueError("Unsupported extraction strategy") raise ValueError("Unsupported extraction strategy")
if not isinstance(chunking_strategy, ChunkingStrategy): if not isinstance(chunking_strategy, ChunkingStrategy):
raise ValueError("Unsupported chunking strategy") raise ValueError("Unsupported chunking strategy")
word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)
cached = None word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)
screenshot_data = None
extracted_content = None
if not bypass_cache and not self.always_by_pass_cache:
cached = get_cached_url(url)
if kwargs.get("warmup", True) and not self.ready:
return None
if cached:
html = sanitize_input_encode(cached[1])
extracted_content = sanitize_input_encode(cached[4])
if screenshot:
screenshot_data = cached[9]
if not screenshot_data:
cached = None
if not cached or not html:
if user_agent:
self.crawler_strategy.update_user_agent(user_agent)
t1 = time.time()
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
t2 = time.time()
if verbose:
print(f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds")
if screenshot:
screenshot_data = self.crawler_strategy.take_screenshot()
cached = None
crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs) screenshot_data = None
crawl_result.success = bool(html) extracted_content = None
return crawl_result if not bypass_cache and not self.always_by_pass_cache:
except Exception as e: cached = get_cached_url(url)
if not hasattr(e, "msg"):
e.msg = str(e) if kwargs.get("warmup", True) and not self.ready:
print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}") return None
return CrawlResult(url=url, html="", success=False, error_message=e.msg)
if cached:
html = sanitize_input_encode(cached[1])
extracted_content = sanitize_input_encode(cached[4])
if screenshot:
screenshot_data = cached[9]
if not screenshot_data:
cached = None
if not cached or not html:
if user_agent:
self.crawler_strategy.update_user_agent(user_agent)
t1 = time.time()
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
t2 = time.time()
if verbose:
print(
f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
)
if screenshot:
screenshot_data = self.crawler_strategy.take_screenshot()
crawl_result = self.process_html(
url,
html,
extracted_content,
word_count_threshold,
extraction_strategy,
chunking_strategy,
css_selector,
screenshot_data,
verbose,
bool(cached),
**kwargs,
)
crawl_result.success = bool(html)
return crawl_result
except Exception as e:
if not hasattr(e, "msg"):
e.msg = str(e)
print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}")
return CrawlResult(url=url, html="", success=False, error_message=e.msg)
def process_html( def process_html(
self, self,
url: str, url: str,
html: str, html: str,
extracted_content: str, extracted_content: str,
word_count_threshold: int, word_count_threshold: int,
extraction_strategy: ExtractionStrategy, extraction_strategy: ExtractionStrategy,
chunking_strategy: ChunkingStrategy, chunking_strategy: ChunkingStrategy,
css_selector: str, css_selector: str,
screenshot: bool, screenshot: bool,
verbose: bool, verbose: bool,
is_cached: bool, is_cached: bool,
**kwargs, **kwargs,
) -> CrawlResult: ) -> CrawlResult:
t = time.time() t = time.time()
# Extract content from HTML # Extract content from HTML
try: try:
t1 = time.time() t1 = time.time()
scrapping_strategy = WebScrapingStrategy() scrapping_strategy = WebScrapingStrategy()
extra_params = {k: v for k, v in kwargs.items() if k not in ["only_text", "image_description_min_word_threshold"]} extra_params = {
result = scrapping_strategy.scrap( k: v
url, for k, v in kwargs.items()
html, if k not in ["only_text", "image_description_min_word_threshold"]
word_count_threshold=word_count_threshold, }
css_selector=css_selector, result = scrapping_strategy.scrap(
only_text=kwargs.get("only_text", False), url,
image_description_min_word_threshold=kwargs.get( html,
"image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD word_count_threshold=word_count_threshold,
), css_selector=css_selector,
**extra_params, only_text=kwargs.get("only_text", False),
image_description_min_word_threshold=kwargs.get(
"image_description_min_word_threshold",
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
),
**extra_params,
)
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
if verbose:
print(
f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds"
) )
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
if verbose:
print(f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds")
if result is None:
raise ValueError(f"Failed to extract content from the website: {url}")
except InvalidCSSSelectorError as e:
raise ValueError(str(e))
cleaned_html = sanitize_input_encode(result.get("cleaned_html", ""))
markdown = sanitize_input_encode(result.get("markdown", ""))
media = result.get("media", [])
links = result.get("links", [])
metadata = result.get("metadata", {})
if extracted_content is None:
if verbose:
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}")
sections = chunking_strategy.chunk(markdown) if result is None:
extracted_content = extraction_strategy.run(url, sections) raise ValueError(f"Failed to extract content from the website: {url}")
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) except InvalidCSSSelectorError as e:
raise ValueError(str(e))
if verbose: cleaned_html = sanitize_input_encode(result.get("cleaned_html", ""))
print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds.") markdown = sanitize_input_encode(result.get("markdown", ""))
media = result.get("media", [])
screenshot = None if not screenshot else screenshot links = result.get("links", [])
metadata = result.get("metadata", {})
if not is_cached:
cache_url( if extracted_content is None:
url, if verbose:
html, print(
cleaned_html, f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}"
markdown, )
extracted_content,
True, sections = chunking_strategy.chunk(markdown)
json.dumps(media), extracted_content = extraction_strategy.run(url, sections)
json.dumps(links), extracted_content = json.dumps(
json.dumps(metadata), extracted_content, indent=4, default=str, ensure_ascii=False
screenshot=screenshot, )
)
if verbose:
return CrawlResult( print(
url=url, f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds."
html=html, )
cleaned_html=format_html(cleaned_html),
markdown=markdown, screenshot = None if not screenshot else screenshot
media=media,
links=links, if not is_cached:
metadata=metadata, cache_url(
url,
html,
cleaned_html,
markdown,
extracted_content,
True,
json.dumps(media),
json.dumps(links),
json.dumps(metadata),
screenshot=screenshot, screenshot=screenshot,
extracted_content=extracted_content, )
success=True,
error_message="", return CrawlResult(
) url=url,
html=html,
cleaned_html=format_html(cleaned_html),
markdown=markdown,
media=media,
links=links,
metadata=metadata,
screenshot=screenshot,
extracted_content=extracted_content,
success=True,
error_message="",
)

View File

@@ -9,13 +9,11 @@ from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json import json
async def extract_amazon_products(): async def extract_amazon_products():
# Initialize browser config # Initialize browser config
browser_config = BrowserConfig( browser_config = BrowserConfig(browser_type="chromium", headless=True)
browser_type="chromium",
headless=True
)
# Initialize crawler config with JSON CSS extraction strategy # Initialize crawler config with JSON CSS extraction strategy
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
extraction_strategy=JsonCssExtractionStrategy( extraction_strategy=JsonCssExtractionStrategy(
@@ -27,74 +25,70 @@ async def extract_amazon_products():
"name": "asin", "name": "asin",
"selector": "", "selector": "",
"type": "attribute", "type": "attribute",
"attribute": "data-asin" "attribute": "data-asin",
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
}, },
{"name": "title", "selector": "h2 a span", "type": "text"},
{ {
"name": "url", "name": "url",
"selector": "h2 a", "selector": "h2 a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
}, },
{ {
"name": "image", "name": "image",
"selector": ".s-image", "selector": ".s-image",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
}, },
{ {
"name": "rating", "name": "rating",
"selector": ".a-icon-star-small .a-icon-alt", "selector": ".a-icon-star-small .a-icon-alt",
"type": "text" "type": "text",
}, },
{ {
"name": "reviews_count", "name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text" "type": "text",
}, },
{ {
"name": "price", "name": "price",
"selector": ".a-price .a-offscreen", "selector": ".a-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "original_price", "name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen", "selector": ".a-price.a-text-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "sponsored", "name": "sponsored",
"selector": ".puis-sponsored-label-text", "selector": ".puis-sponsored-label-text",
"type": "exists" "type": "exists",
}, },
{ {
"name": "delivery_info", "name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base", "selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text", "type": "text",
"multiple": True "multiple": True,
} },
] ],
} }
) )
) )
# Example search URL (you should replace with your actual Amazon URL) # Example search URL (you should replace with your actual Amazon URL)
url = "https://www.amazon.com/s?k=Samsung+Galaxy+Tab" url = "https://www.amazon.com/s?k=Samsung+Galaxy+Tab"
# Use context manager for proper resource handling # Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
# Extract the data # Extract the data
result = await crawler.arun(url=url, config=crawler_config) result = await crawler.arun(url=url, config=crawler_config)
# Process and print the results # Process and print the results
if result and result.extracted_content: if result and result.extracted_content:
# Parse the JSON string into a list of products # Parse the JSON string into a list of products
products = json.loads(result.extracted_content) products = json.loads(result.extracted_content)
# Process each product in the list # Process each product in the list
for product in products: for product in products:
print("\nProduct Details:") print("\nProduct Details:")
@@ -105,10 +99,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}") print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}") print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'): if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}") print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(extract_amazon_products()) asyncio.run(extract_amazon_products())

View File

@@ -10,17 +10,17 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json import json
from playwright.async_api import Page, BrowserContext from playwright.async_api import Page, BrowserContext
async def extract_amazon_products(): async def extract_amazon_products():
# Initialize browser config # Initialize browser config
browser_config = BrowserConfig( browser_config = BrowserConfig(
# browser_type="chromium", # browser_type="chromium",
headless=True headless=True
) )
# Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button # Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
extraction_strategy=JsonCssExtractionStrategy( extraction_strategy=JsonCssExtractionStrategy(
schema={ schema={
"name": "Amazon Product Search Results", "name": "Amazon Product Search Results",
@@ -30,102 +30,105 @@ async def extract_amazon_products():
"name": "asin", "name": "asin",
"selector": "", "selector": "",
"type": "attribute", "type": "attribute",
"attribute": "data-asin" "attribute": "data-asin",
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
}, },
{"name": "title", "selector": "h2 a span", "type": "text"},
{ {
"name": "url", "name": "url",
"selector": "h2 a", "selector": "h2 a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
}, },
{ {
"name": "image", "name": "image",
"selector": ".s-image", "selector": ".s-image",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
}, },
{ {
"name": "rating", "name": "rating",
"selector": ".a-icon-star-small .a-icon-alt", "selector": ".a-icon-star-small .a-icon-alt",
"type": "text" "type": "text",
}, },
{ {
"name": "reviews_count", "name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text" "type": "text",
}, },
{ {
"name": "price", "name": "price",
"selector": ".a-price .a-offscreen", "selector": ".a-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "original_price", "name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen", "selector": ".a-price.a-text-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "sponsored", "name": "sponsored",
"selector": ".puis-sponsored-label-text", "selector": ".puis-sponsored-label-text",
"type": "exists" "type": "exists",
}, },
{ {
"name": "delivery_info", "name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base", "selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text", "type": "text",
"multiple": True "multiple": True,
} },
] ],
} }
) ),
) )
url = "https://www.amazon.com/" url = "https://www.amazon.com/"
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs): async def after_goto(
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
):
"""Hook called after navigating to each URL""" """Hook called after navigating to each URL"""
print(f"[HOOK] after_goto - Successfully loaded: {url}") print(f"[HOOK] after_goto - Successfully loaded: {url}")
try: try:
# Wait for search box to be available # Wait for search box to be available
search_box = await page.wait_for_selector('#twotabsearchtextbox', timeout=1000) search_box = await page.wait_for_selector(
"#twotabsearchtextbox", timeout=1000
)
# Type the search query # Type the search query
await search_box.fill('Samsung Galaxy Tab') await search_box.fill("Samsung Galaxy Tab")
# Get the search button and prepare for navigation # Get the search button and prepare for navigation
search_button = await page.wait_for_selector('#nav-search-submit-button', timeout=1000) search_button = await page.wait_for_selector(
"#nav-search-submit-button", timeout=1000
)
# Click with navigation waiting # Click with navigation waiting
await search_button.click() await search_button.click()
# Wait for search results to load # Wait for search results to load
await page.wait_for_selector('[data-component-type="s-search-result"]', timeout=10000) await page.wait_for_selector(
'[data-component-type="s-search-result"]', timeout=10000
)
print("[HOOK] Search completed and results loaded!") print("[HOOK] Search completed and results loaded!")
except Exception as e: except Exception as e:
print(f"[HOOK] Error during search operation: {str(e)}") print(f"[HOOK] Error during search operation: {str(e)}")
return page return page
# Use context manager for proper resource handling # Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
crawler.crawler_strategy.set_hook("after_goto", after_goto) crawler.crawler_strategy.set_hook("after_goto", after_goto)
# Extract the data # Extract the data
result = await crawler.arun(url=url, config=crawler_config) result = await crawler.arun(url=url, config=crawler_config)
# Process and print the results # Process and print the results
if result and result.extracted_content: if result and result.extracted_content:
# Parse the JSON string into a list of products # Parse the JSON string into a list of products
products = json.loads(result.extracted_content) products = json.loads(result.extracted_content)
# Process each product in the list # Process each product in the list
for product in products: for product in products:
print("\nProduct Details:") print("\nProduct Details:")
@@ -136,10 +139,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}") print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}") print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'): if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}") print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(extract_amazon_products()) asyncio.run(extract_amazon_products())

View File

@@ -8,7 +8,7 @@ from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
import json import json
from playwright.async_api import Page, BrowserContext
async def extract_amazon_products(): async def extract_amazon_products():
# Initialize browser config # Initialize browser config
@@ -16,7 +16,7 @@ async def extract_amazon_products():
# browser_type="chromium", # browser_type="chromium",
headless=True headless=True
) )
js_code_to_search = """ js_code_to_search = """
const task = async () => { const task = async () => {
document.querySelector('#twotabsearchtextbox').value = 'Samsung Galaxy Tab'; document.querySelector('#twotabsearchtextbox').value = 'Samsung Galaxy Tab';
@@ -30,7 +30,7 @@ async def extract_amazon_products():
""" """
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
js_code = js_code_to_search, js_code=js_code_to_search,
wait_for='css:[data-component-type="s-search-result"]', wait_for='css:[data-component-type="s-search-result"]',
extraction_strategy=JsonCssExtractionStrategy( extraction_strategy=JsonCssExtractionStrategy(
schema={ schema={
@@ -41,75 +41,70 @@ async def extract_amazon_products():
"name": "asin", "name": "asin",
"selector": "", "selector": "",
"type": "attribute", "type": "attribute",
"attribute": "data-asin" "attribute": "data-asin",
},
{
"name": "title",
"selector": "h2 a span",
"type": "text"
}, },
{"name": "title", "selector": "h2 a span", "type": "text"},
{ {
"name": "url", "name": "url",
"selector": "h2 a", "selector": "h2 a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
}, },
{ {
"name": "image", "name": "image",
"selector": ".s-image", "selector": ".s-image",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
}, },
{ {
"name": "rating", "name": "rating",
"selector": ".a-icon-star-small .a-icon-alt", "selector": ".a-icon-star-small .a-icon-alt",
"type": "text" "type": "text",
}, },
{ {
"name": "reviews_count", "name": "reviews_count",
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span", "selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
"type": "text" "type": "text",
}, },
{ {
"name": "price", "name": "price",
"selector": ".a-price .a-offscreen", "selector": ".a-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "original_price", "name": "original_price",
"selector": ".a-price.a-text-price .a-offscreen", "selector": ".a-price.a-text-price .a-offscreen",
"type": "text" "type": "text",
}, },
{ {
"name": "sponsored", "name": "sponsored",
"selector": ".puis-sponsored-label-text", "selector": ".puis-sponsored-label-text",
"type": "exists" "type": "exists",
}, },
{ {
"name": "delivery_info", "name": "delivery_info",
"selector": "[data-cy='delivery-recipe'] .a-color-base", "selector": "[data-cy='delivery-recipe'] .a-color-base",
"type": "text", "type": "text",
"multiple": True "multiple": True,
} },
] ],
} }
) ),
) )
# Example search URL (you should replace with your actual Amazon URL) # Example search URL (you should replace with your actual Amazon URL)
url = "https://www.amazon.com/" url = "https://www.amazon.com/"
# Use context manager for proper resource handling # Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
# Extract the data # Extract the data
result = await crawler.arun(url=url, config=crawler_config) result = await crawler.arun(url=url, config=crawler_config)
# Process and print the results # Process and print the results
if result and result.extracted_content: if result and result.extracted_content:
# Parse the JSON string into a list of products # Parse the JSON string into a list of products
products = json.loads(result.extracted_content) products = json.loads(result.extracted_content)
# Process each product in the list # Process each product in the list
for product in products: for product in products:
print("\nProduct Details:") print("\nProduct Details:")
@@ -120,10 +115,12 @@ async def extract_amazon_products():
print(f"Rating: {product.get('rating')}") print(f"Rating: {product.get('rating')}")
print(f"Reviews: {product.get('reviews_count')}") print(f"Reviews: {product.get('reviews_count')}")
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}") print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
if product.get('delivery_info'): if product.get("delivery_info"):
print(f"Delivery: {' '.join(product['delivery_info'])}") print(f"Delivery: {' '.join(product['delivery_info'])}")
print("-" * 80) print("-" * 80)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(extract_amazon_products()) asyncio.run(extract_amazon_products())

View File

@@ -1,12 +1,16 @@
# File: async_webcrawler_multiple_urls_example.py # File: async_webcrawler_multiple_urls_example.py
import os, sys import os, sys
# append 2 parent directories to sys.path to import crawl4ai # append 2 parent directories to sys.path to import crawl4ai
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir) sys.path.append(parent_dir)
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
async def main(): async def main():
# Initialize the AsyncWebCrawler # Initialize the AsyncWebCrawler
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -16,7 +20,7 @@ async def main():
"https://python.org", "https://python.org",
"https://github.com", "https://github.com",
"https://stackoverflow.com", "https://stackoverflow.com",
"https://news.ycombinator.com" "https://news.ycombinator.com",
] ]
# Set up crawling parameters # Set up crawling parameters
@@ -27,7 +31,7 @@ async def main():
urls=urls, urls=urls,
word_count_threshold=word_count_threshold, word_count_threshold=word_count_threshold,
bypass_cache=True, bypass_cache=True,
verbose=True verbose=True,
) )
# Process the results # Process the results
@@ -36,7 +40,9 @@ async def main():
print(f"Successfully crawled: {result.url}") print(f"Successfully crawled: {result.url}")
print(f"Title: {result.metadata.get('title', 'N/A')}") print(f"Title: {result.metadata.get('title', 'N/A')}")
print(f"Word count: {len(result.markdown.split())}") print(f"Word count: {len(result.markdown.split())}")
print(f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}") print(
f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}"
)
print(f"Number of images: {len(result.media.get('images', []))}") print(f"Number of images: {len(result.media.get('images', []))}")
print("---") print("---")
else: else:
@@ -44,5 +50,6 @@ async def main():
print(f"Error: {result.error_message}") print(f"Error: {result.error_message}")
print("---") print("---")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -6,10 +6,8 @@ This example demonstrates optimal browser usage patterns in Crawl4AI:
""" """
import asyncio import asyncio
import os
from typing import List from typing import List
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator

View File

@@ -1,31 +1,32 @@
import os, time import os, time
# append the path to the root of the project # append the path to the root of the project
import sys import sys
import asyncio import asyncio
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data'
__data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data"
async def compare(): async def compare():
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY']) app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
# Tet Firecrawl with a simple crawl # Tet Firecrawl with a simple crawl
start = time.time() start = time.time()
scrape_status = app.scrape_url( scrape_status = app.scrape_url(
'https://www.nbcnews.com/business', "https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
params={'formats': ['markdown', 'html']}
) )
end = time.time() end = time.time()
print(f"Time taken: {end - start} seconds") print(f"Time taken: {end - start} seconds")
print(len(scrape_status['markdown'])) print(len(scrape_status["markdown"]))
# save the markdown content with provider name # save the markdown content with provider name
with open(f"{__data__}/firecrawl_simple.md", "w") as f: with open(f"{__data__}/firecrawl_simple.md", "w") as f:
f.write(scrape_status['markdown']) f.write(scrape_status["markdown"])
# Count how many "cldnry.s-nbcnews.com" are in the markdown # Count how many "cldnry.s-nbcnews.com" are in the markdown
print(scrape_status['markdown'].count("cldnry.s-nbcnews.com")) print(scrape_status["markdown"].count("cldnry.s-nbcnews.com"))
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
start = time.time() start = time.time()
@@ -33,13 +34,13 @@ async def compare():
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], # js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
word_count_threshold=0, word_count_threshold=0,
bypass_cache=True, bypass_cache=True,
verbose=False verbose=False,
) )
end = time.time() end = time.time()
print(f"Time taken: {end - start} seconds") print(f"Time taken: {end - start} seconds")
print(len(result.markdown)) print(len(result.markdown))
# save the markdown content with provider name # save the markdown content with provider name
with open(f"{__data__}/crawl4ai_simple.md", "w") as f: with open(f"{__data__}/crawl4ai_simple.md", "w") as f:
f.write(result.markdown) f.write(result.markdown)
# count how many "cldnry.s-nbcnews.com" are in the markdown # count how many "cldnry.s-nbcnews.com" are in the markdown
@@ -48,10 +49,12 @@ async def compare():
start = time.time() start = time.time()
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], js_code=[
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
word_count_threshold=0, word_count_threshold=0,
bypass_cache=True, bypass_cache=True,
verbose=False verbose=False,
) )
end = time.time() end = time.time()
print(f"Time taken: {end - start} seconds") print(f"Time taken: {end - start} seconds")
@@ -61,7 +64,7 @@ async def compare():
f.write(result.markdown) f.write(result.markdown)
# count how many "cldnry.s-nbcnews.com" are in the markdown # count how many "cldnry.s-nbcnews.com" are in the markdown
print(result.markdown.count("cldnry.s-nbcnews.com")) print(result.markdown.count("cldnry.s-nbcnews.com"))
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(compare()) asyncio.run(compare())

View File

@@ -3,11 +3,18 @@ import time
from rich import print from rich import print
from rich.table import Table from rich.table import Table
from crawl4ai import ( from crawl4ai import (
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, AsyncWebCrawler,
MemoryAdaptiveDispatcher, SemaphoreDispatcher, BrowserConfig,
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode CrawlerRunConfig,
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
CacheMode,
) )
async def memory_adaptive(urls, browser_config, run_config): async def memory_adaptive(urls, browser_config, run_config):
"""Memory adaptive crawler with monitoring""" """Memory adaptive crawler with monitoring"""
start = time.perf_counter() start = time.perf_counter()
@@ -16,14 +23,16 @@ async def memory_adaptive(urls, browser_config, run_config):
memory_threshold_percent=70.0, memory_threshold_percent=70.0,
max_session_permit=10, max_session_permit=10,
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
async def memory_adaptive_with_rate_limit(urls, browser_config, run_config): async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
"""Memory adaptive crawler with rate limiting""" """Memory adaptive crawler with rate limiting"""
start = time.perf_counter() start = time.perf_counter()
@@ -32,19 +41,19 @@ async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
memory_threshold_percent=70.0, memory_threshold_percent=70.0,
max_session_permit=10, max_session_permit=10,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(1.0, 2.0), base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
max_delay=30.0,
max_retries=2
), ),
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
async def semaphore(urls, browser_config, run_config): async def semaphore(urls, browser_config, run_config):
"""Basic semaphore crawler""" """Basic semaphore crawler"""
start = time.perf_counter() start = time.perf_counter()
@@ -52,14 +61,16 @@ async def semaphore(urls, browser_config, run_config):
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(
semaphore_count=5, semaphore_count=5,
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
async def semaphore_with_rate_limit(urls, browser_config, run_config): async def semaphore_with_rate_limit(urls, browser_config, run_config):
"""Semaphore crawler with rate limiting""" """Semaphore crawler with rate limiting"""
start = time.perf_counter() start = time.perf_counter()
@@ -67,19 +78,19 @@ async def semaphore_with_rate_limit(urls, browser_config, run_config):
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(
semaphore_count=5, semaphore_count=5,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(1.0, 2.0), base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
max_delay=30.0,
max_retries=2
), ),
monitor=CrawlerMonitor( monitor=CrawlerMonitor(
max_visible_rows=15, max_visible_rows=15, display_mode=DisplayMode.DETAILED
display_mode=DisplayMode.DETAILED ),
) )
results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
duration = time.perf_counter() - start duration = time.perf_counter() - start
return len(results), duration return len(results), duration
def create_performance_table(results): def create_performance_table(results):
"""Creates a rich table showing performance results""" """Creates a rich table showing performance results"""
table = Table(title="Crawler Strategy Performance Comparison") table = Table(title="Crawler Strategy Performance Comparison")
@@ -89,18 +100,16 @@ def create_performance_table(results):
table.add_column("URLs/second", justify="right", style="magenta") table.add_column("URLs/second", justify="right", style="magenta")
sorted_results = sorted(results.items(), key=lambda x: x[1][1]) sorted_results = sorted(results.items(), key=lambda x: x[1][1])
for strategy, (urls_crawled, duration) in sorted_results: for strategy, (urls_crawled, duration) in sorted_results:
urls_per_second = urls_crawled / duration urls_per_second = urls_crawled / duration
table.add_row( table.add_row(
strategy, strategy, str(urls_crawled), f"{duration:.2f}", f"{urls_per_second:.2f}"
str(urls_crawled),
f"{duration:.2f}",
f"{urls_per_second:.2f}"
) )
return table return table
async def main(): async def main():
urls = [f"https://example.com/page{i}" for i in range(1, 20)] urls = [f"https://example.com/page{i}" for i in range(1, 20)]
browser_config = BrowserConfig(headless=True, verbose=False) browser_config = BrowserConfig(headless=True, verbose=False)
@@ -108,14 +117,19 @@ async def main():
results = { results = {
"Memory Adaptive": await memory_adaptive(urls, browser_config, run_config), "Memory Adaptive": await memory_adaptive(urls, browser_config, run_config),
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(urls, browser_config, run_config), "Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(
urls, browser_config, run_config
),
"Semaphore": await semaphore(urls, browser_config, run_config), "Semaphore": await semaphore(urls, browser_config, run_config),
"Semaphore + Rate Limit": await semaphore_with_rate_limit(urls, browser_config, run_config), "Semaphore + Rate Limit": await semaphore_with_rate_limit(
urls, browser_config, run_config
),
} }
table = create_performance_table(results) table = create_performance_table(results)
print("\nPerformance Summary:") print("\nPerformance Summary:")
print(table) print(table)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -6,63 +6,80 @@ import base64
import os import os
from typing import Dict, Any from typing import Dict, Any
class Crawl4AiTester: class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
self.base_url = base_url self.base_url = base_url
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" # Check environment variable as fallback self.api_token = (
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} api_token or os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
) # Check environment variable as fallback
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: self.headers = (
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
)
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job # Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
if response.status_code == 403: if response.status_code == 403:
raise Exception("API token is invalid or missing") raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"] task_id = response.json()["task_id"]
print(f"Task ID: {task_id}") print(f"Task ID: {task_id}")
# Poll for result # Poll for result
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) )
result = requests.get(
f"{self.base_url}/task/{task_id}", headers=self.headers
)
status = result.json() status = result.json()
if status["status"] == "failed": if status["status"] == "failed":
print("Task failed:", status.get("error")) print("Task failed:", status.get("error"))
raise Exception(f"Task failed: {status.get('error')}") raise Exception(f"Task failed: {status.get('error')}")
if status["status"] == "completed": if status["status"] == "completed":
return status return status
time.sleep(2) time.sleep(2)
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) response = requests.post(
f"{self.base_url}/crawl_sync",
json=request_data,
headers=self.headers,
timeout=60,
)
if response.status_code == 408: if response.status_code == 408:
raise TimeoutError("Task did not complete within server timeout") raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Directly crawl without using task queue""" """Directly crawl without using task queue"""
response = requests.post( response = requests.post(
f"{self.base_url}/crawl_direct", f"{self.base_url}/crawl_direct", json=request_data, headers=self.headers
json=request_data,
headers=self.headers
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def test_docker_deployment(version="basic"): def test_docker_deployment(version="basic"):
tester = Crawl4AiTester( tester = Crawl4AiTester(
base_url="http://localhost:11235" , base_url="http://localhost:11235",
# base_url="https://api.crawl4ai.com" # just for example # base_url="https://api.crawl4ai.com" # just for example
# api_token="test" # just for example # api_token="test" # just for example
) )
print(f"Testing Crawl4AI Docker {version} version") print(f"Testing Crawl4AI Docker {version} version")
# Health check with timeout and retry # Health check with timeout and retry
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
@@ -70,19 +87,19 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10) health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json()) print("Health check:", health.json())
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
if i == max_retries - 1: if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts") print(f"Failed to connect after {max_retries} attempts")
sys.exit(1) sys.exit(1)
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
time.sleep(5) time.sleep(5)
# Test cases based on version # Test cases based on version
test_basic_crawl_direct(tester) test_basic_crawl_direct(tester)
test_basic_crawl(tester) test_basic_crawl(tester)
test_basic_crawl(tester) test_basic_crawl(tester)
test_basic_crawl_sync(tester) test_basic_crawl_sync(tester)
if version in ["full", "transformer"]: if version in ["full", "transformer"]:
test_cosine_extraction(tester) test_cosine_extraction(tester)
@@ -92,49 +109,52 @@ def test_docker_deployment(version="basic"):
test_llm_extraction(tester) test_llm_extraction(tester)
test_llm_with_ollama(tester) test_llm_with_ollama(tester)
test_screenshot(tester) test_screenshot(tester)
def test_basic_crawl(tester: Crawl4AiTester): def test_basic_crawl(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl ===") print("\n=== Testing Basic Crawl ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0 assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_sync(tester: Crawl4AiTester): def test_basic_crawl_sync(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Sync) ===") print("\n=== Testing Basic Crawl (Sync) ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_sync(request) result = tester.submit_sync(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['status'] == 'completed' assert result["status"] == "completed"
assert result['result']['success'] assert result["result"]["success"]
assert len(result['result']['markdown']) > 0 assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_direct(tester: Crawl4AiTester): def test_basic_crawl_direct(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Direct) ===") print("\n=== Testing Basic Crawl (Direct) ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
# "session_id": "test" # "session_id": "test"
"cache_mode": "bypass" # or "enabled", "disabled", "read_only", "write_only" "cache_mode": "bypass", # or "enabled", "disabled", "read_only", "write_only"
} }
result = tester.crawl_direct(request) result = tester.crawl_direct(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['result']['success'] assert result["result"]["success"]
assert len(result['result']['markdown']) > 0 assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester): def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
request = { request = {
@@ -144,32 +164,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}") print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester): def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description", "css_selector": ".wide-tease-item__description",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True "extra": {"word_count_threshold": 10},
},
"extra": {"word_count_threshold": 10}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}") print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester): def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
schema = { schema = {
@@ -190,21 +207,16 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price", "name": "price",
"selector": "td:nth-child(2)", "selector": "td:nth-child(2)",
"type": "text", "type": "text",
} },
], ],
} }
request = { request = {
"urls": "https://www.coinbase.com/explore", "urls": "https://www.coinbase.com/explore",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted {len(extracted)} items") print(f"Extracted {len(extracted)} items")
@@ -212,6 +224,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester): def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===") print("\n=== Testing LLM Extraction ===")
schema = { schema = {
@@ -219,20 +232,20 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": { "properties": {
"model_name": { "model_name": {
"type": "string", "type": "string",
"description": "Name of the OpenAI model." "description": "Name of the OpenAI model.",
}, },
"input_fee": { "input_fee": {
"type": "string", "type": "string",
"description": "Fee for input token for the OpenAI model." "description": "Fee for input token for the OpenAI model.",
}, },
"output_fee": { "output_fee": {
"type": "string", "type": "string",
"description": "Fee for output token for the OpenAI model." "description": "Fee for output token for the OpenAI model.",
} },
}, },
"required": ["model_name", "input_fee", "output_fee"] "required": ["model_name", "input_fee", "output_fee"],
} }
request = { request = {
"urls": "https://openai.com/api/pricing", "urls": "https://openai.com/api/pricing",
"priority": 8, "priority": 8,
@@ -243,12 +256,12 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"), "api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
} },
}, },
"crawler_params": {"word_count_threshold": 1} "crawler_params": {"word_count_threshold": 1},
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -258,6 +271,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester): def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===") print("\n=== Testing LLM with Ollama ===")
schema = { schema = {
@@ -265,20 +279,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
} },
} }
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 8, "priority": 8,
@@ -288,13 +302,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2", "provider": "ollama/llama2",
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics." "instruction": "Extract the main article information including title, summary, and main topics.",
} },
}, },
"extra": {"word_count_threshold": 1}, "extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True} "crawler_params": {"verbose": True},
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -303,6 +317,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Ollama extraction test failed: {str(e)}") print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester): def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===") print("\n=== Testing Cosine Extraction ===")
request = { request = {
@@ -314,11 +329,11 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy", "semantic_filter": "business finance economy",
"word_count_threshold": 10, "word_count_threshold": 10,
"max_dist": 0.2, "max_dist": 0.2,
"top_k": 3 "top_k": 3,
} },
} },
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -328,30 +343,30 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Cosine extraction test failed: {str(e)}") print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester): def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print("Screenshot captured:", bool(result["result"]["screenshot"])) print("Screenshot captured:", bool(result["result"]["screenshot"]))
if result["result"]["screenshot"]: if result["result"]["screenshot"]:
# Save screenshot # Save screenshot
screenshot_data = base64.b64decode(result["result"]["screenshot"]) screenshot_data = base64.b64decode(result["result"]["screenshot"])
with open("test_screenshot.jpg", "wb") as f: with open("test_screenshot.jpg", "wb") as f:
f.write(screenshot_data) f.write(screenshot_data)
print("Screenshot saved as test_screenshot.jpg") print("Screenshot saved as test_screenshot.jpg")
assert result["result"]["success"] assert result["result"]["success"]
if __name__ == "__main__": if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic" version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full" # version = "full"
test_docker_deployment(version) test_docker_deployment(version)

View File

@@ -9,18 +9,17 @@ This example shows how to:
import asyncio import asyncio
import os import os
from typing import Dict, Any
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
LLMExtractionStrategy, LLMExtractionStrategy,
JsonCssExtractionStrategy, JsonCssExtractionStrategy,
JsonXPathExtractionStrategy JsonXPathExtractionStrategy,
) )
from crawl4ai.chunking_strategy import RegexChunking, IdentityChunking
from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str): async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str):
"""Helper function to run extraction with proper configuration""" """Helper function to run extraction with proper configuration"""
try: try:
@@ -30,78 +29,90 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
extraction_strategy=strategy, extraction_strategy=strategy,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter() # For fit_markdown support content_filter=PruningContentFilter() # For fit_markdown support
) ),
) )
# Run the crawler # Run the crawler
result = await crawler.arun(url=url, config=config) result = await crawler.arun(url=url, config=config)
if result.success: if result.success:
print(f"\n=== {name} Results ===") print(f"\n=== {name} Results ===")
print(f"Extracted Content: {result.extracted_content}") print(f"Extracted Content: {result.extracted_content}")
print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}") print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}")
print(f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}") print(
f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}"
)
else: else:
print(f"Error in {name}: Crawl failed") print(f"Error in {name}: Crawl failed")
except Exception as e: except Exception as e:
print(f"Error in {name}: {str(e)}") print(f"Error in {name}: {str(e)}")
async def main(): async def main():
# Example URL (replace with actual URL) # Example URL (replace with actual URL)
url = "https://example.com/product-page" url = "https://example.com/product-page"
# Configure browser settings # Configure browser settings
browser_config = BrowserConfig( browser_config = BrowserConfig(headless=True, verbose=True)
headless=True,
verbose=True
)
# Initialize extraction strategies # Initialize extraction strategies
# 1. LLM Extraction with different input formats # 1. LLM Extraction with different input formats
markdown_strategy = LLMExtractionStrategy( markdown_strategy = LLMExtractionStrategy(
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information including name, price, and description" instruction="Extract product information including name, price, and description",
) )
html_strategy = LLMExtractionStrategy( html_strategy = LLMExtractionStrategy(
input_format="html", input_format="html",
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information from HTML including structured data" instruction="Extract product information from HTML including structured data",
) )
fit_markdown_strategy = LLMExtractionStrategy( fit_markdown_strategy = LLMExtractionStrategy(
input_format="fit_markdown", input_format="fit_markdown",
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv("OPENAI_API_KEY"), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract product information from cleaned markdown" instruction="Extract product information from cleaned markdown",
) )
# 2. JSON CSS Extraction (automatically uses HTML input) # 2. JSON CSS Extraction (automatically uses HTML input)
css_schema = { css_schema = {
"baseSelector": ".product", "baseSelector": ".product",
"fields": [ "fields": [
{"name": "title", "selector": "h1.product-title", "type": "text"}, {"name": "title", "selector": "h1.product-title", "type": "text"},
{"name": "price", "selector": ".price", "type": "text"}, {"name": "price", "selector": ".price", "type": "text"},
{"name": "description", "selector": ".description", "type": "text"} {"name": "description", "selector": ".description", "type": "text"},
] ],
} }
css_strategy = JsonCssExtractionStrategy(schema=css_schema) css_strategy = JsonCssExtractionStrategy(schema=css_schema)
# 3. JSON XPath Extraction (automatically uses HTML input) # 3. JSON XPath Extraction (automatically uses HTML input)
xpath_schema = { xpath_schema = {
"baseSelector": "//div[@class='product']", "baseSelector": "//div[@class='product']",
"fields": [ "fields": [
{"name": "title", "selector": ".//h1[@class='product-title']/text()", "type": "text"}, {
{"name": "price", "selector": ".//span[@class='price']/text()", "type": "text"}, "name": "title",
{"name": "description", "selector": ".//div[@class='description']/text()", "type": "text"} "selector": ".//h1[@class='product-title']/text()",
] "type": "text",
},
{
"name": "price",
"selector": ".//span[@class='price']/text()",
"type": "text",
},
{
"name": "description",
"selector": ".//div[@class='description']/text()",
"type": "text",
},
],
} }
xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema) xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema)
# Use context manager for proper resource handling # Use context manager for proper resource handling
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
# Run all strategies # Run all strategies
@@ -111,5 +122,6 @@ async def main():
await run_extraction(crawler, url, css_strategy, "CSS Extraction") await run_extraction(crawler, url, css_strategy, "CSS Extraction")
await run_extraction(crawler, url, xpath_strategy, "XPath Extraction") await run_extraction(crawler, url, xpath_strategy, "XPath Extraction")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,20 +1,23 @@
import asyncio import asyncio
from crawl4ai import * from crawl4ai import *
async def main(): async def main():
browser_config = BrowserConfig(headless=True, verbose=True) browser_config = BrowserConfig(headless=True, verbose=True)
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) content_filter=PruningContentFilter(
) threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
),
) )
result = await crawler.arun( result = await crawler.arun(
url="https://www.helloworld.org", url="https://www.helloworld.org", config=crawler_config
config=crawler_config
) )
print(result.markdown_v2.raw_markdown[:500]) print(result.markdown_v2.raw_markdown[:500])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,19 +1,18 @@
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
from playwright.async_api import Page, BrowserContext from playwright.async_api import Page, BrowserContext
async def main(): async def main():
print("🔗 Hooks Example: Demonstrating different hook use cases") print("🔗 Hooks Example: Demonstrating different hook use cases")
# Configure browser settings # Configure browser settings
browser_config = BrowserConfig( browser_config = BrowserConfig(headless=True)
headless=True
)
# Configure crawler settings # Configure crawler settings
crawler_run_config = CrawlerRunConfig( crawler_run_config = CrawlerRunConfig(
js_code="window.scrollTo(0, document.body.scrollHeight);", js_code="window.scrollTo(0, document.body.scrollHeight);",
wait_for="body", wait_for="body",
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
# Create crawler instance # Create crawler instance
@@ -30,16 +29,22 @@ async def main():
"""Hook called after a new page and context are created""" """Hook called after a new page and context are created"""
print("[HOOK] on_page_context_created - New page created!") print("[HOOK] on_page_context_created - New page created!")
# Example: Set default viewport size # Example: Set default viewport size
await context.add_cookies([{ await context.add_cookies(
'name': 'session_id', [
'value': 'example_session', {
'domain': '.example.com', "name": "session_id",
'path': '/' "value": "example_session",
}]) "domain": ".example.com",
"path": "/",
}
]
)
await page.set_viewport_size({"width": 1080, "height": 800}) await page.set_viewport_size({"width": 1080, "height": 800})
return page return page
async def on_user_agent_updated(page: Page, context: BrowserContext, user_agent: str, **kwargs): async def on_user_agent_updated(
page: Page, context: BrowserContext, user_agent: str, **kwargs
):
"""Hook called when the user agent is updated""" """Hook called when the user agent is updated"""
print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}") print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}")
return page return page
@@ -53,17 +58,17 @@ async def main():
"""Hook called before navigating to each URL""" """Hook called before navigating to each URL"""
print(f"[HOOK] before_goto - About to visit: {url}") print(f"[HOOK] before_goto - About to visit: {url}")
# Example: Add custom headers for the request # Example: Add custom headers for the request
await page.set_extra_http_headers({ await page.set_extra_http_headers({"Custom-Header": "my-value"})
"Custom-Header": "my-value"
})
return page return page
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs): async def after_goto(
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
):
"""Hook called after navigating to each URL""" """Hook called after navigating to each URL"""
print(f"[HOOK] after_goto - Successfully loaded: {url}") print(f"[HOOK] after_goto - Successfully loaded: {url}")
# Example: Wait for a specific element to be loaded # Example: Wait for a specific element to be loaded
try: try:
await page.wait_for_selector('.content', timeout=1000) await page.wait_for_selector(".content", timeout=1000)
print("Content element found!") print("Content element found!")
except: except:
print("Content element not found, continuing anyway") print("Content element not found, continuing anyway")
@@ -76,7 +81,9 @@ async def main():
await page.evaluate("window.scrollTo(0, document.body.scrollHeight);") await page.evaluate("window.scrollTo(0, document.body.scrollHeight);")
return page return page
async def before_return_html(page: Page, context: BrowserContext, html:str, **kwargs): async def before_return_html(
page: Page, context: BrowserContext, html: str, **kwargs
):
"""Hook called before returning the HTML content""" """Hook called before returning the HTML content"""
print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})") print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})")
# Example: You could modify the HTML content here if needed # Example: You could modify the HTML content here if needed
@@ -84,7 +91,9 @@ async def main():
# Set all the hooks # Set all the hooks
crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created) crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created)
crawler.crawler_strategy.set_hook("on_page_context_created", on_page_context_created) crawler.crawler_strategy.set_hook(
"on_page_context_created", on_page_context_created
)
crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated) crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated)
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started) crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
crawler.crawler_strategy.set_hook("before_goto", before_goto) crawler.crawler_strategy.set_hook("before_goto", before_goto)
@@ -95,13 +104,15 @@ async def main():
await crawler.start() await crawler.start()
# Example usage: crawl a simple website # Example usage: crawl a simple website
url = 'https://example.com' url = "https://example.com"
result = await crawler.arun(url, config=crawler_run_config) result = await crawler.arun(url, config=crawler_run_config)
print(f"\nCrawled URL: {result.url}") print(f"\nCrawled URL: {result.url}")
print(f"HTML length: {len(result.html)}") print(f"HTML length: {len(result.html)}")
await crawler.close() await crawler.close()
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main())
asyncio.run(main())

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy
async def main(): async def main():
# Example 1: Setting language when creating the crawler # Example 1: Setting language when creating the crawler
crawler1 = AsyncWebCrawler( crawler1 = AsyncWebCrawler(
@@ -9,11 +10,15 @@ async def main():
) )
) )
result1 = await crawler1.arun("https://www.example.com") result1 = await crawler1.arun("https://www.example.com")
print("Example 1 result:", result1.extracted_content[:100]) # Print first 100 characters print(
"Example 1 result:", result1.extracted_content[:100]
) # Print first 100 characters
# Example 2: Setting language before crawling # Example 2: Setting language before crawling
crawler2 = AsyncWebCrawler() crawler2 = AsyncWebCrawler()
crawler2.crawler_strategy.headers["Accept-Language"] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7" crawler2.crawler_strategy.headers[
"Accept-Language"
] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
result2 = await crawler2.arun("https://www.example.com") result2 = await crawler2.arun("https://www.example.com")
print("Example 2 result:", result2.extracted_content[:100]) print("Example 2 result:", result2.extracted_content[:100])
@@ -21,7 +26,7 @@ async def main():
crawler3 = AsyncWebCrawler() crawler3 = AsyncWebCrawler()
result3 = await crawler3.arun( result3 = await crawler3.arun(
"https://www.example.com", "https://www.example.com",
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"} headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"},
) )
print("Example 3 result:", result3.extracted_content[:100]) print("Example 3 result:", result3.extracted_content[:100])
@@ -31,15 +36,15 @@ async def main():
("https://www.example.org", "es-ES,es;q=0.9"), ("https://www.example.org", "es-ES,es;q=0.9"),
("https://www.example.net", "de-DE,de;q=0.9"), ("https://www.example.net", "de-DE,de;q=0.9"),
] ]
crawler4 = AsyncWebCrawler() crawler4 = AsyncWebCrawler()
results = await asyncio.gather(*[ results = await asyncio.gather(
crawler4.arun(url, headers={"Accept-Language": lang}) *[crawler4.arun(url, headers={"Accept-Language": lang}) for url, lang in urls]
for url, lang in urls )
])
for url, result in zip([u for u, _ in urls], results): for url, result in zip([u for u, _ in urls], results):
print(f"Result for {url}:", result.extracted_content[:100]) print(f"Result for {url}:", result.extracted_content[:100])
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -3,32 +3,37 @@ from crawl4ai.crawler_strategy import *
import asyncio import asyncio
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
url = r'https://openai.com/api/pricing/' url = r"https://openai.com/api/pricing/"
class OpenAIModelFee(BaseModel): class OpenAIModelFee(BaseModel):
model_name: str = Field(..., description="Name of the OpenAI model.") model_name: str = Field(..., description="Name of the OpenAI model.")
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.") input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
output_fee: str = Field(..., description="Fee for output token for the OpenAI model.") output_fee: str = Field(
..., description="Fee for output token for the OpenAI model."
)
from crawl4ai import AsyncWebCrawler from crawl4ai import AsyncWebCrawler
async def main(): async def main():
# Use AsyncWebCrawler # Use AsyncWebCrawler
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url=url, url=url,
word_count_threshold=1, word_count_threshold=1,
extraction_strategy= LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
# provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'), # provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
provider= "groq/llama-3.1-70b-versatile", api_token = os.getenv('GROQ_API_KEY'), provider="groq/llama-3.1-70b-versatile",
api_token=os.getenv("GROQ_API_KEY"),
schema=OpenAIModelFee.model_json_schema(), schema=OpenAIModelFee.model_json_schema(),
extraction_type="schema", extraction_type="schema",
instruction="From the crawled content, extract all mentioned model names along with their " \ instruction="From the crawled content, extract all mentioned model names along with their "
"fees for input and output tokens. Make sure not to miss anything in the entire content. " \ "fees for input and output tokens. Make sure not to miss anything in the entire content. "
'One extracted model JSON format should look like this: ' \ "One extracted model JSON format should look like this: "
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }' '{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }',
), ),
) )
print("Success:", result.success) print("Success:", result.success)
model_fees = json.loads(result.extracted_content) model_fees = json.loads(result.extracted_content)
@@ -37,4 +42,5 @@ async def main():
with open(".data/data.json", "w", encoding="utf-8") as f: with open(".data/data.json", "w", encoding="utf-8") as f:
f.write(result.extracted_content) f.write(result.extracted_content)
asyncio.run(main()) asyncio.run(main())

View File

@@ -8,12 +8,12 @@ import asyncio
import time import time
import json import json
import re import re
from typing import Dict, List from typing import Dict
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
JsonCssExtractionStrategy, JsonCssExtractionStrategy,
LLMExtractionStrategy, LLMExtractionStrategy,
@@ -62,6 +62,7 @@ async def clean_content():
print(f"Full Markdown Length: {full_markdown_length}") print(f"Full Markdown Length: {full_markdown_length}")
print(f"Fit Markdown Length: {fit_markdown_length}") print(f"Fit Markdown Length: {fit_markdown_length}")
async def link_analysis(): async def link_analysis():
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.ENABLED, cache_mode=CacheMode.ENABLED,
@@ -76,9 +77,10 @@ async def link_analysis():
print(f"Found {len(result.links['internal'])} internal links") print(f"Found {len(result.links['internal'])} internal links")
print(f"Found {len(result.links['external'])} external links") print(f"Found {len(result.links['external'])} external links")
for link in result.links['internal'][:5]: for link in result.links["internal"][:5]:
print(f"Href: {link['href']}\nText: {link['text']}\n") print(f"Href: {link['href']}\nText: {link['text']}\n")
# JavaScript Execution Example # JavaScript Execution Example
async def simple_example_with_running_js_code(): async def simple_example_with_running_js_code():
print("\n--- Executing JavaScript and Using CSS Selectors ---") print("\n--- Executing JavaScript and Using CSS Selectors ---")
@@ -112,25 +114,29 @@ async def simple_example_with_css_selector():
) )
print(result.markdown[:500]) print(result.markdown[:500])
async def media_handling(): async def media_handling():
crawler_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True) crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True
)
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business", config=crawler_config
config=crawler_config
) )
for img in result.media['images'][:5]: for img in result.media["images"][:5]:
print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}") print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}")
async def custom_hook_workflow(verbose=True): async def custom_hook_workflow(verbose=True):
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
# Set a 'before_goto' hook to run custom code just before navigation # Set a 'before_goto' hook to run custom code just before navigation
crawler.crawler_strategy.set_hook("before_goto", lambda page, context: print("[Hook] Preparing to navigate...")) crawler.crawler_strategy.set_hook(
"before_goto",
lambda page, context: print("[Hook] Preparing to navigate..."),
)
# Perform the crawl operation # Perform the crawl operation
result = await crawler.arun( result = await crawler.arun(url="https://crawl4ai.com")
url="https://crawl4ai.com"
)
print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- ")) print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- "))
@@ -412,21 +418,22 @@ async def cosine_similarity_extraction():
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
extraction_strategy=CosineStrategy( extraction_strategy=CosineStrategy(
word_count_threshold=10, word_count_threshold=10,
max_dist=0.2, # Maximum distance between two words max_dist=0.2, # Maximum distance between two words
linkage_method="ward", # Linkage method for hierarchical clustering (ward, complete, average, single) linkage_method="ward", # Linkage method for hierarchical clustering (ward, complete, average, single)
top_k=3, # Number of top keywords to extract top_k=3, # Number of top keywords to extract
sim_threshold=0.3, # Similarity threshold for clustering sim_threshold=0.3, # Similarity threshold for clustering
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
verbose=True verbose=True,
), ),
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156", url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156",
config=crawl_config config=crawl_config,
) )
print(json.loads(result.extracted_content)[:5]) print(json.loads(result.extracted_content)[:5])
# Browser Comparison # Browser Comparison
async def crawl_custom_browser_type(): async def crawl_custom_browser_type():
print("\n--- Browser Comparison ---") print("\n--- Browser Comparison ---")
@@ -484,39 +491,42 @@ async def crawl_with_user_simulation():
result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config) result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config)
print(result.markdown) print(result.markdown)
async def ssl_certification(): async def ssl_certification():
# Configure crawler to fetch SSL certificate # Configure crawler to fetch SSL certificate
config = CrawlerRunConfig( config = CrawlerRunConfig(
fetch_ssl_certificate=True, fetch_ssl_certificate=True,
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=config)
url='https://example.com',
config=config
)
if result.success and result.ssl_certificate: if result.success and result.ssl_certificate:
cert = result.ssl_certificate cert = result.ssl_certificate
# 1. Access certificate properties directly # 1. Access certificate properties directly
print("\nCertificate Information:") print("\nCertificate Information:")
print(f"Issuer: {cert.issuer.get('CN', '')}") print(f"Issuer: {cert.issuer.get('CN', '')}")
print(f"Valid until: {cert.valid_until}") print(f"Valid until: {cert.valid_until}")
print(f"Fingerprint: {cert.fingerprint}") print(f"Fingerprint: {cert.fingerprint}")
# 2. Export certificate in different formats # 2. Export certificate in different formats
cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis
print("\nCertificate exported to:") print("\nCertificate exported to:")
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}") print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers pem_data = cert.to_pem(
os.path.join(tmp_dir, "certificate.pem")
) # For web servers
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}") print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps der_data = cert.to_der(
os.path.join(tmp_dir, "certificate.der")
) # For Java apps
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}") print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
# Speed Comparison # Speed Comparison
async def speed_comparison(): async def speed_comparison():
print("\n--- Speed Comparison ---") print("\n--- Speed Comparison ---")

View File

@@ -1,6 +1,10 @@
import os, sys import os, sys
# append parent directory to system path # append parent directory to system path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))); os.environ['FIRECRAWL_API_KEY'] = "fc-84b370ccfad44beabc686b38f1769692"; sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
os.environ["FIRECRAWL_API_KEY"] = "fc-84b370ccfad44beabc686b38f1769692"
import asyncio import asyncio
# import nest_asyncio # import nest_asyncio
@@ -15,7 +19,7 @@ from bs4 import BeautifulSoup
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
JsonCssExtractionStrategy, JsonCssExtractionStrategy,
LLMExtractionStrategy, LLMExtractionStrategy,
@@ -32,9 +36,12 @@ print("Website: https://crawl4ai.com")
async def simple_crawl(): async def simple_crawl():
print("\n--- Basic Usage ---") print("\n--- Basic Usage ---")
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun(url="https://www.nbcnews.com/business", cache_mode= CacheMode.BYPASS) result = await crawler.arun(
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def simple_example_with_running_js_code(): async def simple_example_with_running_js_code():
print("\n--- Executing JavaScript and Using CSS Selectors ---") print("\n--- Executing JavaScript and Using CSS Selectors ---")
# New code to handle the wait_for parameter # New code to handle the wait_for parameter
@@ -57,6 +64,7 @@ async def simple_example_with_running_js_code():
) )
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def simple_example_with_css_selector(): async def simple_example_with_css_selector():
print("\n--- Using CSS Selectors ---") print("\n--- Using CSS Selectors ---")
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -67,42 +75,44 @@ async def simple_example_with_css_selector():
) )
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def use_proxy(): async def use_proxy():
print("\n--- Using a Proxy ---") print("\n--- Using a Proxy ---")
print( print(
"Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example." "Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example."
) )
# Uncomment and modify the following lines to use a proxy # Uncomment and modify the following lines to use a proxy
async with AsyncWebCrawler(verbose=True, proxy="http://your-proxy-url:port") as crawler: async with AsyncWebCrawler(
verbose=True, proxy="http://your-proxy-url:port"
) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
cache_mode= CacheMode.BYPASS
) )
if result.success: if result.success:
print(result.markdown[:500]) # Print first 500 characters print(result.markdown[:500]) # Print first 500 characters
async def capture_and_save_screenshot(url: str, output_path: str): async def capture_and_save_screenshot(url: str, output_path: str):
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, screenshot=True, cache_mode=CacheMode.BYPASS
screenshot=True,
cache_mode= CacheMode.BYPASS
) )
if result.success and result.screenshot: if result.success and result.screenshot:
import base64 import base64
# Decode the base64 screenshot data # Decode the base64 screenshot data
screenshot_data = base64.b64decode(result.screenshot) screenshot_data = base64.b64decode(result.screenshot)
# Save the screenshot as a JPEG file # Save the screenshot as a JPEG file
with open(output_path, 'wb') as f: with open(output_path, "wb") as f:
f.write(screenshot_data) f.write(screenshot_data)
print(f"Screenshot saved successfully to {output_path}") print(f"Screenshot saved successfully to {output_path}")
else: else:
print("Failed to capture screenshot") print("Failed to capture screenshot")
class OpenAIModelFee(BaseModel): class OpenAIModelFee(BaseModel):
model_name: str = Field(..., description="Name of the OpenAI model.") model_name: str = Field(..., description="Name of the OpenAI model.")
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.") input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
@@ -110,16 +120,19 @@ class OpenAIModelFee(BaseModel):
..., description="Fee for output token for the OpenAI model." ..., description="Fee for output token for the OpenAI model."
) )
async def extract_structured_data_using_llm(provider: str, api_token: str = None, extra_headers: Dict[str, str] = None):
async def extract_structured_data_using_llm(
provider: str, api_token: str = None, extra_headers: Dict[str, str] = None
):
print(f"\n--- Extracting Structured Data with {provider} ---") print(f"\n--- Extracting Structured Data with {provider} ---")
if api_token is None and provider != "ollama": if api_token is None and provider != "ollama":
print(f"API token is required for {provider}. Skipping this example.") print(f"API token is required for {provider}. Skipping this example.")
return return
# extra_args = {} # extra_args = {}
extra_args={ extra_args = {
"temperature": 0, "temperature": 0,
"top_p": 0.9, "top_p": 0.9,
"max_tokens": 2000, "max_tokens": 2000,
# any other supported parameters for litellm # any other supported parameters for litellm
@@ -139,52 +152,49 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None
instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens. instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens.
Do not miss any models in the entire content. One extracted model JSON format should look like this: Do not miss any models in the entire content. One extracted model JSON format should look like this:
{"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""", {"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""",
extra_args=extra_args extra_args=extra_args,
), ),
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
) )
print(result.extracted_content) print(result.extracted_content)
async def extract_structured_data_using_css_extractor(): async def extract_structured_data_using_css_extractor():
print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---") print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---")
schema = { schema = {
"name": "KidoCode Courses", "name": "KidoCode Courses",
"baseSelector": "section.charge-methodology .w-tab-content > div", "baseSelector": "section.charge-methodology .w-tab-content > div",
"fields": [ "fields": [
{ {
"name": "section_title", "name": "section_title",
"selector": "h3.heading-50", "selector": "h3.heading-50",
"type": "text", "type": "text",
}, },
{ {
"name": "section_description", "name": "section_description",
"selector": ".charge-content", "selector": ".charge-content",
"type": "text", "type": "text",
}, },
{ {
"name": "course_name", "name": "course_name",
"selector": ".text-block-93", "selector": ".text-block-93",
"type": "text", "type": "text",
}, },
{ {
"name": "course_description", "name": "course_description",
"selector": ".course-content-text", "selector": ".course-content-text",
"type": "text", "type": "text",
}, },
{ {
"name": "course_icon", "name": "course_icon",
"selector": ".image-92", "selector": ".image-92",
"type": "attribute", "type": "attribute",
"attribute": "src" "attribute": "src",
} },
] ],
} }
async with AsyncWebCrawler( async with AsyncWebCrawler(headless=True, verbose=True) as crawler:
headless=True,
verbose=True
) as crawler:
# Create the JavaScript that handles clicking multiple times # Create the JavaScript that handles clicking multiple times
js_click_tabs = """ js_click_tabs = """
(async () => { (async () => {
@@ -198,19 +208,20 @@ async def extract_structured_data_using_css_extractor():
await new Promise(r => setTimeout(r, 500)); await new Promise(r => setTimeout(r, 500));
} }
})(); })();
""" """
result = await crawler.arun( result = await crawler.arun(
url="https://www.kidocode.com/degrees/technology", url="https://www.kidocode.com/degrees/technology",
extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True), extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True),
js_code=[js_click_tabs], js_code=[js_click_tabs],
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
companies = json.loads(result.extracted_content) companies = json.loads(result.extracted_content)
print(f"Successfully extracted {len(companies)} companies") print(f"Successfully extracted {len(companies)} companies")
print(json.dumps(companies[0], indent=2)) print(json.dumps(companies[0], indent=2))
# Advanced Session-Based Crawling with Dynamic Content 🔄 # Advanced Session-Based Crawling with Dynamic Content 🔄
async def crawl_dynamic_content_pages_method_1(): async def crawl_dynamic_content_pages_method_1():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---") print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
@@ -267,6 +278,7 @@ async def crawl_dynamic_content_pages_method_1():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_dynamic_content_pages_method_2(): async def crawl_dynamic_content_pages_method_2():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---") print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
@@ -334,8 +346,11 @@ async def crawl_dynamic_content_pages_method_2():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_dynamic_content_pages_method_3(): async def crawl_dynamic_content_pages_method_3():
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---") print(
"\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---"
)
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://github.com/microsoft/TypeScript/commits/main" url = "https://github.com/microsoft/TypeScript/commits/main"
@@ -357,7 +372,7 @@ async def crawl_dynamic_content_pages_method_3():
const firstCommit = commits[0].textContent.trim(); const firstCommit = commits[0].textContent.trim();
return firstCommit !== window.firstCommit; return firstCommit !== window.firstCommit;
}""" }"""
schema = { schema = {
"name": "Commit Extractor", "name": "Commit Extractor",
"baseSelector": "li.Box-sc-g0xbh4-0", "baseSelector": "li.Box-sc-g0xbh4-0",
@@ -395,40 +410,53 @@ async def crawl_dynamic_content_pages_method_3():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
async def crawl_custom_browser_type(): async def crawl_custom_browser_type():
# Use Firefox # Use Firefox
start = time.time() start = time.time()
async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler: async with AsyncWebCrawler(
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) browser_type="firefox", verbose=True, headless=True
) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) print(result.markdown[:500])
print("Time taken: ", time.time() - start) print("Time taken: ", time.time() - start)
# Use WebKit # Use WebKit
start = time.time() start = time.time()
async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler: async with AsyncWebCrawler(
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) browser_type="webkit", verbose=True, headless=True
) as crawler:
result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) print(result.markdown[:500])
print("Time taken: ", time.time() - start) print("Time taken: ", time.time() - start)
# Use Chromium (default) # Use Chromium (default)
start = time.time() start = time.time()
async with AsyncWebCrawler(verbose=True, headless = True) as crawler: async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS) result = await crawler.arun(
url="https://www.example.com", cache_mode=CacheMode.BYPASS
)
print(result.markdown[:500]) print(result.markdown[:500])
print("Time taken: ", time.time() - start) print("Time taken: ", time.time() - start)
async def crawl_with_user_simultion(): async def crawl_with_user_simultion():
async with AsyncWebCrawler(verbose=True, headless=True) as crawler: async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
url = "YOUR-URL-HERE" url = "YOUR-URL-HERE"
result = await crawler.arun( result = await crawler.arun(
url=url, url=url,
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
magic = True, # Automatically detects and removes overlays, popups, and other elements that block content magic=True, # Automatically detects and removes overlays, popups, and other elements that block content
# simulate_user = True,# Causes a series of random mouse movements and clicks to simulate user interaction # simulate_user = True,# Causes a series of random mouse movements and clicks to simulate user interaction
# override_navigator = True # Overrides the navigator object to make it look like a real user # override_navigator = True # Overrides the navigator object to make it look like a real user
) )
print(result.markdown) print(result.markdown)
async def speed_comparison(): async def speed_comparison():
# print("\n--- Speed Comparison ---") # print("\n--- Speed Comparison ---")
@@ -439,18 +467,18 @@ async def speed_comparison():
# print() # print()
# Simulated Firecrawl performance # Simulated Firecrawl performance
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
start = time.time() start = time.time()
scrape_status = app.scrape_url( scrape_status = app.scrape_url(
'https://www.nbcnews.com/business', "https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
params={'formats': ['markdown', 'html']}
) )
end = time.time() end = time.time()
print("Firecrawl:") print("Firecrawl:")
print(f"Time taken: {end - start:.2f} seconds") print(f"Time taken: {end - start:.2f} seconds")
print(f"Content length: {len(scrape_status['markdown'])} characters") print(f"Content length: {len(scrape_status['markdown'])} characters")
print(f"Images found: {scrape_status['markdown'].count('cldnry.s-nbcnews.com')}") print(f"Images found: {scrape_status['markdown'].count('cldnry.s-nbcnews.com')}")
print() print()
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
# Crawl4AI simple crawl # Crawl4AI simple crawl
@@ -474,7 +502,9 @@ async def speed_comparison():
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
word_count_threshold=0, word_count_threshold=0,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0) # content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
), ),
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
@@ -498,7 +528,9 @@ async def speed_comparison():
word_count_threshold=0, word_count_threshold=0,
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0
)
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0) # content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
), ),
verbose=False, verbose=False,
@@ -520,11 +552,12 @@ async def speed_comparison():
print("If you run these tests in an environment with better network conditions,") print("If you run these tests in an environment with better network conditions,")
print("you may observe an even more significant speed advantage for Crawl4AI.") print("you may observe an even more significant speed advantage for Crawl4AI.")
async def generate_knowledge_graph(): async def generate_knowledge_graph():
class Entity(BaseModel): class Entity(BaseModel):
name: str name: str
description: str description: str
class Relationship(BaseModel): class Relationship(BaseModel):
entity1: Entity entity1: Entity
entity2: Entity entity2: Entity
@@ -536,11 +569,11 @@ async def generate_knowledge_graph():
relationships: List[Relationship] relationships: List[Relationship]
extraction_strategy = LLMExtractionStrategy( extraction_strategy = LLMExtractionStrategy(
provider='openai/gpt-4o-mini', # Or any other provider, including Ollama and open source models provider="openai/gpt-4o-mini", # Or any other provider, including Ollama and open source models
api_token=os.getenv('OPENAI_API_KEY'), # In case of Ollama just pass "no-token" api_token=os.getenv("OPENAI_API_KEY"), # In case of Ollama just pass "no-token"
schema=KnowledgeGraph.model_json_schema(), schema=KnowledgeGraph.model_json_schema(),
extraction_type="schema", extraction_type="schema",
instruction="""Extract entities and relationships from the given text.""" instruction="""Extract entities and relationships from the given text.""",
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
url = "https://paulgraham.com/love.html" url = "https://paulgraham.com/love.html"
@@ -554,27 +587,22 @@ async def generate_knowledge_graph():
with open(os.path.join(__location__, "kb.json"), "w") as f: with open(os.path.join(__location__, "kb.json"), "w") as f:
f.write(result.extracted_content) f.write(result.extracted_content)
async def fit_markdown_remove_overlay(): async def fit_markdown_remove_overlay():
async with AsyncWebCrawler( async with AsyncWebCrawler(
headless=True, # Set to False to see what is happening headless=True, # Set to False to see what is happening
verbose=True, verbose=True,
user_agent_mode="random", user_agent_mode="random",
user_agent_generator_config={ user_agent_generator_config={"device_type": "mobile", "os_type": "android"},
"device_type": "mobile",
"os_type": "android"
},
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.kidocode.com/degrees/technology', url="https://www.kidocode.com/degrees/technology",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator( markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter( content_filter=PruningContentFilter(
threshold=0.48, threshold_type="fixed", min_word_threshold=0 threshold=0.48, threshold_type="fixed", min_word_threshold=0
), ),
options={ options={"ignore_links": True},
"ignore_links": True
}
), ),
# markdown_generator=DefaultMarkdownGenerator( # markdown_generator=DefaultMarkdownGenerator(
# content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0), # content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0),
@@ -583,31 +611,38 @@ async def fit_markdown_remove_overlay():
# } # }
# ), # ),
) )
if result.success: if result.success:
print(len(result.markdown_v2.raw_markdown)) print(len(result.markdown_v2.raw_markdown))
print(len(result.markdown_v2.markdown_with_citations)) print(len(result.markdown_v2.markdown_with_citations))
print(len(result.markdown_v2.fit_markdown)) print(len(result.markdown_v2.fit_markdown))
# Save clean html # Save clean html
with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f: with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f:
f.write(result.cleaned_html) f.write(result.cleaned_html)
with open(os.path.join(__location__, "output/output_raw_markdown.md"), "w") as f: with open(
os.path.join(__location__, "output/output_raw_markdown.md"), "w"
) as f:
f.write(result.markdown_v2.raw_markdown) f.write(result.markdown_v2.raw_markdown)
with open(os.path.join(__location__, "output/output_markdown_with_citations.md"), "w") as f: with open(
f.write(result.markdown_v2.markdown_with_citations) os.path.join(__location__, "output/output_markdown_with_citations.md"),
"w",
with open(os.path.join(__location__, "output/output_fit_markdown.md"), "w") as f: ) as f:
f.write(result.markdown_v2.markdown_with_citations)
with open(
os.path.join(__location__, "output/output_fit_markdown.md"), "w"
) as f:
f.write(result.markdown_v2.fit_markdown) f.write(result.markdown_v2.fit_markdown)
print("Done") print("Done")
async def main(): async def main():
# await extract_structured_data_using_llm("openai/gpt-4o", os.getenv("OPENAI_API_KEY")) # await extract_structured_data_using_llm("openai/gpt-4o", os.getenv("OPENAI_API_KEY"))
# await simple_crawl() # await simple_crawl()
# await simple_example_with_running_js_code() # await simple_example_with_running_js_code()
# await simple_example_with_css_selector() # await simple_example_with_css_selector()
@@ -618,7 +653,7 @@ async def main():
# LLM extraction examples # LLM extraction examples
# await extract_structured_data_using_llm() # await extract_structured_data_using_llm()
# await extract_structured_data_using_llm("huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", os.getenv("HUGGINGFACE_API_KEY")) # await extract_structured_data_using_llm("huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", os.getenv("HUGGINGFACE_API_KEY"))
# await extract_structured_data_using_llm("ollama/llama3.2") # await extract_structured_data_using_llm("ollama/llama3.2")
# You always can pass custom headers to the extraction strategy # You always can pass custom headers to the extraction strategy
# custom_headers = { # custom_headers = {
@@ -626,13 +661,13 @@ async def main():
# "X-Custom-Header": "Some-Value" # "X-Custom-Header": "Some-Value"
# } # }
# await extract_structured_data_using_llm(extra_headers=custom_headers) # await extract_structured_data_using_llm(extra_headers=custom_headers)
# await crawl_dynamic_content_pages_method_1() # await crawl_dynamic_content_pages_method_1()
# await crawl_dynamic_content_pages_method_2() # await crawl_dynamic_content_pages_method_2()
await crawl_dynamic_content_pages_method_3() await crawl_dynamic_content_pages_method_3()
# await crawl_custom_browser_type() # await crawl_custom_browser_type()
# await speed_comparison() # await speed_comparison()

View File

@@ -10,15 +10,17 @@ from functools import lru_cache
console = Console() console = Console()
@lru_cache() @lru_cache()
def create_crawler(): def create_crawler():
crawler = WebCrawler(verbose=True) crawler = WebCrawler(verbose=True)
crawler.warmup() crawler.warmup()
return crawler return crawler
def print_result(result): def print_result(result):
# Print each key in one line and just the first 10 characters of each one's value and three dots # Print each key in one line and just the first 10 characters of each one's value and three dots
console.print(f"\t[bold]Result:[/bold]") console.print("\t[bold]Result:[/bold]")
for key, value in result.model_dump().items(): for key, value in result.model_dump().items():
if isinstance(value, str) and value: if isinstance(value, str) and value:
console.print(f"\t{key}: [green]{value[:20]}...[/green]") console.print(f"\t{key}: [green]{value[:20]}...[/green]")
@@ -33,18 +35,27 @@ def cprint(message, press_any_key=False):
console.print("Press any key to continue...", style="") console.print("Press any key to continue...", style="")
input() input()
def basic_usage(crawler): def basic_usage(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") cprint(
result = crawler.run(url="https://www.nbcnews.com/business", only_text = True) "🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
)
result = crawler.run(url="https://www.nbcnews.com/business", only_text=True)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result) print_result(result)
def basic_usage_some_params(crawler): def basic_usage_some_params(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") cprint(
result = crawler.run(url="https://www.nbcnews.com/business", word_count_threshold=1, only_text = True) "🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
)
result = crawler.run(
url="https://www.nbcnews.com/business", word_count_threshold=1, only_text=True
)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result) print_result(result)
def screenshot_usage(crawler): def screenshot_usage(crawler):
cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]") cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True) result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
@@ -55,16 +66,23 @@ def screenshot_usage(crawler):
cprint("Screenshot saved to 'screenshot.png'!") cprint("Screenshot saved to 'screenshot.png'!")
print_result(result) print_result(result)
def understanding_parameters(crawler): def understanding_parameters(crawler):
cprint("\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]") cprint(
cprint("By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action.") "\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]"
)
cprint(
"By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action."
)
# First crawl (reads from cache) # First crawl (reads from cache)
cprint("1⃣ First crawl (caches the result):", True) cprint("1⃣ First crawl (caches the result):", True)
start_time = time.time() start_time = time.time()
result = crawler.run(url="https://www.nbcnews.com/business") result = crawler.run(url="https://www.nbcnews.com/business")
end_time = time.time() end_time = time.time()
cprint(f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]") cprint(
f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]"
)
print_result(result) print_result(result)
# Force to crawl again # Force to crawl again
@@ -72,169 +90,232 @@ def understanding_parameters(crawler):
start_time = time.time() start_time = time.time()
result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True) result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True)
end_time = time.time() end_time = time.time()
cprint(f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]") cprint(
f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]"
)
print_result(result) print_result(result)
def add_chunking_strategy(crawler): def add_chunking_strategy(crawler):
# Adding a chunking strategy: RegexChunking # Adding a chunking strategy: RegexChunking
cprint("\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", True) cprint(
cprint("RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!") "\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]",
True,
)
cprint(
"RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
chunking_strategy=RegexChunking(patterns=["\n\n"]) chunking_strategy=RegexChunking(patterns=["\n\n"]),
) )
cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]")
print_result(result) print_result(result)
# Adding another chunking strategy: NlpSentenceChunking # Adding another chunking strategy: NlpSentenceChunking
cprint("\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", True) cprint(
cprint("NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!") "\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]",
True,
)
cprint(
"NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business", chunking_strategy=NlpSentenceChunking()
chunking_strategy=NlpSentenceChunking()
) )
cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]")
print_result(result) print_result(result)
def add_extraction_strategy(crawler): def add_extraction_strategy(crawler):
# Adding an extraction strategy: CosineStrategy # Adding an extraction strategy: CosineStrategy
cprint("\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", True) cprint(
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!") "\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]",
True,
)
cprint(
"CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True) extraction_strategy=CosineStrategy(
word_count_threshold=10,
max_dist=0.2,
linkage_method="ward",
top_k=3,
sim_threshold=0.3,
verbose=True,
),
) )
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
print_result(result) print_result(result)
# Using semantic_filter with CosineStrategy # Using semantic_filter with CosineStrategy
cprint("You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!") cprint(
"You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=CosineStrategy( extraction_strategy=CosineStrategy(
semantic_filter="inflation rent prices", semantic_filter="inflation rent prices",
) ),
)
cprint(
"[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]")
print_result(result) print_result(result)
def add_llm_extraction_strategy(crawler): def add_llm_extraction_strategy(crawler):
# Adding an LLM extraction strategy without instructions # Adding an LLM extraction strategy without instructions
cprint("\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", True) cprint(
cprint("LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!") "\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]",
True,
)
cprint(
"LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-4o", api_token=os.getenv('OPENAI_API_KEY')) extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", api_token=os.getenv("OPENAI_API_KEY")
),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]")
print_result(result) print_result(result)
# Adding an LLM extraction strategy with instructions # Adding an LLM extraction strategy with instructions
cprint("\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", True) cprint(
cprint("Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!") "\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]",
True,
)
cprint(
"Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!"
)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", provider="openai/gpt-4o",
api_token=os.getenv('OPENAI_API_KEY'), api_token=os.getenv("OPENAI_API_KEY"),
instruction="I am interested in only financial news" instruction="I am interested in only financial news",
) ),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]")
print_result(result) print_result(result)
result = crawler.run( result = crawler.run(
url="https://www.nbcnews.com/business", url="https://www.nbcnews.com/business",
extraction_strategy=LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-4o", provider="openai/gpt-4o",
api_token=os.getenv('OPENAI_API_KEY'), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract only content related to technology" instruction="Extract only content related to technology",
) ),
)
cprint(
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]")
print_result(result) print_result(result)
def targeted_extraction(crawler): def targeted_extraction(crawler):
# Using a CSS selector to extract only H2 tags # Using a CSS selector to extract only H2 tags
cprint("\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", True) cprint(
result = crawler.run( "\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]",
url="https://www.nbcnews.com/business", True,
css_selector="h2"
) )
result = crawler.run(url="https://www.nbcnews.com/business", css_selector="h2")
cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]")
print_result(result) print_result(result)
def interactive_extraction(crawler): def interactive_extraction(crawler):
# Passing JavaScript code to interact with the page # Passing JavaScript code to interact with the page
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True) cprint(
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.") "\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
True,
)
cprint(
"In this example we try to click the 'Load More' button on the page using JavaScript code."
)
js_code = """ js_code = """
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click(); loadMoreButton && loadMoreButton.click();
""" """
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code) # crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True) # crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
result = crawler.run( result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
url="https://www.nbcnews.com/business", cprint(
js = js_code "[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
print_result(result) print_result(result)
def multiple_scrip(crawler): def multiple_scrip(crawler):
# Passing JavaScript code to interact with the page # Passing JavaScript code to interact with the page
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True) cprint(
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.") "\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
js_code = [""" True,
)
cprint(
"In this example we try to click the 'Load More' button on the page using JavaScript code."
)
js_code = [
"""
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click(); loadMoreButton && loadMoreButton.click();
"""] * 2 """
] * 2
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code) # crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True) # crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
result = crawler.run( result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
url="https://www.nbcnews.com/business", cprint(
js = js_code "[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
) )
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
print_result(result) print_result(result)
def using_crawler_hooks(crawler): def using_crawler_hooks(crawler):
# Example usage of the hooks for authentication and setting a cookie # Example usage of the hooks for authentication and setting a cookie
def on_driver_created(driver): def on_driver_created(driver):
print("[HOOK] on_driver_created") print("[HOOK] on_driver_created")
# Example customization: maximize the window # Example customization: maximize the window
driver.maximize_window() driver.maximize_window()
# Example customization: logging in to a hypothetical website # Example customization: logging in to a hypothetical website
driver.get('https://example.com/login') driver.get("https://example.com/login")
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.NAME, 'username')) EC.presence_of_element_located((By.NAME, "username"))
) )
driver.find_element(By.NAME, 'username').send_keys('testuser') driver.find_element(By.NAME, "username").send_keys("testuser")
driver.find_element(By.NAME, 'password').send_keys('password123') driver.find_element(By.NAME, "password").send_keys("password123")
driver.find_element(By.NAME, 'login').click() driver.find_element(By.NAME, "login").click()
WebDriverWait(driver, 10).until( WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, 'welcome')) EC.presence_of_element_located((By.ID, "welcome"))
) )
# Add a custom cookie # Add a custom cookie
driver.add_cookie({'name': 'test_cookie', 'value': 'cookie_value'}) driver.add_cookie({"name": "test_cookie", "value": "cookie_value"})
return driver return driver
def before_get_url(driver): def before_get_url(driver):
print("[HOOK] before_get_url") print("[HOOK] before_get_url")
# Example customization: add a custom header # Example customization: add a custom header
# Enable Network domain for sending headers # Enable Network domain for sending headers
driver.execute_cdp_cmd('Network.enable', {}) driver.execute_cdp_cmd("Network.enable", {})
# Add a custom header # Add a custom header
driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': {'X-Test-Header': 'test'}}) driver.execute_cdp_cmd(
"Network.setExtraHTTPHeaders", {"headers": {"X-Test-Header": "test"}}
)
return driver return driver
def after_get_url(driver): def after_get_url(driver):
print("[HOOK] after_get_url") print("[HOOK] after_get_url")
# Example customization: log the URL # Example customization: log the URL
@@ -246,48 +327,59 @@ def using_crawler_hooks(crawler):
# Example customization: log the HTML # Example customization: log the HTML
print(len(html)) print(len(html))
return driver return driver
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", True) cprint(
"\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]",
True,
)
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True) crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
crawler_strategy.set_hook('on_driver_created', on_driver_created) crawler_strategy.set_hook("on_driver_created", on_driver_created)
crawler_strategy.set_hook('before_get_url', before_get_url) crawler_strategy.set_hook("before_get_url", before_get_url)
crawler_strategy.set_hook('after_get_url', after_get_url) crawler_strategy.set_hook("after_get_url", after_get_url)
crawler_strategy.set_hook('before_return_html', before_return_html) crawler_strategy.set_hook("before_return_html", before_return_html)
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy) crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
crawler.warmup() crawler.warmup()
result = crawler.run(url="https://example.com") result = crawler.run(url="https://example.com")
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
print_result(result= result) print_result(result=result)
def using_crawler_hooks_dleay_example(crawler): def using_crawler_hooks_dleay_example(crawler):
def delay(driver): def delay(driver):
print("Delaying for 5 seconds...") print("Delaying for 5 seconds...")
time.sleep(5) time.sleep(5)
print("Resuming...") print("Resuming...")
def create_crawler(): def create_crawler():
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True) crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
crawler_strategy.set_hook('after_get_url', delay) crawler_strategy.set_hook("after_get_url", delay)
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy) crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
crawler.warmup() crawler.warmup()
return crawler return crawler
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]") cprint(
"\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]"
)
crawler = create_crawler() crawler = create_crawler()
result = crawler.run(url="https://google.com", bypass_cache=True) result = crawler.run(url="https://google.com", bypass_cache=True)
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
print_result(result) print_result(result)
def main(): def main():
cprint("🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]") cprint(
cprint("⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]") "🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]"
cprint("If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files.") )
cprint(
"⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]"
)
cprint(
"If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files."
)
crawler = create_crawler() crawler = create_crawler()
@@ -295,7 +387,7 @@ def main():
basic_usage(crawler) basic_usage(crawler)
# basic_usage_some_params(crawler) # basic_usage_some_params(crawler)
understanding_parameters(crawler) understanding_parameters(crawler)
crawler.always_by_pass_cache = True crawler.always_by_pass_cache = True
screenshot_usage(crawler) screenshot_usage(crawler)
add_chunking_strategy(crawler) add_chunking_strategy(crawler)
@@ -305,8 +397,10 @@ def main():
interactive_extraction(crawler) interactive_extraction(crawler)
multiple_scrip(crawler) multiple_scrip(crawler)
cprint("\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]") cprint(
"\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]"
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -11,7 +11,9 @@ from groq import Groq
# Import threadpools to run the crawl_url function in a separate thread # Import threadpools to run the crawl_url function in a separate thread
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
client = AsyncOpenAI(base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")) client = AsyncOpenAI(
base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")
)
# Instrument the OpenAI client # Instrument the OpenAI client
cl.instrument_openai() cl.instrument_openai()
@@ -25,41 +27,39 @@ settings = {
"presence_penalty": 0, "presence_penalty": 0,
} }
def extract_urls(text): def extract_urls(text):
url_pattern = re.compile(r'(https?://\S+)') url_pattern = re.compile(r"(https?://\S+)")
return url_pattern.findall(text) return url_pattern.findall(text)
def crawl_url(url): def crawl_url(url):
data = { data = {
"urls": [url], "urls": [url],
"include_raw_html": True, "include_raw_html": True,
"word_count_threshold": 10, "word_count_threshold": 10,
"extraction_strategy": "NoExtractionStrategy", "extraction_strategy": "NoExtractionStrategy",
"chunking_strategy": "RegexChunking" "chunking_strategy": "RegexChunking",
} }
response = requests.post("https://crawl4ai.com/crawl", json=data) response = requests.post("https://crawl4ai.com/crawl", json=data)
response_data = response.json() response_data = response.json()
response_data = response_data['results'][0] response_data = response_data["results"][0]
return response_data['markdown'] return response_data["markdown"]
@cl.on_chat_start @cl.on_chat_start
async def on_chat_start(): async def on_chat_start():
cl.user_session.set("session", { cl.user_session.set("session", {"history": [], "context": {}})
"history": [], await cl.Message(content="Welcome to the chat! How can I assist you today?").send()
"context": {}
})
await cl.Message(
content="Welcome to the chat! How can I assist you today?"
).send()
@cl.on_message @cl.on_message
async def on_message(message: cl.Message): async def on_message(message: cl.Message):
user_session = cl.user_session.get("session") user_session = cl.user_session.get("session")
# Extract URLs from the user's message # Extract URLs from the user's message
urls = extract_urls(message.content) urls = extract_urls(message.content)
futures = [] futures = []
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
for url in urls: for url in urls:
@@ -69,16 +69,9 @@ async def on_message(message: cl.Message):
for url, result in zip(urls, results): for url, result in zip(urls, results):
ref_number = f"REF_{len(user_session['context']) + 1}" ref_number = f"REF_{len(user_session['context']) + 1}"
user_session["context"][ref_number] = { user_session["context"][ref_number] = {"url": url, "content": result}
"url": url,
"content": result
}
user_session["history"].append({"role": "user", "content": message.content})
user_session["history"].append({
"role": "user",
"content": message.content
})
# Create a system message that includes the context # Create a system message that includes the context
context_messages = [ context_messages = [
@@ -95,26 +88,17 @@ async def on_message(message: cl.Message):
"If not, there is no need to add a references section. " "If not, there is no need to add a references section. "
"At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n" "At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n"
"\n\n".join(context_messages) "\n\n".join(context_messages)
) ),
} }
else: else:
system_message = { system_message = {"role": "system", "content": "You are a helpful assistant."}
"role": "system",
"content": "You are a helpful assistant."
}
msg = cl.Message(content="") msg = cl.Message(content="")
await msg.send() await msg.send()
# Get response from the LLM # Get response from the LLM
stream = await client.chat.completions.create( stream = await client.chat.completions.create(
messages=[ messages=[system_message, *user_session["history"]], stream=True, **settings
system_message,
*user_session["history"]
],
stream=True,
**settings
) )
assistant_response = "" assistant_response = ""
@@ -124,10 +108,7 @@ async def on_message(message: cl.Message):
await msg.stream_token(token) await msg.stream_token(token)
# Add assistant message to the history # Add assistant message to the history
user_session["history"].append({ user_session["history"].append({"role": "assistant", "content": assistant_response})
"role": "assistant",
"content": assistant_response
})
await msg.update() await msg.update()
# Append the reference section to the assistant's response # Append the reference section to the assistant's response
@@ -154,10 +135,11 @@ async def on_audio_chunk(chunk: cl.AudioChunk):
pass pass
@cl.step(type="tool") @cl.step(type="tool")
async def speech_to_text(audio_file): async def speech_to_text(audio_file):
cli = Groq() cli = Groq()
response = await client.audio.transcriptions.create( response = await client.audio.transcriptions.create(
model="whisper-large-v3", file=audio_file model="whisper-large-v3", file=audio_file
) )
@@ -172,24 +154,19 @@ async def on_audio_end(elements: list[ElementBased]):
audio_buffer.seek(0) # Move the file pointer to the beginning audio_buffer.seek(0) # Move the file pointer to the beginning
audio_file = audio_buffer.read() audio_file = audio_buffer.read()
audio_mime_type: str = cl.user_session.get("audio_mime_type") audio_mime_type: str = cl.user_session.get("audio_mime_type")
start_time = time.time() start_time = time.time()
whisper_input = (audio_buffer.name, audio_file, audio_mime_type) whisper_input = (audio_buffer.name, audio_file, audio_mime_type)
transcription = await speech_to_text(whisper_input) transcription = await speech_to_text(whisper_input)
end_time = time.time() end_time = time.time()
print(f"Transcription took {end_time - start_time} seconds") print(f"Transcription took {end_time - start_time} seconds")
user_msg = cl.Message( user_msg = cl.Message(author="You", type="user_message", content=transcription)
author="You",
type="user_message",
content=transcription
)
await user_msg.send() await user_msg.send()
await on_message(user_msg) await on_message(user_msg)
if __name__ == "__main__": if __name__ == "__main__":
from chainlit.cli import run_chainlit from chainlit.cli import run_chainlit
run_chainlit(__file__) run_chainlit(__file__)

View File

@@ -1,4 +1,3 @@
import requests, base64, os import requests, base64, os
data = { data = {
@@ -6,59 +5,50 @@ data = {
"screenshot": True, "screenshot": True,
} }
response = requests.post("https://crawl4ai.com/crawl", json=data) response = requests.post("https://crawl4ai.com/crawl", json=data)
result = response.json()['results'][0] result = response.json()["results"][0]
print(result.keys()) print(result.keys())
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media', # dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
# 'links', 'screenshot', 'markdown', 'extracted_content', # 'links', 'screenshot', 'markdown', 'extracted_content',
# 'metadata', 'error_message']) # 'metadata', 'error_message'])
with open("screenshot.png", "wb") as f: with open("screenshot.png", "wb") as f:
f.write(base64.b64decode(result['screenshot'])) f.write(base64.b64decode(result["screenshot"]))
# Example of filtering the content using CSS selectors # Example of filtering the content using CSS selectors
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"css_selector": "article", "css_selector": "article",
"screenshot": True, "screenshot": True,
} }
# Example of executing a JS script on the page before extracting the content # Example of executing a JS script on the page before extracting the content
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"screenshot": True, "screenshot": True,
'js' : [""" "js": [
"""
const loadMoreButton = Array.from(document.querySelectorAll('button')). const loadMoreButton = Array.from(document.querySelectorAll('button')).
find(button => button.textContent.includes('Load More')); find(button => button.textContent.includes('Load More'));
loadMoreButton && loadMoreButton.click(); loadMoreButton && loadMoreButton.click();
"""] """
],
} }
# Example of using a custom extraction strategy # Example of using a custom extraction strategy
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"extraction_strategy": "CosineStrategy", "extraction_strategy": "CosineStrategy",
"extraction_strategy_args": { "extraction_strategy_args": {"semantic_filter": "inflation rent prices"},
"semantic_filter": "inflation rent prices"
},
} }
# Example of using LLM to extract content # Example of using LLM to extract content
data = { data = {
"urls": [ "urls": ["https://www.nbcnews.com/business"],
"https://www.nbcnews.com/business"
],
"extraction_strategy": "LLMExtractionStrategy", "extraction_strategy": "LLMExtractionStrategy",
"extraction_strategy_args": { "extraction_strategy_args": {
"provider": "groq/llama3-8b-8192", "provider": "groq/llama3-8b-8192",
"api_token": os.environ.get("GROQ_API_KEY"), "api_token": os.environ.get("GROQ_API_KEY"),
"instruction": """I am interested in only financial news, "instruction": """I am interested in only financial news,
and translate them in French.""" and translate them in French.""",
}, },
} }

View File

@@ -5,42 +5,47 @@ import os
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode
# Create tmp directory if it doesn't exist # Create tmp directory if it doesn't exist
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
tmp_dir = os.path.join(parent_dir, "tmp") tmp_dir = os.path.join(parent_dir, "tmp")
os.makedirs(tmp_dir, exist_ok=True) os.makedirs(tmp_dir, exist_ok=True)
async def main(): async def main():
# Configure crawler to fetch SSL certificate # Configure crawler to fetch SSL certificate
config = CrawlerRunConfig( config = CrawlerRunConfig(
fetch_ssl_certificate=True, fetch_ssl_certificate=True,
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=config)
url='https://example.com',
config=config
)
if result.success and result.ssl_certificate: if result.success and result.ssl_certificate:
cert = result.ssl_certificate cert = result.ssl_certificate
# 1. Access certificate properties directly # 1. Access certificate properties directly
print("\nCertificate Information:") print("\nCertificate Information:")
print(f"Issuer: {cert.issuer.get('CN', '')}") print(f"Issuer: {cert.issuer.get('CN', '')}")
print(f"Valid until: {cert.valid_until}") print(f"Valid until: {cert.valid_until}")
print(f"Fingerprint: {cert.fingerprint}") print(f"Fingerprint: {cert.fingerprint}")
# 2. Export certificate in different formats # 2. Export certificate in different formats
cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis
print("\nCertificate exported to:") print("\nCertificate exported to:")
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}") print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers pem_data = cert.to_pem(
os.path.join(tmp_dir, "certificate.pem")
) # For web servers
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}") print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps der_data = cert.to_der(
os.path.join(tmp_dir, "certificate.der")
) # For Java apps
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}") print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,39 +1,41 @@
import os import os
import time
import json import json
from crawl4ai.web_crawler import WebCrawler from crawl4ai.web_crawler import WebCrawler
from crawl4ai.chunking_strategy import * from crawl4ai.chunking_strategy import *
from crawl4ai.extraction_strategy import * from crawl4ai.extraction_strategy import *
from crawl4ai.crawler_strategy import * from crawl4ai.crawler_strategy import *
url = r'https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot' url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot"
crawler = WebCrawler() crawler = WebCrawler()
crawler.warmup() crawler.warmup()
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class PageSummary(BaseModel): class PageSummary(BaseModel):
title: str = Field(..., description="Title of the page.") title: str = Field(..., description="Title of the page.")
summary: str = Field(..., description="Summary of the page.") summary: str = Field(..., description="Summary of the page.")
brief_summary: str = Field(..., description="Brief summary of the page.") brief_summary: str = Field(..., description="Brief summary of the page.")
keywords: list = Field(..., description="Keywords assigned to the page.") keywords: list = Field(..., description="Keywords assigned to the page.")
result = crawler.run( result = crawler.run(
url=url, url=url,
word_count_threshold=1, word_count_threshold=1,
extraction_strategy= LLMExtractionStrategy( extraction_strategy=LLMExtractionStrategy(
provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'), provider="openai/gpt-4o",
api_token=os.getenv("OPENAI_API_KEY"),
schema=PageSummary.model_json_schema(), schema=PageSummary.model_json_schema(),
extraction_type="schema", extraction_type="schema",
apply_chunking =False, apply_chunking=False,
instruction="From the crawled content, extract the following details: "\ instruction="From the crawled content, extract the following details: "
"1. Title of the page "\ "1. Title of the page "
"2. Summary of the page, which is a detailed summary "\ "2. Summary of the page, which is a detailed summary "
"3. Brief summary of the page, which is a paragraph text "\ "3. Brief summary of the page, which is a paragraph text "
"4. Keywords assigned to the page, which is a list of keywords. "\ "4. Keywords assigned to the page, which is a list of keywords. "
'The extracted JSON format should look like this: '\ "The extracted JSON format should look like this: "
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }' '{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }',
), ),
bypass_cache=True, bypass_cache=True,
) )

View File

@@ -1,4 +1,5 @@
import os, sys import os, sys
# append the parent directory to the sys.path # append the parent directory to the sys.path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
@@ -13,19 +14,18 @@ import json
from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.content_filter_strategy import BM25ContentFilter from crawl4ai.content_filter_strategy import BM25ContentFilter
# 1. File Download Processing Example # 1. File Download Processing Example
async def download_example(): async def download_example():
"""Example of downloading files from Python.org""" """Example of downloading files from Python.org"""
# downloads_path = os.path.join(os.getcwd(), "downloads") # downloads_path = os.path.join(os.getcwd(), "downloads")
downloads_path = os.path.join(Path.home(), ".crawl4ai", "downloads") downloads_path = os.path.join(Path.home(), ".crawl4ai", "downloads")
os.makedirs(downloads_path, exist_ok=True) os.makedirs(downloads_path, exist_ok=True)
print(f"Downloads will be saved to: {downloads_path}") print(f"Downloads will be saved to: {downloads_path}")
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True, downloads_path=downloads_path, verbose=True
downloads_path=downloads_path,
verbose=True
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
@@ -40,9 +40,9 @@ async def download_example():
} }
""", """,
delay_before_return_html=1, # Wait 5 seconds to ensure download starts delay_before_return_html=1, # Wait 5 seconds to ensure download starts
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
if result.downloaded_files: if result.downloaded_files:
print("\nDownload successful!") print("\nDownload successful!")
print("Downloaded files:") print("Downloaded files:")
@@ -52,25 +52,26 @@ async def download_example():
else: else:
print("\nNo files were downloaded") print("\nNo files were downloaded")
# 2. Local File and Raw HTML Processing Example # 2. Local File and Raw HTML Processing Example
async def local_and_raw_html_example(): async def local_and_raw_html_example():
"""Example of processing local files and raw HTML""" """Example of processing local files and raw HTML"""
# Create a sample HTML file # Create a sample HTML file
sample_file = os.path.join(__data__, "sample.html") sample_file = os.path.join(__data__, "sample.html")
with open(sample_file, "w") as f: with open(sample_file, "w") as f:
f.write(""" f.write(
"""
<html><body> <html><body>
<h1>Test Content</h1> <h1>Test Content</h1>
<p>This is a test paragraph.</p> <p>This is a test paragraph.</p>
</body></html> </body></html>
""") """
)
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
# Process local file # Process local file
local_result = await crawler.arun( local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}")
url=f"file://{os.path.abspath(sample_file)}"
)
# Process raw HTML # Process raw HTML
raw_html = """ raw_html = """
<html><body> <html><body>
@@ -78,16 +79,15 @@ async def local_and_raw_html_example():
<p>This is a test of raw HTML processing.</p> <p>This is a test of raw HTML processing.</p>
</body></html> </body></html>
""" """
raw_result = await crawler.arun( raw_result = await crawler.arun(url=f"raw:{raw_html}")
url=f"raw:{raw_html}"
)
# Clean up # Clean up
os.remove(sample_file) os.remove(sample_file)
print("Local file content:", local_result.markdown) print("Local file content:", local_result.markdown)
print("\nRaw HTML content:", raw_result.markdown) print("\nRaw HTML content:", raw_result.markdown)
# 3. Enhanced Markdown Generation Example # 3. Enhanced Markdown Generation Example
async def markdown_generation_example(): async def markdown_generation_example():
"""Example of enhanced markdown generation with citations and LLM-friendly features""" """Example of enhanced markdown generation with citations and LLM-friendly features"""
@@ -97,58 +97,66 @@ async def markdown_generation_example():
# user_query="History and cultivation", # user_query="History and cultivation",
bm25_threshold=1.0 bm25_threshold=1.0
) )
result = await crawler.arun( result = await crawler.arun(
url="https://en.wikipedia.org/wiki/Apple", url="https://en.wikipedia.org/wiki/Apple",
css_selector="main div#bodyContent", css_selector="main div#bodyContent",
content_filter=content_filter, content_filter=content_filter,
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
from crawl4ai import AsyncWebCrawler
from crawl4ai.content_filter_strategy import BM25ContentFilter from crawl4ai.content_filter_strategy import BM25ContentFilter
result = await crawler.arun( result = await crawler.arun(
url="https://en.wikipedia.org/wiki/Apple", url="https://en.wikipedia.org/wiki/Apple",
css_selector="main div#bodyContent", css_selector="main div#bodyContent",
content_filter=BM25ContentFilter() content_filter=BM25ContentFilter(),
) )
print(result.markdown_v2.fit_markdown) print(result.markdown_v2.fit_markdown)
print("\nMarkdown Generation Results:") print("\nMarkdown Generation Results:")
print(f"1. Original markdown length: {len(result.markdown)}") print(f"1. Original markdown length: {len(result.markdown)}")
print(f"2. New markdown versions (markdown_v2):") print("2. New markdown versions (markdown_v2):")
print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}") print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}")
print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}") print(
print(f" - References section length: {len(result.markdown_v2.references_markdown)}") f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}"
)
print(
f" - References section length: {len(result.markdown_v2.references_markdown)}"
)
if result.markdown_v2.fit_markdown: if result.markdown_v2.fit_markdown:
print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}") print(
f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}"
)
# Save examples to files # Save examples to files
output_dir = os.path.join(__data__, "markdown_examples") output_dir = os.path.join(__data__, "markdown_examples")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Save different versions # Save different versions
with open(os.path.join(output_dir, "1_raw_markdown.md"), "w") as f: with open(os.path.join(output_dir, "1_raw_markdown.md"), "w") as f:
f.write(result.markdown_v2.raw_markdown) f.write(result.markdown_v2.raw_markdown)
with open(os.path.join(output_dir, "2_citations_markdown.md"), "w") as f: with open(os.path.join(output_dir, "2_citations_markdown.md"), "w") as f:
f.write(result.markdown_v2.markdown_with_citations) f.write(result.markdown_v2.markdown_with_citations)
with open(os.path.join(output_dir, "3_references.md"), "w") as f: with open(os.path.join(output_dir, "3_references.md"), "w") as f:
f.write(result.markdown_v2.references_markdown) f.write(result.markdown_v2.references_markdown)
if result.markdown_v2.fit_markdown: if result.markdown_v2.fit_markdown:
with open(os.path.join(output_dir, "4_filtered_markdown.md"), "w") as f: with open(os.path.join(output_dir, "4_filtered_markdown.md"), "w") as f:
f.write(result.markdown_v2.fit_markdown) f.write(result.markdown_v2.fit_markdown)
print(f"\nMarkdown examples saved to: {output_dir}") print(f"\nMarkdown examples saved to: {output_dir}")
# Show a sample of citations and references # Show a sample of citations and references
print("\nSample of markdown with citations:") print("\nSample of markdown with citations:")
print(result.markdown_v2.markdown_with_citations[:500] + "...\n") print(result.markdown_v2.markdown_with_citations[:500] + "...\n")
print("Sample of references:") print("Sample of references:")
print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...") print(
"\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..."
)
# 4. Browser Management Example # 4. Browser Management Example
async def browser_management_example(): async def browser_management_example():
@@ -156,38 +164,38 @@ async def browser_management_example():
# Use the specified user directory path # Use the specified user directory path
user_data_dir = os.path.join(Path.home(), ".crawl4ai", "browser_profile") user_data_dir = os.path.join(Path.home(), ".crawl4ai", "browser_profile")
os.makedirs(user_data_dir, exist_ok=True) os.makedirs(user_data_dir, exist_ok=True)
print(f"Browser profile will be saved to: {user_data_dir}") print(f"Browser profile will be saved to: {user_data_dir}")
async with AsyncWebCrawler( async with AsyncWebCrawler(
use_managed_browser=True, use_managed_browser=True,
user_data_dir=user_data_dir, user_data_dir=user_data_dir,
headless=False, headless=False,
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://crawl4ai.com", url="https://crawl4ai.com",
# session_id="persistent_session_1", # session_id="persistent_session_1",
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
# Use GitHub as an example - it's a good test for browser management # Use GitHub as an example - it's a good test for browser management
# because it requires proper browser handling # because it requires proper browser handling
result = await crawler.arun( result = await crawler.arun(
url="https://github.com/trending", url="https://github.com/trending",
# session_id="persistent_session_1", # session_id="persistent_session_1",
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
print("\nBrowser session result:", result.success) print("\nBrowser session result:", result.success)
if result.success: if result.success:
print("Page title:", result.metadata.get('title', 'No title found')) print("Page title:", result.metadata.get("title", "No title found"))
# 5. API Usage Example # 5. API Usage Example
async def api_example(): async def api_example():
"""Example of using the new API endpoints""" """Example of using the new API endpoints"""
api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
headers = {'Authorization': f'Bearer {api_token}'} headers = {"Authorization": f"Bearer {api_token}"}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Submit crawl job # Submit crawl job
crawl_request = { crawl_request = {
@@ -199,25 +207,17 @@ async def api_example():
"name": "Hacker News Articles", "name": "Hacker News Articles",
"baseSelector": ".athing", "baseSelector": ".athing",
"fields": [ "fields": [
{ {"name": "title", "selector": ".title a", "type": "text"},
"name": "title", {"name": "score", "selector": ".score", "type": "text"},
"selector": ".title a",
"type": "text"
},
{
"name": "score",
"selector": ".score",
"type": "text"
},
{ {
"name": "url", "name": "url",
"selector": ".title a", "selector": ".title a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
} },
] ],
} }
} },
}, },
"crawler_params": { "crawler_params": {
"headless": True, "headless": True,
@@ -227,51 +227,50 @@ async def api_example():
# "screenshot": True, # "screenshot": True,
# "magic": True # "magic": True
} }
async with session.post( async with session.post(
"http://localhost:11235/crawl", "http://localhost:11235/crawl", json=crawl_request, headers=headers
json=crawl_request,
headers=headers
) as response: ) as response:
task_data = await response.json() task_data = await response.json()
task_id = task_data["task_id"] task_id = task_data["task_id"]
# Check task status # Check task status
while True: while True:
async with session.get( async with session.get(
f"http://localhost:11235/task/{task_id}", f"http://localhost:11235/task/{task_id}", headers=headers
headers=headers
) as status_response: ) as status_response:
result = await status_response.json() result = await status_response.json()
print(f"Task status: {result['status']}") print(f"Task status: {result['status']}")
if result["status"] == "completed": if result["status"] == "completed":
print("Task completed!") print("Task completed!")
print("Results:") print("Results:")
news = json.loads(result["results"][0]['extracted_content']) news = json.loads(result["results"][0]["extracted_content"])
print(json.dumps(news[:4], indent=2)) print(json.dumps(news[:4], indent=2))
break break
else: else:
await asyncio.sleep(1) await asyncio.sleep(1)
# Main execution # Main execution
async def main(): async def main():
# print("Running Crawl4AI feature examples...") # print("Running Crawl4AI feature examples...")
# print("\n1. Running Download Example:") # print("\n1. Running Download Example:")
# await download_example() # await download_example()
# print("\n2. Running Markdown Generation Example:") # print("\n2. Running Markdown Generation Example:")
# await markdown_generation_example() # await markdown_generation_example()
# # print("\n3. Running Local and Raw HTML Example:") # # print("\n3. Running Local and Raw HTML Example:")
# await local_and_raw_html_example() # await local_and_raw_html_example()
# # print("\n4. Running Browser Management Example:") # # print("\n4. Running Browser Management Example:")
await browser_management_example() await browser_management_example()
# print("\n5. Running API Example:") # print("\n5. Running API Example:")
await api_example() await api_example()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -10,18 +10,17 @@ import asyncio
import os import os
import json import json
import re import re
from typing import List, Optional, Dict, Any from typing import List
from pydantic import BaseModel, Field
from crawl4ai import ( from crawl4ai import (
AsyncWebCrawler, AsyncWebCrawler,
BrowserConfig, BrowserConfig,
CrawlerRunConfig, CrawlerRunConfig,
CacheMode, CacheMode,
LLMExtractionStrategy, LLMExtractionStrategy,
JsonCssExtractionStrategy JsonCssExtractionStrategy,
) )
from crawl4ai.content_filter_strategy import RelevantContentFilter from crawl4ai.content_filter_strategy import RelevantContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
# Sample HTML for demonstrations # Sample HTML for demonstrations
@@ -52,17 +51,18 @@ SAMPLE_HTML = """
</div> </div>
""" """
async def demo_ssl_features(): async def demo_ssl_features():
""" """
Enhanced SSL & Security Features Demo Enhanced SSL & Security Features Demo
----------------------------------- -----------------------------------
This example demonstrates the new SSL certificate handling and security features: This example demonstrates the new SSL certificate handling and security features:
1. Custom certificate paths 1. Custom certificate paths
2. SSL verification options 2. SSL verification options
3. HTTPS error handling 3. HTTPS error handling
4. Certificate validation configurations 4. Certificate validation configurations
These features are particularly useful when: These features are particularly useful when:
- Working with self-signed certificates - Working with self-signed certificates
- Dealing with corporate proxies - Dealing with corporate proxies
@@ -76,14 +76,11 @@ async def demo_ssl_features():
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
fetch_ssl_certificate=True # Enable SSL certificate fetching fetch_ssl_certificate=True, # Enable SSL certificate fetching
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=run_config)
url="https://example.com",
config=run_config
)
print(f"SSL Crawl Success: {result.success}") print(f"SSL Crawl Success: {result.success}")
result.ssl_certificate.to_json( result.ssl_certificate.to_json(
os.path.join(os.getcwd(), "ssl_certificate.json") os.path.join(os.getcwd(), "ssl_certificate.json")
@@ -91,11 +88,12 @@ async def demo_ssl_features():
if not result.success: if not result.success:
print(f"SSL Error: {result.error_message}") print(f"SSL Error: {result.error_message}")
async def demo_content_filtering(): async def demo_content_filtering():
""" """
Smart Content Filtering Demo Smart Content Filtering Demo
---------------------- ----------------------
Demonstrates advanced content filtering capabilities: Demonstrates advanced content filtering capabilities:
1. Custom filter to identify and extract specific content 1. Custom filter to identify and extract specific content
2. Integration with markdown generation 2. Integration with markdown generation
@@ -110,87 +108,90 @@ async def demo_content_filtering():
super().__init__() super().__init__()
# Add news-specific patterns # Add news-specific patterns
self.negative_patterns = re.compile( self.negative_patterns = re.compile(
r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending', r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending",
re.I re.I,
) )
self.min_word_count = 30 # Higher threshold for news content self.min_word_count = 30 # Higher threshold for news content
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: def filter_content(
self, html: str, min_word_threshold: int = None
) -> List[str]:
""" """
Implements news-specific content filtering logic. Implements news-specific content filtering logic.
Args: Args:
html (str): HTML content to be filtered html (str): HTML content to be filtered
min_word_threshold (int, optional): Minimum word count threshold min_word_threshold (int, optional): Minimum word count threshold
Returns: Returns:
List[str]: List of filtered HTML content blocks List[str]: List of filtered HTML content blocks
""" """
if not html or not isinstance(html, str): if not html or not isinstance(html, str):
return [] return []
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
if not soup.body: if not soup.body:
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml') soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
body = soup.find('body') body = soup.find("body")
# Extract chunks with metadata # Extract chunks with metadata
chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count) chunks = self.extract_text_chunks(
body, min_word_threshold or self.min_word_count
)
# Filter chunks based on news-specific criteria # Filter chunks based on news-specific criteria
filtered_chunks = [] filtered_chunks = []
for _, text, tag_type, element in chunks: for _, text, tag_type, element in chunks:
# Skip if element has negative class/id # Skip if element has negative class/id
if self.is_excluded(element): if self.is_excluded(element):
continue continue
# Headers are important in news articles # Headers are important in news articles
if tag_type == 'header': if tag_type == "header":
filtered_chunks.append(self.clean_element(element)) filtered_chunks.append(self.clean_element(element))
continue continue
# For content, check word count and link density # For content, check word count and link density
text = element.get_text(strip=True) text = element.get_text(strip=True)
if len(text.split()) >= (min_word_threshold or self.min_word_count): if len(text.split()) >= (min_word_threshold or self.min_word_count):
# Calculate link density # Calculate link density
links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a')) links_text = " ".join(
a.get_text(strip=True) for a in element.find_all("a")
)
link_density = len(links_text) / len(text) if text else 1 link_density = len(links_text) / len(text) if text else 1
# Accept if link density is reasonable # Accept if link density is reasonable
if link_density < 0.5: if link_density < 0.5:
filtered_chunks.append(self.clean_element(element)) filtered_chunks.append(self.clean_element(element))
return filtered_chunks return filtered_chunks
# Create markdown generator with custom filter # Create markdown generator with custom filter
markdown_gen = DefaultMarkdownGenerator( markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter())
content_filter=CustomNewsFilter()
)
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
markdown_generator=markdown_gen, markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://news.ycombinator.com", url="https://news.ycombinator.com", config=run_config
config=run_config
) )
print("Filtered Content Sample:") print("Filtered Content Sample:")
print(result.markdown[:500]) # Show first 500 chars print(result.markdown[:500]) # Show first 500 chars
async def demo_json_extraction(): async def demo_json_extraction():
""" """
Improved JSON Extraction Demo Improved JSON Extraction Demo
--------------------------- ---------------------------
Demonstrates the enhanced JSON extraction capabilities: Demonstrates the enhanced JSON extraction capabilities:
1. Base element attributes extraction 1. Base element attributes extraction
2. Complex nested structures 2. Complex nested structures
3. Multiple extraction patterns 3. Multiple extraction patterns
Key features shown: Key features shown:
- Extracting attributes from base elements (href, data-* attributes) - Extracting attributes from base elements (href, data-* attributes)
- Processing repeated patterns - Processing repeated patterns
@@ -206,7 +207,7 @@ async def demo_json_extraction():
"baseSelector": "div.article-list", "baseSelector": "div.article-list",
"baseFields": [ "baseFields": [
{"name": "list_id", "type": "attribute", "attribute": "data-list-id"}, {"name": "list_id", "type": "attribute", "attribute": "data-list-id"},
{"name": "category", "type": "attribute", "attribute": "data-category"} {"name": "category", "type": "attribute", "attribute": "data-category"},
], ],
"fields": [ "fields": [
{ {
@@ -214,8 +215,16 @@ async def demo_json_extraction():
"selector": "article.post", "selector": "article.post",
"type": "nested_list", "type": "nested_list",
"baseFields": [ "baseFields": [
{"name": "post_id", "type": "attribute", "attribute": "data-post-id"}, {
{"name": "author_id", "type": "attribute", "attribute": "data-author"} "name": "post_id",
"type": "attribute",
"attribute": "data-post-id",
},
{
"name": "author_id",
"type": "attribute",
"attribute": "data-author",
},
], ],
"fields": [ "fields": [
{ {
@@ -223,60 +232,68 @@ async def demo_json_extraction():
"selector": "h2.title a", "selector": "h2.title a",
"type": "text", "type": "text",
"baseFields": [ "baseFields": [
{"name": "url", "type": "attribute", "attribute": "href"} {
] "name": "url",
"type": "attribute",
"attribute": "href",
}
],
}, },
{ {
"name": "author", "name": "author",
"selector": "div.meta a.author", "selector": "div.meta a.author",
"type": "text", "type": "text",
"baseFields": [ "baseFields": [
{"name": "profile_url", "type": "attribute", "attribute": "href"} {
] "name": "profile_url",
}, "type": "attribute",
{ "attribute": "href",
"name": "date", }
"selector": "span.date", ],
"type": "text"
}, },
{"name": "date", "selector": "span.date", "type": "text"},
{ {
"name": "read_more", "name": "read_more",
"selector": "a.read-more", "selector": "a.read-more",
"type": "nested", "type": "nested",
"fields": [ "fields": [
{"name": "text", "type": "text"}, {"name": "text", "type": "text"},
{"name": "url", "type": "attribute", "attribute": "href"} {
] "name": "url",
} "type": "attribute",
] "attribute": "href",
},
],
},
],
} }
] ],
} }
) )
# Demonstrate extraction from raw HTML # Demonstrate extraction from raw HTML
run_config = CrawlerRunConfig( run_config = CrawlerRunConfig(
extraction_strategy=json_strategy, extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS
cache_mode=CacheMode.BYPASS
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML
config=run_config config=run_config,
) )
print("Extracted Content:") print("Extracted Content:")
print(result.extracted_content) print(result.extracted_content)
async def demo_input_formats(): async def demo_input_formats():
""" """
Input Format Handling Demo Input Format Handling Demo
---------------------- ----------------------
Demonstrates how LLM extraction can work with different input formats: Demonstrates how LLM extraction can work with different input formats:
1. Markdown (default) - Good for simple text extraction 1. Markdown (default) - Good for simple text extraction
2. HTML - Better when you need structure and attributes 2. HTML - Better when you need structure and attributes
This example shows how HTML input can be beneficial when: This example shows how HTML input can be beneficial when:
- You need to understand the DOM structure - You need to understand the DOM structure
- You want to extract both visible text and HTML attributes - You want to extract both visible text and HTML attributes
@@ -350,7 +367,7 @@ async def demo_input_formats():
</footer> </footer>
</div> </div>
""" """
# Use raw:// prefix to pass HTML content directly # Use raw:// prefix to pass HTML content directly
url = f"raw://{dummy_html}" url = f"raw://{dummy_html}"
@@ -359,18 +376,30 @@ async def demo_input_formats():
# Define our schema using Pydantic # Define our schema using Pydantic
class JobRequirement(BaseModel): class JobRequirement(BaseModel):
category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)") category: str = Field(
items: List[str] = Field(description="List of specific requirements in this category") description="Category of the requirement (e.g., Technical, Soft Skills)"
priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context") )
items: List[str] = Field(
description="List of specific requirements in this category"
)
priority: str = Field(
description="Priority level (Required/Preferred) based on the HTML class or context"
)
class JobPosting(BaseModel): class JobPosting(BaseModel):
title: str = Field(description="Job title") title: str = Field(description="Job title")
department: str = Field(description="Department or team") department: str = Field(description="Department or team")
location: str = Field(description="Job location, including remote options") location: str = Field(description="Job location, including remote options")
salary_range: Optional[str] = Field(description="Salary range if specified") salary_range: Optional[str] = Field(description="Salary range if specified")
requirements: List[JobRequirement] = Field(description="Categorized job requirements") requirements: List[JobRequirement] = Field(
application_deadline: Optional[str] = Field(description="Application deadline if specified") description="Categorized job requirements"
contact_info: Optional[dict] = Field(description="Contact information from footer or contact section") )
application_deadline: Optional[str] = Field(
description="Application deadline if specified"
)
contact_info: Optional[dict] = Field(
description="Contact information from footer or contact section"
)
# First try with markdown (default) # First try with markdown (default)
markdown_strategy = LLMExtractionStrategy( markdown_strategy = LLMExtractionStrategy(
@@ -382,7 +411,7 @@ async def demo_input_formats():
Extract job posting details into structured data. Focus on the visible text content Extract job posting details into structured data. Focus on the visible text content
and organize requirements into categories. and organize requirements into categories.
""", """,
input_format="markdown" # default input_format="markdown", # default
) )
# Then with HTML for better structure understanding # Then with HTML for better structure understanding
@@ -400,34 +429,25 @@ async def demo_input_formats():
Use HTML attributes and classes to enhance extraction accuracy. Use HTML attributes and classes to enhance extraction accuracy.
""", """,
input_format="html" # explicitly use HTML input_format="html", # explicitly use HTML
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
# Try with markdown first # Try with markdown first
markdown_config = CrawlerRunConfig( markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy)
extraction_strategy=markdown_strategy markdown_result = await crawler.arun(url=url, config=markdown_config)
)
markdown_result = await crawler.arun(
url=url,
config=markdown_config
)
print("\nMarkdown-based Extraction Result:") print("\nMarkdown-based Extraction Result:")
items = json.loads(markdown_result.extracted_content) items = json.loads(markdown_result.extracted_content)
print(json.dumps(items, indent=2)) print(json.dumps(items, indent=2))
# Then with HTML for better structure understanding # Then with HTML for better structure understanding
html_config = CrawlerRunConfig( html_config = CrawlerRunConfig(extraction_strategy=html_strategy)
extraction_strategy=html_strategy html_result = await crawler.arun(url=url, config=html_config)
)
html_result = await crawler.arun(
url=url,
config=html_config
)
print("\nHTML-based Extraction Result:") print("\nHTML-based Extraction Result:")
items = json.loads(html_result.extracted_content) items = json.loads(html_result.extracted_content)
print(json.dumps(items, indent=2)) print(json.dumps(items, indent=2))
# Main execution # Main execution
async def main(): async def main():
print("Crawl4AI v0.4.24 Feature Walkthrough") print("Crawl4AI v0.4.24 Feature Walkthrough")
@@ -439,5 +459,6 @@ async def main():
await demo_json_extraction() await demo_json_extraction()
# await demo_input_formats() # await demo_input_formats()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

106
main.py
View File

@@ -1,14 +1,9 @@
import asyncio, os import asyncio, os
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.exceptions import RequestValidationError
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import FileResponse
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends, Security from fastapi import Depends, Security
@@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union
import psutil import psutil
import time import time
import uuid import uuid
from collections import defaultdict
from urllib.parse import urlparse
import math import math
import logging import logging
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
import json
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
from crawl4ai.config import MIN_WORD_THRESHOLD from crawl4ai.config import MIN_WORD_THRESHOLD
from crawl4ai.extraction_strategy import ( from crawl4ai.extraction_strategy import (
@@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"
COMPLETED = "completed" COMPLETED = "completed"
FAILED = "failed" FAILED = "failed"
class CrawlerType(str, Enum): class CrawlerType(str, Enum):
BASIC = "basic" BASIC = "basic"
LLM = "llm" LLM = "llm"
COSINE = "cosine" COSINE = "cosine"
JSON_CSS = "json_css" JSON_CSS = "json_css"
class ExtractionConfig(BaseModel): class ExtractionConfig(BaseModel):
type: CrawlerType type: CrawlerType
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
class ChunkingStrategy(BaseModel): class ChunkingStrategy(BaseModel):
type: str type: str
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
class ContentFilter(BaseModel): class ContentFilter(BaseModel):
type: str = "bm25" type: str = "bm25"
params: Dict[str, Any] = {} params: Dict[str, Any] = {}
class CrawlRequest(BaseModel): class CrawlRequest(BaseModel):
urls: Union[HttpUrl, List[HttpUrl]] urls: Union[HttpUrl, List[HttpUrl]]
word_count_threshold: int = MIN_WORD_THRESHOLD word_count_threshold: int = MIN_WORD_THRESHOLD
@@ -77,9 +75,10 @@ class CrawlRequest(BaseModel):
session_id: Optional[str] = None session_id: Optional[str] = None
cache_mode: Optional[CacheMode] = CacheMode.ENABLED cache_mode: Optional[CacheMode] = CacheMode.ENABLED
priority: int = Field(default=5, ge=1, le=10) priority: int = Field(default=5, ge=1, le=10)
ttl: Optional[int] = 3600 ttl: Optional[int] = 3600
crawler_params: Dict[str, Any] = {} crawler_params: Dict[str, Any] = {}
@dataclass @dataclass
class TaskInfo: class TaskInfo:
id: str id: str
@@ -89,6 +88,7 @@ class TaskInfo:
created_at: float = time.time() created_at: float = time.time()
ttl: int = 3600 ttl: int = 3600
class ResourceMonitor: class ResourceMonitor:
def __init__(self, max_concurrent_tasks: int = 10): def __init__(self, max_concurrent_tasks: int = 10):
self.max_concurrent_tasks = max_concurrent_tasks self.max_concurrent_tasks = max_concurrent_tasks
@@ -106,7 +106,9 @@ class ResourceMonitor:
mem_usage = psutil.virtual_memory().percent / 100 mem_usage = psutil.virtual_memory().percent / 100
cpu_usage = psutil.cpu_percent() / 100 cpu_usage = psutil.cpu_percent() / 100
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold) memory_factor = max(
0, (self.memory_threshold - mem_usage) / self.memory_threshold
)
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
self._last_available_slots = math.floor( self._last_available_slots = math.floor(
@@ -116,6 +118,7 @@ class ResourceMonitor:
return self._last_available_slots return self._last_available_slots
class TaskManager: class TaskManager:
def __init__(self, cleanup_interval: int = 300): def __init__(self, cleanup_interval: int = 300):
self.tasks: Dict[str, TaskInfo] = {} self.tasks: Dict[str, TaskInfo] = {}
@@ -149,12 +152,16 @@ class TaskManager:
except asyncio.TimeoutError: except asyncio.TimeoutError:
try: try:
# Then try low priority # Then try low priority
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1) _, task_id = await asyncio.wait_for(
self.low_priority.get(), timeout=0.1
)
return task_id return task_id
except asyncio.TimeoutError: except asyncio.TimeoutError:
return None return None
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None): def update_task(
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
):
if task_id in self.tasks: if task_id in self.tasks:
task_info = self.tasks[task_id] task_info = self.tasks[task_id]
task_info.status = status task_info.status = status
@@ -180,6 +187,7 @@ class TaskManager:
except Exception as e: except Exception as e:
logger.error(f"Error in cleanup loop: {e}") logger.error(f"Error in cleanup loop: {e}")
class CrawlerPool: class CrawlerPool:
def __init__(self, max_size: int = 10): def __init__(self, max_size: int = 10):
self.max_size = max_size self.max_size = max_size
@@ -222,6 +230,7 @@ class CrawlerPool:
await crawler.__aexit__(None, None, None) await crawler.__aexit__(None, None, None)
self.active_crawlers.clear() self.active_crawlers.clear()
class CrawlerService: class CrawlerService:
def __init__(self, max_concurrent_tasks: int = 10): def __init__(self, max_concurrent_tasks: int = 10):
self.resource_monitor = ResourceMonitor(max_concurrent_tasks) self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
@@ -258,10 +267,10 @@ class CrawlerService:
async def submit_task(self, request: CrawlRequest) -> str: async def submit_task(self, request: CrawlRequest) -> str:
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600)
# Store request data with task # Store request data with task
self.task_manager.tasks[task_id].request = request self.task_manager.tasks[task_id].request = request
return task_id return task_id
async def _process_queue(self): async def _process_queue(self):
@@ -286,9 +295,11 @@ class CrawlerService:
try: try:
crawler = await self.crawler_pool.acquire(**request.crawler_params) crawler = await self.crawler_pool.acquire(**request.crawler_params)
extraction_strategy = self._create_extraction_strategy(request.extraction_config) extraction_strategy = self._create_extraction_strategy(
request.extraction_config
)
if isinstance(request.urls, list): if isinstance(request.urls, list):
results = await crawler.arun_many( results = await crawler.arun_many(
urls=[str(url) for url in request.urls], urls=[str(url) for url in request.urls],
@@ -318,16 +329,21 @@ class CrawlerService:
) )
await self.crawler_pool.release(crawler) await self.crawler_pool.release(crawler)
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results) self.task_manager.update_task(
task_id, TaskStatus.COMPLETED, results
)
except Exception as e: except Exception as e:
logger.error(f"Error processing task {task_id}: {str(e)}") logger.error(f"Error processing task {task_id}: {str(e)}")
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e)) self.task_manager.update_task(
task_id, TaskStatus.FAILED, error=str(e)
)
except Exception as e: except Exception as e:
logger.error(f"Error in queue processing: {str(e)}") logger.error(f"Error in queue processing: {str(e)}")
await asyncio.sleep(1) await asyncio.sleep(1)
app = FastAPI(title="Crawl4AI API") app = FastAPI(title="Crawl4AI API")
# CORS configuration # CORS configuration
@@ -344,6 +360,7 @@ app.add_middleware(
security = HTTPBearer() security = HTTPBearer()
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN") CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
if not CRAWL4AI_API_TOKEN: if not CRAWL4AI_API_TOKEN:
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
@@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu
raise HTTPException(status_code=401, detail="Invalid token") raise HTTPException(status_code=401, detail="Invalid token")
return credentials return credentials
def secure_endpoint(): def secure_endpoint():
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set""" """Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
# Check if site directory exists # Check if site directory exists
if os.path.exists(__location__ + "/site"): if os.path.exists(__location__ + "/site"):
# Mount the site directory as a static directory # Mount the site directory as a static directory
@@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site")
crawler_service = CrawlerService() crawler_service = CrawlerService()
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
await crawler_service.start() await crawler_service.start()
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
await crawler_service.stop() await crawler_service.stop()
@app.get("/") @app.get("/")
def read_root(): def read_root():
if os.path.exists(__location__ + "/site"): if os.path.exists(__location__ + "/site"):
@@ -379,12 +401,16 @@ def read_root():
# Return a json response # Return a json response
return {"message": "Crawl4AI API service is running"} return {"message": "Crawl4AI API service is running"}
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) @app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl(request: CrawlRequest) -> Dict[str, str]: async def crawl(request: CrawlRequest) -> Dict[str, str]:
task_id = await crawler_service.submit_task(request) task_id = await crawler_service.submit_task(request)
return {"task_id": task_id} return {"task_id": task_id}
@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
@app.get(
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def get_task_status(task_id: str): async def get_task_status(task_id: str):
task_info = crawler_service.task_manager.get_task(task_id) task_info = crawler_service.task_manager.get_task(task_id)
if not task_info: if not task_info:
@@ -406,36 +432,45 @@ async def get_task_status(task_id: str):
return response return response
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) @app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
task_id = await crawler_service.submit_task(request) task_id = await crawler_service.submit_task(request)
# Wait up to 60 seconds for task completion # Wait up to 60 seconds for task completion
for _ in range(60): for _ in range(60):
task_info = crawler_service.task_manager.get_task(task_id) task_info = crawler_service.task_manager.get_task(task_id)
if not task_info: if not task_info:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
if task_info.status == TaskStatus.COMPLETED: if task_info.status == TaskStatus.COMPLETED:
# Return same format as /task/{task_id} endpoint # Return same format as /task/{task_id} endpoint
if isinstance(task_info.result, list): if isinstance(task_info.result, list):
return {"status": task_info.status, "results": [result.dict() for result in task_info.result]} return {
"status": task_info.status,
"results": [result.dict() for result in task_info.result],
}
return {"status": task_info.status, "result": task_info.result.dict()} return {"status": task_info.status, "result": task_info.result.dict()}
if task_info.status == TaskStatus.FAILED: if task_info.status == TaskStatus.FAILED:
raise HTTPException(status_code=500, detail=task_info.error) raise HTTPException(status_code=500, detail=task_info.error)
await asyncio.sleep(1) await asyncio.sleep(1)
# If we get here, task didn't complete within timeout # If we get here, task didn't complete within timeout
raise HTTPException(status_code=408, detail="Task timed out") raise HTTPException(status_code=408, detail="Task timed out")
@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
@app.post(
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
)
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
try: try:
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config) extraction_strategy = crawler_service._create_extraction_strategy(
request.extraction_config
)
try: try:
if isinstance(request.urls, list): if isinstance(request.urls, list):
results = await crawler.arun_many( results = await crawler.arun_many(
@@ -470,7 +505,8 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
except Exception as e: except Exception as e:
logger.error(f"Error in direct crawl: {str(e)}") logger.error(f"Error in direct crawl: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
available_slots = await crawler_service.resource_monitor.get_available_slots() available_slots = await crawler_service.resource_monitor.get_available_slots()
@@ -482,6 +518,8 @@ async def health_check():
"cpu_usage": psutil.cpu_percent(), "cpu_usage": psutil.cpu_percent(),
} }
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=11235)
uvicorn.run(app, host="0.0.0.0", port=11235)

View File

@@ -51,9 +51,7 @@ setup(
author_email="unclecode@kidocode.com", author_email="unclecode@kidocode.com",
license="MIT", license="MIT",
packages=find_packages(), packages=find_packages(),
package_data={ package_data={"crawl4ai": ["js_snippet/*.js"]},
'crawl4ai': ['js_snippet/*.js']
},
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@@ -1,17 +1,18 @@
import os, sys import os
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) import sys
sys.path.append(parent_dir)
__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__)))
import os, sys
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Assuming that the changes made allow different configurations parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
# Assuming that the changes made allow different configurations
# for managed browser, persistent context, and so forth. # for managed browser, persistent context, and so forth.
async def test_default_headless(): async def test_default_headless():
async with AsyncWebCrawler( async with AsyncWebCrawler(
headless=True, headless=True,
@@ -24,13 +25,14 @@ async def test_default_headless():
# Testing normal ephemeral context # Testing normal ephemeral context
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.kidocode.com/degrees/technology', url="https://www.kidocode.com/degrees/technology",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_default_headless] success:", result.success) print("[test_default_headless] success:", result.success)
print("HTML length:", len(result.html if result.html else "")) print("HTML length:", len(result.html if result.html else ""))
async def test_managed_browser_persistent(): async def test_managed_browser_persistent():
# Treating use_persistent_context=True as managed_browser scenario. # Treating use_persistent_context=True as managed_browser scenario.
async with AsyncWebCrawler( async with AsyncWebCrawler(
@@ -44,13 +46,14 @@ async def test_managed_browser_persistent():
# This should store and reuse profile data across runs # This should store and reuse profile data across runs
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.google.com', url="https://www.google.com",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_managed_browser_persistent] success:", result.success) print("[test_managed_browser_persistent] success:", result.success)
print("HTML length:", len(result.html if result.html else "")) print("HTML length:", len(result.html if result.html else ""))
async def test_session_reuse(): async def test_session_reuse():
# Test creating a session, using it for multiple calls # Test creating a session, using it for multiple calls
session_id = "my_session" session_id = "my_session"
@@ -62,25 +65,25 @@ async def test_session_reuse():
use_managed_browser=False, use_managed_browser=False,
use_persistent_context=False, use_persistent_context=False,
) as crawler: ) as crawler:
# First call: create session # First call: create session
result1 = await crawler.arun( result1 = await crawler.arun(
url='https://www.example.com', url="https://www.example.com",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
session_id=session_id, session_id=session_id,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_session_reuse first call] success:", result1.success) print("[test_session_reuse first call] success:", result1.success)
# Second call: same session, possibly cookie retained # Second call: same session, possibly cookie retained
result2 = await crawler.arun( result2 = await crawler.arun(
url='https://www.example.com/about', url="https://www.example.com/about",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
session_id=session_id, session_id=session_id,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_session_reuse second call] success:", result2.success) print("[test_session_reuse second call] success:", result2.success)
async def test_magic_mode(): async def test_magic_mode():
# Test magic mode with override_navigator and simulate_user # Test magic mode with override_navigator and simulate_user
async with AsyncWebCrawler( async with AsyncWebCrawler(
@@ -95,13 +98,14 @@ async def test_magic_mode():
simulate_user=True, simulate_user=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://www.kidocode.com/degrees/business', url="https://www.kidocode.com/degrees/business",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_magic_mode] success:", result.success) print("[test_magic_mode] success:", result.success)
print("HTML length:", len(result.html if result.html else "")) print("HTML length:", len(result.html if result.html else ""))
async def test_proxy_settings(): async def test_proxy_settings():
# Test with a proxy (if available) to ensure code runs with proxy # Test with a proxy (if available) to ensure code runs with proxy
async with AsyncWebCrawler( async with AsyncWebCrawler(
@@ -113,14 +117,15 @@ async def test_proxy_settings():
use_persistent_context=False, use_persistent_context=False,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://httpbin.org/ip', url="https://httpbin.org/ip",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_proxy_settings] success:", result.success) print("[test_proxy_settings] success:", result.success)
if result.success: if result.success:
print("HTML preview:", result.html[:200] if result.html else "") print("HTML preview:", result.html[:200] if result.html else "")
async def test_ignore_https_errors(): async def test_ignore_https_errors():
# Test ignore HTTPS errors with a self-signed or invalid cert domain # Test ignore HTTPS errors with a self-signed or invalid cert domain
# This is just conceptual, the domain should be one that triggers SSL error. # This is just conceptual, the domain should be one that triggers SSL error.
@@ -134,12 +139,13 @@ async def test_ignore_https_errors():
use_persistent_context=False, use_persistent_context=False,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url='https://self-signed.badssl.com/', url="https://self-signed.badssl.com/",
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}) markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
) )
print("[test_ignore_https_errors] success:", result.success) print("[test_ignore_https_errors] success:", result.success)
async def main(): async def main():
print("Running tests...") print("Running tests...")
# await test_default_headless() # await test_default_headless()
@@ -149,5 +155,6 @@ async def main():
# await test_proxy_settings() # await test_proxy_settings()
await test_ignore_https_errors() await test_ignore_https_errors()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,15 +1,16 @@
import os, sys import os, sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai import AsyncWebCrawler, CacheMode
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
from crawl4ai.chunking_strategy import RegexChunking from crawl4ai.chunking_strategy import RegexChunking
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Category 1: Browser Configuration Tests # Category 1: Browser Configuration Tests
async def test_browser_config_object(): async def test_browser_config_object():
@@ -21,29 +22,31 @@ async def test_browser_config_object():
viewport_height=1080, viewport_height=1080,
use_managed_browser=True, use_managed_browser=True,
user_agent_mode="random", user_agent_mode="random",
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"} user_agent_generator_config={"device_type": "desktop", "os_type": "windows"},
) )
async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler: async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler:
result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS) result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS)
assert result.success, "Browser config crawl failed" assert result.success, "Browser config crawl failed"
assert len(result.html) > 0, "No HTML content retrieved" assert len(result.html) > 0, "No HTML content retrieved"
async def test_browser_performance_config(): async def test_browser_performance_config():
"""Test browser configurations focused on performance""" """Test browser configurations focused on performance"""
browser_config = BrowserConfig( browser_config = BrowserConfig(
text_mode=True, text_mode=True,
light_mode=True, light_mode=True,
extra_args=['--disable-gpu', '--disable-software-rasterizer'], extra_args=["--disable-gpu", "--disable-software-rasterizer"],
ignore_https_errors=True, ignore_https_errors=True,
java_script_enabled=False java_script_enabled=False,
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun('https://example.com') result = await crawler.arun("https://example.com")
assert result.success, "Performance optimized crawl failed" assert result.success, "Performance optimized crawl failed"
assert result.status_code == 200, "Unexpected status code" assert result.status_code == 200, "Unexpected status code"
# Category 2: Content Processing Tests # Category 2: Content Processing Tests
async def test_content_extraction_config(): async def test_content_extraction_config():
"""Test content extraction with various strategies""" """Test content extraction with various strategies"""
@@ -53,24 +56,20 @@ async def test_content_extraction_config():
schema={ schema={
"name": "article", "name": "article",
"baseSelector": "div", "baseSelector": "div",
"fields": [{ "fields": [{"name": "title", "selector": "h1", "type": "text"}],
"name": "title",
"selector": "h1",
"type": "text"
}]
} }
), ),
chunking_strategy=RegexChunking(), chunking_strategy=RegexChunking(),
content_filter=PruningContentFilter() content_filter=PruningContentFilter(),
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(
'https://example.com/article', "https://example.com/article", config=crawler_config
config=crawler_config
) )
assert result.extracted_content is not None, "Content extraction failed" assert result.extracted_content is not None, "Content extraction failed"
assert 'title' in result.extracted_content, "Missing expected content field" assert "title" in result.extracted_content, "Missing expected content field"
# Category 3: Cache and Session Management Tests # Category 3: Cache and Session Management Tests
async def test_cache_and_session_management(): async def test_cache_and_session_management():
@@ -79,25 +78,20 @@ async def test_cache_and_session_management():
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.WRITE_ONLY, cache_mode=CacheMode.WRITE_ONLY,
process_iframes=True, process_iframes=True,
remove_overlay_elements=True remove_overlay_elements=True,
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
# First request - should write to cache # First request - should write to cache
result1 = await crawler.arun( result1 = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
# Second request - should use fresh fetch due to WRITE_ONLY mode # Second request - should use fresh fetch due to WRITE_ONLY mode
result2 = await crawler.arun( result2 = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
assert result1.success and result2.success, "Cache mode crawl failed" assert result1.success and result2.success, "Cache mode crawl failed"
assert result1.html == result2.html, "Inconsistent results between requests" assert result1.html == result2.html, "Inconsistent results between requests"
# Category 4: Media Handling Tests # Category 4: Media Handling Tests
async def test_media_handling_config(): async def test_media_handling_config():
"""Test configurations related to media handling""" """Test configurations related to media handling"""
@@ -107,24 +101,22 @@ async def test_media_handling_config():
viewport_width=1920, viewport_width=1920,
viewport_height=1080, viewport_height=1080,
accept_downloads=True, accept_downloads=True,
downloads_path= os.path.expanduser("~/.crawl4ai/downloads") downloads_path=os.path.expanduser("~/.crawl4ai/downloads"),
) )
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(
screenshot=True, screenshot=True,
pdf=True, pdf=True,
adjust_viewport_to_content=True, adjust_viewport_to_content=True,
wait_for_images=True, wait_for_images=True,
screenshot_height_threshold=20000 screenshot_height_threshold=20000,
) )
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
result = await crawler.arun( result = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
assert result.screenshot is not None, "Screenshot capture failed" assert result.screenshot is not None, "Screenshot capture failed"
assert result.pdf is not None, "PDF generation failed" assert result.pdf is not None, "PDF generation failed"
# Category 5: Anti-Bot and Site Interaction Tests # Category 5: Anti-Bot and Site Interaction Tests
async def test_antibot_config(): async def test_antibot_config():
"""Test configurations for handling anti-bot measures""" """Test configurations for handling anti-bot measures"""
@@ -135,76 +127,64 @@ async def test_antibot_config():
wait_for="js:()=>document.querySelector('body')", wait_for="js:()=>document.querySelector('body')",
delay_before_return_html=1.0, delay_before_return_html=1.0,
log_console=True, log_console=True,
cache_mode=CacheMode.BYPASS cache_mode=CacheMode.BYPASS,
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun("https://example.com", config=crawler_config)
'https://example.com',
config=crawler_config
)
assert result.success, "Anti-bot measure handling failed" assert result.success, "Anti-bot measure handling failed"
# Category 6: Parallel Processing Tests # Category 6: Parallel Processing Tests
async def test_parallel_processing(): async def test_parallel_processing():
"""Test parallel processing capabilities""" """Test parallel processing capabilities"""
crawler_config = CrawlerRunConfig( crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5)
mean_delay=0.5,
max_range=1.0, urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"]
semaphore_count=5
)
urls = [
'https://example.com/1',
'https://example.com/2',
'https://example.com/3'
]
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
results = await crawler.arun_many( results = await crawler.arun_many(urls, config=crawler_config)
urls,
config=crawler_config
)
assert len(results) == len(urls), "Not all URLs were processed" assert len(results) == len(urls), "Not all URLs were processed"
assert all(r.success for r in results), "Some parallel requests failed" assert all(r.success for r in results), "Some parallel requests failed"
# Category 7: Backwards Compatibility Tests # Category 7: Backwards Compatibility Tests
async def test_legacy_parameter_support(): async def test_legacy_parameter_support():
"""Test that legacy parameters still work""" """Test that legacy parameters still work"""
async with AsyncWebCrawler( async with AsyncWebCrawler(
headless=True, headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768
browser_type="chromium",
viewport_width=1024,
viewport_height=768
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
'https://example.com', "https://example.com",
screenshot=True, screenshot=True,
word_count_threshold=200, word_count_threshold=200,
bypass_cache=True, bypass_cache=True,
css_selector=".main-content" css_selector=".main-content",
) )
assert result.success, "Legacy parameter support failed" assert result.success, "Legacy parameter support failed"
# Category 8: Mixed Configuration Tests # Category 8: Mixed Configuration Tests
async def test_mixed_config_usage(): async def test_mixed_config_usage():
"""Test mixing new config objects with legacy parameters""" """Test mixing new config objects with legacy parameters"""
browser_config = BrowserConfig(headless=True) browser_config = BrowserConfig(headless=True)
crawler_config = CrawlerRunConfig(screenshot=True) crawler_config = CrawlerRunConfig(screenshot=True)
async with AsyncWebCrawler( async with AsyncWebCrawler(
config=browser_config, config=browser_config,
verbose=True # legacy parameter verbose=True, # legacy parameter
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
'https://example.com', "https://example.com",
config=crawler_config, config=crawler_config,
cache_mode=CacheMode.BYPASS, # legacy parameter cache_mode=CacheMode.BYPASS, # legacy parameter
css_selector="body" # legacy parameter css_selector="body", # legacy parameter
) )
assert result.success, "Mixed configuration usage failed" assert result.success, "Mixed configuration usage failed"
if __name__ == "__main__": if __name__ == "__main__":
async def run_tests(): async def run_tests():
test_functions = [ test_functions = [
test_browser_config_object, test_browser_config_object,
@@ -217,7 +197,7 @@ if __name__ == "__main__":
# test_legacy_parameter_support, # test_legacy_parameter_support,
# test_mixed_config_usage # test_mixed_config_usage
] ]
for test in test_functions: for test in test_functions:
print(f"\nRunning {test.__name__}...") print(f"\nRunning {test.__name__}...")
try: try:
@@ -227,5 +207,5 @@ if __name__ == "__main__":
print(f"{test.__name__} failed: {str(e)}") print(f"{test.__name__} failed: {str(e)}")
except Exception as e: except Exception as e:
print(f"{test.__name__} error: {str(e)}") print(f"{test.__name__} error: {str(e)}")
asyncio.run(run_tests()) asyncio.run(run_tests())

View File

@@ -4,7 +4,6 @@ import asyncio
import shutil import shutil
from typing import List from typing import List
import tempfile import tempfile
import time
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -12,28 +11,27 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
class TestDownloads: class TestDownloads:
def __init__(self): def __init__(self):
self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_") self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_")
self.download_dir = os.path.join(self.temp_dir, "downloads") self.download_dir = os.path.join(self.temp_dir, "downloads")
os.makedirs(self.download_dir, exist_ok=True) os.makedirs(self.download_dir, exist_ok=True)
self.results: List[str] = [] self.results: List[str] = []
def cleanup(self): def cleanup(self):
shutil.rmtree(self.temp_dir) shutil.rmtree(self.temp_dir)
def log_result(self, test_name: str, success: bool, message: str = ""): def log_result(self, test_name: str, success: bool, message: str = ""):
result = f"{'' if success else ''} {test_name}: {message}" result = f"{'' if success else ''} {test_name}: {message}"
self.results.append(result) self.results.append(result)
print(result) print(result)
async def test_basic_download(self): async def test_basic_download(self):
"""Test basic file download functionality""" """Test basic file download functionality"""
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True, downloads_path=self.download_dir, verbose=True
downloads_path=self.download_dir,
verbose=True
) as crawler: ) as crawler:
# Python.org downloads page typically has stable download links # Python.org downloads page typically has stable download links
result = await crawler.arun( result = await crawler.arun(
@@ -42,14 +40,19 @@ class TestDownloads:
// Click first download link // Click first download link
const downloadLink = document.querySelector('a[href$=".exe"]'); const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click(); if (downloadLink) downloadLink.click();
""" """,
)
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
self.log_result( self.log_result(
"Basic Download", "Basic Download",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result("Basic Download", False, str(e)) self.log_result("Basic Download", False, str(e))
@@ -59,27 +62,32 @@ class TestDownloads:
try: try:
user_data_dir = os.path.join(self.temp_dir, "user_data") user_data_dir = os.path.join(self.temp_dir, "user_data")
os.makedirs(user_data_dir, exist_ok=True) os.makedirs(user_data_dir, exist_ok=True)
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True,
downloads_path=self.download_dir, downloads_path=self.download_dir,
use_persistent_context=True, use_persistent_context=True,
user_data_dir=user_data_dir, user_data_dir=user_data_dir,
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code=""" js_code="""
const downloadLink = document.querySelector('a[href$=".exe"]'); const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click(); if (downloadLink) downloadLink.click();
""" """,
)
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
self.log_result( self.log_result(
"Persistent Context Download", "Persistent Context Download",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result("Persistent Context Download", False, str(e)) self.log_result("Persistent Context Download", False, str(e))
@@ -88,9 +96,7 @@ class TestDownloads:
"""Test multiple simultaneous downloads""" """Test multiple simultaneous downloads"""
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True, downloads_path=self.download_dir, verbose=True
downloads_path=self.download_dir,
verbose=True
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
@@ -98,14 +104,19 @@ class TestDownloads:
// Click multiple download links // Click multiple download links
const downloadLinks = document.querySelectorAll('a[href$=".exe"]'); const downloadLinks = document.querySelectorAll('a[href$=".exe"]');
downloadLinks.forEach(link => link.click()); downloadLinks.forEach(link => link.click());
""" """,
)
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 1
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 1
self.log_result( self.log_result(
"Multiple Downloads", "Multiple Downloads",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "Not enough files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result("Multiple Downloads", False, str(e)) self.log_result("Multiple Downloads", False, str(e))
@@ -113,49 +124,51 @@ class TestDownloads:
async def test_different_browsers(self): async def test_different_browsers(self):
"""Test downloads across different browser types""" """Test downloads across different browser types"""
browsers = ["chromium", "firefox", "webkit"] browsers = ["chromium", "firefox", "webkit"]
for browser_type in browsers: for browser_type in browsers:
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True,
downloads_path=self.download_dir, downloads_path=self.download_dir,
browser_type=browser_type, browser_type=browser_type,
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code=""" js_code="""
const downloadLink = document.querySelector('a[href$=".exe"]'); const downloadLink = document.querySelector('a[href$=".exe"]');
if (downloadLink) downloadLink.click(); if (downloadLink) downloadLink.click();
""" """,
)
success = (
result.downloaded_files is not None
and len(result.downloaded_files) > 0
) )
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
self.log_result( self.log_result(
f"{browser_type.title()} Download", f"{browser_type.title()} Download",
success, success,
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded" f"Downloaded {len(result.downloaded_files or [])} files"
if success
else "No files downloaded",
) )
except Exception as e: except Exception as e:
self.log_result(f"{browser_type.title()} Download", False, str(e)) self.log_result(f"{browser_type.title()} Download", False, str(e))
async def test_edge_cases(self): async def test_edge_cases(self):
"""Test various edge cases""" """Test various edge cases"""
# Test 1: Downloads without specifying download path # Test 1: Downloads without specifying download path
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler:
accept_downloads=True,
verbose=True
) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()" js_code="document.querySelector('a[href$=\".exe\"]').click()",
) )
self.log_result( self.log_result(
"Default Download Path", "Default Download Path",
True, True,
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}" f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}",
) )
except Exception as e: except Exception as e:
self.log_result("Default Download Path", False, str(e)) self.log_result("Default Download Path", False, str(e))
@@ -165,31 +178,34 @@ class TestDownloads:
async with AsyncWebCrawler( async with AsyncWebCrawler(
accept_downloads=True, accept_downloads=True,
downloads_path="/invalid/path/that/doesnt/exist", downloads_path="/invalid/path/that/doesnt/exist",
verbose=True verbose=True,
) as crawler: ) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()" js_code="document.querySelector('a[href$=\".exe\"]').click()",
) )
self.log_result("Invalid Download Path", False, "Should have raised an error") self.log_result(
except Exception as e: "Invalid Download Path", False, "Should have raised an error"
self.log_result("Invalid Download Path", True, "Correctly handled invalid path") )
except Exception:
self.log_result(
"Invalid Download Path", True, "Correctly handled invalid path"
)
# Test 3: Download with accept_downloads=False # Test 3: Download with accept_downloads=False
try: try:
async with AsyncWebCrawler( async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler:
accept_downloads=False,
verbose=True
) as crawler:
result = await crawler.arun( result = await crawler.arun(
url="https://www.python.org/downloads/", url="https://www.python.org/downloads/",
js_code="document.querySelector('a[href$=\".exe\"]').click()" js_code="document.querySelector('a[href$=\".exe\"]').click()",
) )
success = result.downloaded_files is None success = result.downloaded_files is None
self.log_result( self.log_result(
"Disabled Downloads", "Disabled Downloads",
success, success,
"Correctly ignored downloads" if success else "Unexpectedly downloaded files" "Correctly ignored downloads"
if success
else "Unexpectedly downloaded files",
) )
except Exception as e: except Exception as e:
self.log_result("Disabled Downloads", False, str(e)) self.log_result("Disabled Downloads", False, str(e))
@@ -197,33 +213,35 @@ class TestDownloads:
async def run_all_tests(self): async def run_all_tests(self):
"""Run all test cases""" """Run all test cases"""
print("\n🧪 Running Download Tests...\n") print("\n🧪 Running Download Tests...\n")
test_methods = [ test_methods = [
self.test_basic_download, self.test_basic_download,
self.test_persistent_context_download, self.test_persistent_context_download,
self.test_multiple_downloads, self.test_multiple_downloads,
self.test_different_browsers, self.test_different_browsers,
self.test_edge_cases self.test_edge_cases,
] ]
for test in test_methods: for test in test_methods:
print(f"\n📝 Running {test.__doc__}...") print(f"\n📝 Running {test.__doc__}...")
await test() await test()
await asyncio.sleep(2) # Brief pause between tests await asyncio.sleep(2) # Brief pause between tests
print("\n📊 Test Results Summary:") print("\n📊 Test Results Summary:")
for result in self.results: for result in self.results:
print(result) print(result)
successes = len([r for r in self.results if '' in r]) successes = len([r for r in self.results if "" in r])
total = len(self.results) total = len(self.results)
print(f"\nTotal: {successes}/{total} tests passed") print(f"\nTotal: {successes}/{total} tests passed")
self.cleanup() self.cleanup()
async def main(): async def main():
tester = TestDownloads() tester = TestDownloads()
await tester.run_all_tests() await tester.run_all_tests()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,15 +1,17 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import time import time
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_crawl(): async def test_successful_crawl():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -21,6 +23,7 @@ async def test_successful_crawl():
assert result.markdown assert result.markdown
assert result.cleaned_html assert result.cleaned_html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_url(): async def test_invalid_url():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -29,19 +32,21 @@ async def test_invalid_url():
assert not result.success assert not result.success
assert result.error_message assert result.error_message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_urls(): async def test_multiple_urls():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
urls = [ urls = [
"https://www.nbcnews.com/business", "https://www.nbcnews.com/business",
"https://www.example.com", "https://www.example.com",
"https://www.python.org" "https://www.python.org",
] ]
results = await crawler.arun_many(urls=urls, bypass_cache=True) results = await crawler.arun_many(urls=urls, bypass_cache=True)
assert len(results) == len(urls) assert len(results) == len(urls)
assert all(result.success for result in results) assert all(result.success for result in results)
assert all(result.html for result in results) assert all(result.html for result in results)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_javascript_execution(): async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -51,6 +56,7 @@ async def test_javascript_execution():
assert result.success assert result.success
assert "<h1>Modified by JS</h1>" in result.html assert "<h1>Modified by JS</h1>" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_crawling_performance(): async def test_concurrent_crawling_performance():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -59,23 +65,26 @@ async def test_concurrent_crawling_performance():
"https://www.example.com", "https://www.example.com",
"https://www.python.org", "https://www.python.org",
"https://www.github.com", "https://www.github.com",
"https://www.stackoverflow.com" "https://www.stackoverflow.com",
] ]
start_time = time.time() start_time = time.time()
results = await crawler.arun_many(urls=urls, bypass_cache=True) results = await crawler.arun_many(urls=urls, bypass_cache=True)
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
print(f"Total time for concurrent crawling: {total_time:.2f} seconds") print(f"Total time for concurrent crawling: {total_time:.2f} seconds")
assert all(result.success for result in results) assert all(result.success for result in results)
assert len(results) == len(urls) assert len(results) == len(urls)
# Assert that concurrent crawling is faster than sequential # Assert that concurrent crawling is faster than sequential
# This multiplier may need adjustment based on the number of URLs and their complexity # This multiplier may need adjustment based on the number of URLs and their complexity
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" assert (
total_time < len(urls) * 5
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -9,74 +9,79 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_caching(): async def test_caching():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
# First crawl (should not use cache) # First crawl (should not use cache)
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
result1 = await crawler.arun(url=url, bypass_cache=True) result1 = await crawler.arun(url=url, bypass_cache=True)
end_time = asyncio.get_event_loop().time() end_time = asyncio.get_event_loop().time()
time_taken1 = end_time - start_time time_taken1 = end_time - start_time
assert result1.success assert result1.success
# Second crawl (should use cache) # Second crawl (should use cache)
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
result2 = await crawler.arun(url=url, bypass_cache=False) result2 = await crawler.arun(url=url, bypass_cache=False)
end_time = asyncio.get_event_loop().time() end_time = asyncio.get_event_loop().time()
time_taken2 = end_time - start_time time_taken2 = end_time - start_time
assert result2.success assert result2.success
assert time_taken2 < time_taken1 # Cached result should be faster assert time_taken2 < time_taken1 # Cached result should be faster
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bypass_cache(): async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
# First crawl # First crawl
result1 = await crawler.arun(url=url, bypass_cache=False) result1 = await crawler.arun(url=url, bypass_cache=False)
assert result1.success assert result1.success
# Second crawl with bypass_cache=True # Second crawl with bypass_cache=True
result2 = await crawler.arun(url=url, bypass_cache=True) result2 = await crawler.arun(url=url, bypass_cache=True)
assert result2.success assert result2.success
# Content should be different (or at least, not guaranteed to be the same) # Content should be different (or at least, not guaranteed to be the same)
assert result1.html != result2.html or result1.markdown != result2.markdown assert result1.html != result2.html or result1.markdown != result2.markdown
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_clear_cache(): async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
# Crawl and cache # Crawl and cache
await crawler.arun(url=url, bypass_cache=False) await crawler.arun(url=url, bypass_cache=False)
# Clear cache # Clear cache
await crawler.aclear_cache() await crawler.aclear_cache()
# Check cache size # Check cache size
cache_size = await crawler.aget_cache_size() cache_size = await crawler.aget_cache_size()
assert cache_size == 0 assert cache_size == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flush_cache(): async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
# Crawl and cache # Crawl and cache
await crawler.arun(url=url, bypass_cache=False) await crawler.arun(url=url, bypass_cache=False)
# Flush cache # Flush cache
await crawler.aflush_cache() await crawler.aflush_cache()
# Check cache size # Check cache size
cache_size = await crawler.aget_cache_size() cache_size = await crawler.aget_cache_size()
assert cache_size == 0 assert cache_size == 0
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
@@ -9,8 +8,9 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking from crawl4ai.chunking_strategy import RegexChunking
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy from crawl4ai.extraction_strategy import LLMExtractionStrategy
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regex_chunking(): async def test_regex_chunking():
@@ -18,15 +18,14 @@ async def test_regex_chunking():
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
chunking_strategy = RegexChunking(patterns=["\n\n"]) chunking_strategy = RegexChunking(patterns=["\n\n"])
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, chunking_strategy=chunking_strategy, bypass_cache=True
chunking_strategy=chunking_strategy,
bypass_cache=True
) )
assert result.success assert result.success
assert result.extracted_content assert result.extracted_content
chunks = json.loads(result.extracted_content) chunks = json.loads(result.extracted_content)
assert len(chunks) > 1 # Ensure multiple chunks were created assert len(chunks) > 1 # Ensure multiple chunks were created
# @pytest.mark.asyncio # @pytest.mark.asyncio
# async def test_cosine_strategy(): # async def test_cosine_strategy():
# async with AsyncWebCrawler(verbose=True) as crawler: # async with AsyncWebCrawler(verbose=True) as crawler:
@@ -43,25 +42,25 @@ async def test_regex_chunking():
# assert len(extracted_data) > 0 # assert len(extracted_data) > 0
# assert all('tags' in item for item in extracted_data) # assert all('tags' in item for item in extracted_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_extraction_strategy(): async def test_llm_extraction_strategy():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
extraction_strategy = LLMExtractionStrategy( extraction_strategy = LLMExtractionStrategy(
provider="openai/gpt-4o-mini", provider="openai/gpt-4o-mini",
api_token=os.getenv('OPENAI_API_KEY'), api_token=os.getenv("OPENAI_API_KEY"),
instruction="Extract only content related to technology" instruction="Extract only content related to technology",
) )
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, extraction_strategy=extraction_strategy, bypass_cache=True
extraction_strategy=extraction_strategy,
bypass_cache=True
) )
assert result.success assert result.success
assert result.extracted_content assert result.extracted_content
extracted_data = json.loads(result.extracted_content) extracted_data = json.loads(result.extracted_content)
assert len(extracted_data) > 0 assert len(extracted_data) > 0
assert all('content' in item for item in extracted_data) assert all("content" in item for item in extracted_data)
# @pytest.mark.asyncio # @pytest.mark.asyncio
# async def test_combined_chunking_and_extraction(): # async def test_combined_chunking_and_extraction():
@@ -84,4 +83,4 @@ async def test_llm_extraction_strategy():
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,8 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_markdown(): async def test_extract_markdown():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -20,6 +19,7 @@ async def test_extract_markdown():
assert isinstance(result.markdown, str) assert isinstance(result.markdown, str)
assert len(result.markdown) > 0 assert len(result.markdown) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_cleaned_html(): async def test_extract_cleaned_html():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -30,6 +30,7 @@ async def test_extract_cleaned_html():
assert isinstance(result.cleaned_html, str) assert isinstance(result.cleaned_html, str)
assert len(result.cleaned_html) > 0 assert len(result.cleaned_html) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_media(): async def test_extract_media():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -46,6 +47,7 @@ async def test_extract_media():
assert "alt" in image assert "alt" in image
assert "type" in image assert "type" in image
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_links(): async def test_extract_links():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -63,6 +65,7 @@ async def test_extract_links():
assert "href" in link assert "href" in link
assert "text" in link assert "text" in link
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_metadata(): async def test_extract_metadata():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -75,16 +78,20 @@ async def test_extract_metadata():
assert "title" in metadata assert "title" in metadata
assert isinstance(metadata["title"], str) assert isinstance(metadata["title"], str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_css_selector_extraction(): async def test_css_selector_extraction():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
css_selector = "h1, h2, h3" css_selector = "h1, h2, h3"
result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector) result = await crawler.arun(
url=url, bypass_cache=True, css_selector=css_selector
)
assert result.success assert result.success
assert result.markdown assert result.markdown
assert all(heading in result.markdown for heading in ["#", "##", "###"]) assert all(heading in result.markdown for heading in ["#", "##", "###"])
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os, sys import os, sys
import pytest import pytest
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import List
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -9,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.content_filter_strategy import BM25ContentFilter from crawl4ai.content_filter_strategy import BM25ContentFilter
@pytest.fixture @pytest.fixture
def basic_html(): def basic_html():
return """ return """
@@ -28,6 +28,7 @@ def basic_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def wiki_html(): def wiki_html():
return """ return """
@@ -46,6 +47,7 @@ def wiki_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def no_meta_html(): def no_meta_html():
return """ return """
@@ -57,26 +59,27 @@ def no_meta_html():
</html> </html>
""" """
class TestBM25ContentFilter: class TestBM25ContentFilter:
def test_basic_extraction(self, basic_html): def test_basic_extraction(self, basic_html):
"""Test basic content extraction functionality""" """Test basic content extraction functionality"""
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(basic_html) contents = filter.filter_content(basic_html)
assert contents, "Should extract content" assert contents, "Should extract content"
assert len(contents) >= 1, "Should extract at least one content block" assert len(contents) >= 1, "Should extract at least one content block"
assert "long paragraph" in ' '.join(contents).lower() assert "long paragraph" in " ".join(contents).lower()
assert "navigation" not in ' '.join(contents).lower() assert "navigation" not in " ".join(contents).lower()
def test_user_query_override(self, basic_html): def test_user_query_override(self, basic_html):
"""Test that user query overrides metadata extraction""" """Test that user query overrides metadata extraction"""
user_query = "specific test query" user_query = "specific test query"
filter = BM25ContentFilter(user_query=user_query) filter = BM25ContentFilter(user_query=user_query)
# Access internal state to verify query usage # Access internal state to verify query usage
soup = BeautifulSoup(basic_html, 'lxml') soup = BeautifulSoup(basic_html, "lxml")
extracted_query = filter.extract_page_query(soup.find('head')) extracted_query = filter.extract_page_query(soup.find("head"))
assert extracted_query == user_query assert extracted_query == user_query
assert "Test description" not in extracted_query assert "Test description" not in extracted_query
@@ -84,8 +87,8 @@ class TestBM25ContentFilter:
"""Test that headers are properly extracted despite length""" """Test that headers are properly extracted despite length"""
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(wiki_html) contents = filter.filter_content(wiki_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "section 1" in combined_content, "Should include section header" assert "section 1" in combined_content, "Should include section header"
assert "article title" in combined_content, "Should include main title" assert "article title" in combined_content, "Should include main title"
@@ -93,9 +96,11 @@ class TestBM25ContentFilter:
"""Test fallback behavior when no metadata is present""" """Test fallback behavior when no metadata is present"""
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(no_meta_html) contents = filter.filter_content(no_meta_html)
assert contents, "Should extract content even without metadata" assert contents, "Should extract content even without metadata"
assert "First paragraph" in ' '.join(contents), "Should use first paragraph content" assert "First paragraph" in " ".join(
contents
), "Should use first paragraph content"
def test_empty_input(self): def test_empty_input(self):
"""Test handling of empty input""" """Test handling of empty input"""
@@ -108,29 +113,30 @@ class TestBM25ContentFilter:
malformed_html = "<p>Unclosed paragraph<div>Nested content</p></div>" malformed_html = "<p>Unclosed paragraph<div>Nested content</p></div>"
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(malformed_html) contents = filter.filter_content(malformed_html)
assert isinstance(contents, list), "Should return list even with malformed HTML" assert isinstance(contents, list), "Should return list even with malformed HTML"
def test_threshold_behavior(self, basic_html): def test_threshold_behavior(self, basic_html):
"""Test different BM25 threshold values""" """Test different BM25 threshold values"""
strict_filter = BM25ContentFilter(bm25_threshold=2.0) strict_filter = BM25ContentFilter(bm25_threshold=2.0)
lenient_filter = BM25ContentFilter(bm25_threshold=0.5) lenient_filter = BM25ContentFilter(bm25_threshold=0.5)
strict_contents = strict_filter.filter_content(basic_html) strict_contents = strict_filter.filter_content(basic_html)
lenient_contents = lenient_filter.filter_content(basic_html) lenient_contents = lenient_filter.filter_content(basic_html)
assert len(strict_contents) <= len(lenient_contents), \ assert len(strict_contents) <= len(
"Strict threshold should extract fewer elements" lenient_contents
), "Strict threshold should extract fewer elements"
def test_html_cleaning(self, basic_html): def test_html_cleaning(self, basic_html):
"""Test HTML cleaning functionality""" """Test HTML cleaning functionality"""
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(basic_html) contents = filter.filter_content(basic_html)
cleaned_content = ' '.join(contents) cleaned_content = " ".join(contents)
assert 'class=' not in cleaned_content, "Should remove class attributes" assert "class=" not in cleaned_content, "Should remove class attributes"
assert 'style=' not in cleaned_content, "Should remove style attributes" assert "style=" not in cleaned_content, "Should remove style attributes"
assert '<script' not in cleaned_content, "Should remove script tags" assert "<script" not in cleaned_content, "Should remove script tags"
def test_large_content(self): def test_large_content(self):
"""Test handling of large content blocks""" """Test handling of large content blocks"""
@@ -143,9 +149,9 @@ class TestBM25ContentFilter:
contents = filter.filter_content(large_html) contents = filter.filter_content(large_html)
assert contents, "Should handle large content blocks" assert contents, "Should handle large content blocks"
@pytest.mark.parametrize("unwanted_tag", [ @pytest.mark.parametrize(
'script', 'style', 'nav', 'footer', 'header' "unwanted_tag", ["script", "style", "nav", "footer", "header"]
]) )
def test_excluded_tags(self, unwanted_tag): def test_excluded_tags(self, unwanted_tag):
"""Test that specific tags are properly excluded""" """Test that specific tags are properly excluded"""
html = f""" html = f"""
@@ -156,20 +162,22 @@ class TestBM25ContentFilter:
""" """
filter = BM25ContentFilter() filter = BM25ContentFilter()
contents = filter.filter_content(html) contents = filter.filter_content(html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "should not appear" not in combined_content assert "should not appear" not in combined_content
def test_performance(self, basic_html): def test_performance(self, basic_html):
"""Test performance with timer""" """Test performance with timer"""
filter = BM25ContentFilter() filter = BM25ContentFilter()
import time import time
start = time.perf_counter() start = time.perf_counter()
filter.filter_content(basic_html) filter.filter_content(basic_html)
duration = time.perf_counter() - start duration = time.perf_counter() - start
assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds" assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@@ -1,12 +1,12 @@
import os, sys import os, sys
import pytest import pytest
from bs4 import BeautifulSoup
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.content_filter_strategy import PruningContentFilter
@pytest.fixture @pytest.fixture
def basic_html(): def basic_html():
return """ return """
@@ -22,6 +22,7 @@ def basic_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def link_heavy_html(): def link_heavy_html():
return """ return """
@@ -40,6 +41,7 @@ def link_heavy_html():
</html> </html>
""" """
@pytest.fixture @pytest.fixture
def mixed_content_html(): def mixed_content_html():
return """ return """
@@ -60,13 +62,14 @@ def mixed_content_html():
</html> </html>
""" """
class TestPruningContentFilter: class TestPruningContentFilter:
def test_basic_pruning(self, basic_html): def test_basic_pruning(self, basic_html):
"""Test basic content pruning functionality""" """Test basic content pruning functionality"""
filter = PruningContentFilter(min_word_threshold=5) filter = PruningContentFilter(min_word_threshold=5)
contents = filter.filter_content(basic_html) contents = filter.filter_content(basic_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "high-quality paragraph" in combined_content assert "high-quality paragraph" in combined_content
assert "sidebar content" not in combined_content assert "sidebar content" not in combined_content
assert "share buttons" not in combined_content assert "share buttons" not in combined_content
@@ -75,40 +78,42 @@ class TestPruningContentFilter:
"""Test minimum word threshold filtering""" """Test minimum word threshold filtering"""
filter = PruningContentFilter(min_word_threshold=10) filter = PruningContentFilter(min_word_threshold=10)
contents = filter.filter_content(mixed_content_html) contents = filter.filter_content(mixed_content_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "short summary" not in combined_content assert "short summary" not in combined_content
assert "long high-quality paragraph" in combined_content assert "long high-quality paragraph" in combined_content
assert "short comment" not in combined_content assert "short comment" not in combined_content
def test_threshold_types(self, basic_html): def test_threshold_types(self, basic_html):
"""Test fixed vs dynamic thresholds""" """Test fixed vs dynamic thresholds"""
fixed_filter = PruningContentFilter(threshold_type='fixed', threshold=0.48) fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48)
dynamic_filter = PruningContentFilter(threshold_type='dynamic', threshold=0.45) dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45)
fixed_contents = fixed_filter.filter_content(basic_html) fixed_contents = fixed_filter.filter_content(basic_html)
dynamic_contents = dynamic_filter.filter_content(basic_html) dynamic_contents = dynamic_filter.filter_content(basic_html)
assert len(fixed_contents) != len(dynamic_contents), \ assert len(fixed_contents) != len(
"Fixed and dynamic thresholds should yield different results" dynamic_contents
), "Fixed and dynamic thresholds should yield different results"
def test_link_density_impact(self, link_heavy_html): def test_link_density_impact(self, link_heavy_html):
"""Test handling of link-heavy content""" """Test handling of link-heavy content"""
filter = PruningContentFilter(threshold_type='dynamic') filter = PruningContentFilter(threshold_type="dynamic")
contents = filter.filter_content(link_heavy_html) contents = filter.filter_content(link_heavy_html)
combined_content = ' '.join(contents).lower() combined_content = " ".join(contents).lower()
assert "good content paragraph" in combined_content assert "good content paragraph" in combined_content
assert len([c for c in contents if 'href' in c]) < 2, \ assert (
"Should prune link-heavy sections" len([c for c in contents if "href" in c]) < 2
), "Should prune link-heavy sections"
def test_tag_importance(self, mixed_content_html): def test_tag_importance(self, mixed_content_html):
"""Test tag importance in scoring""" """Test tag importance in scoring"""
filter = PruningContentFilter(threshold_type='dynamic') filter = PruningContentFilter(threshold_type="dynamic")
contents = filter.filter_content(mixed_content_html) contents = filter.filter_content(mixed_content_html)
has_article = any('article' in c.lower() for c in contents) has_article = any("article" in c.lower() for c in contents)
has_h1 = any('h1' in c.lower() for c in contents) has_h1 = any("h1" in c.lower() for c in contents)
assert has_article or has_h1, "Should retain important tags" assert has_article or has_h1, "Should retain important tags"
def test_empty_input(self): def test_empty_input(self):
@@ -127,26 +132,31 @@ class TestPruningContentFilter:
def test_performance(self, basic_html): def test_performance(self, basic_html):
"""Test performance with timer""" """Test performance with timer"""
filter = PruningContentFilter() filter = PruningContentFilter()
import time import time
start = time.perf_counter() start = time.perf_counter()
filter.filter_content(basic_html) filter.filter_content(basic_html)
duration = time.perf_counter() - start duration = time.perf_counter() - start
# Extra strict on performance since you mentioned milliseconds matter # Extra strict on performance since you mentioned milliseconds matter
assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds" assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds"
@pytest.mark.parametrize("threshold,expected_count", [ @pytest.mark.parametrize(
(0.3, 4), # Very lenient "threshold,expected_count",
(0.48, 2), # Default [
(0.7, 1), # Very strict (0.3, 4), # Very lenient
]) (0.48, 2), # Default
(0.7, 1), # Very strict
],
)
def test_threshold_levels(self, mixed_content_html, threshold, expected_count): def test_threshold_levels(self, mixed_content_html, threshold, expected_count):
"""Test different threshold levels""" """Test different threshold levels"""
filter = PruningContentFilter(threshold_type='fixed', threshold=threshold) filter = PruningContentFilter(threshold_type="fixed", threshold=threshold)
contents = filter.filter_content(mixed_content_html) contents = filter.filter_content(mixed_content_html)
assert len(contents) <= expected_count, \ assert (
f"Expected {expected_count} or fewer elements with threshold {threshold}" len(contents) <= expected_count
), f"Expected {expected_count} or fewer elements with threshold {threshold}"
def test_consistent_output(self, basic_html): def test_consistent_output(self, basic_html):
"""Test output consistency across multiple runs""" """Test output consistency across multiple runs"""
@@ -155,5 +165,6 @@ class TestPruningContentFilter:
second_run = filter.filter_content(basic_html) second_run = filter.filter_content(basic_html)
assert first_run == second_run, "Output should be consistent" assert first_run == second_run, "Output should be consistent"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@@ -1,22 +1,24 @@
import asyncio
from bs4 import BeautifulSoup
from typing import Dict, Any
import os import os
import sys import sys
import time import time
import csv import csv
from tabulate import tabulate from tabulate import tabulate
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict from typing import List
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) parent_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
from crawl4ai.content_scraping_strategy import WebScrapingStrategy from crawl4ai.content_scraping_strategy import WebScrapingStrategy
from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent from crawl4ai.content_scraping_strategy import (
WebScrapingStrategy as WebScrapingStrategyCurrent,
)
# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent # from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
@dataclass @dataclass
class TestResult: class TestResult:
name: str name: str
@@ -27,69 +29,71 @@ class TestResult:
markdown_length: int markdown_length: int
execution_time: float execution_time: float
class StrategyTester: class StrategyTester:
def __init__(self): def __init__(self):
self.new_scraper = WebScrapingStrategy() self.new_scraper = WebScrapingStrategy()
self.current_scraper = WebScrapingStrategyCurrent() self.current_scraper = WebScrapingStrategyCurrent()
with open(__location__ + '/sample_wikipedia.html', 'r', encoding='utf-8') as f: with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f:
self.WIKI_HTML = f.read() self.WIKI_HTML = f.read()
self.results = {'new': [], 'current': []} self.results = {"new": [], "current": []}
def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]: def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
results = [] results = []
for scraper in [self.new_scraper, self.current_scraper]: for scraper in [self.new_scraper, self.current_scraper]:
start_time = time.time() start_time = time.time()
result = scraper._get_content_of_website_optimized( result = scraper._get_content_of_website_optimized(
url="https://en.wikipedia.org/wiki/Test", url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs
html=self.WIKI_HTML,
**kwargs
) )
execution_time = time.time() - start_time execution_time = time.time() - start_time
test_result = TestResult( test_result = TestResult(
name=name, name=name,
success=result['success'], success=result["success"],
images=len(result['media']['images']), images=len(result["media"]["images"]),
internal_links=len(result['links']['internal']), internal_links=len(result["links"]["internal"]),
external_links=len(result['links']['external']), external_links=len(result["links"]["external"]),
markdown_length=len(result['markdown']), markdown_length=len(result["markdown"]),
execution_time=execution_time execution_time=execution_time,
) )
results.append(test_result) results.append(test_result)
return results[0], results[1] # new, current return results[0], results[1] # new, current
def run_all_tests(self): def run_all_tests(self):
test_cases = [ test_cases = [
("Basic Extraction", {}), ("Basic Extraction", {}),
("Exclude Tags", {'excluded_tags': ['table', 'div.infobox', 'div.navbox']}), ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
("Word Threshold", {'word_count_threshold': 50}), ("Word Threshold", {"word_count_threshold": 50}),
("CSS Selector", {'css_selector': 'div.mw-parser-output > p'}), ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
("Link Exclusions", { (
'exclude_external_links': True, "Link Exclusions",
'exclude_social_media_links': True, {
'exclude_domains': ['facebook.com', 'twitter.com'] "exclude_external_links": True,
}), "exclude_social_media_links": True,
("Media Handling", { "exclude_domains": ["facebook.com", "twitter.com"],
'exclude_external_images': True, },
'image_description_min_word_threshold': 20 ),
}), (
("Text Only", { "Media Handling",
'only_text': True, {
'remove_forms': True "exclude_external_images": True,
}), "image_description_min_word_threshold": 20,
("HTML Cleaning", { },
'clean_html': True, ),
'keep_data_attributes': True ("Text Only", {"only_text": True, "remove_forms": True}),
}), ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
("HTML2Text Options", { (
'html2text': { "HTML2Text Options",
'skip_internal_links': True, {
'single_line_break': True, "html2text": {
'mark_code': True, "skip_internal_links": True,
'preserve_tags': ['pre', 'code'] "single_line_break": True,
} "mark_code": True,
}) "preserve_tags": ["pre", "code"],
}
},
),
] ]
all_results = [] all_results = []
@@ -99,64 +103,117 @@ class StrategyTester:
all_results.append((name, new_result, current_result)) all_results.append((name, new_result, current_result))
except Exception as e: except Exception as e:
print(f"Error in {name}: {str(e)}") print(f"Error in {name}: {str(e)}")
self.save_results_to_csv(all_results) self.save_results_to_csv(all_results)
self.print_comparison_table(all_results) self.print_comparison_table(all_results)
def save_results_to_csv(self, all_results: List[tuple]): def save_results_to_csv(self, all_results: List[tuple]):
csv_file = os.path.join(__location__, 'strategy_comparison_results.csv') csv_file = os.path.join(__location__, "strategy_comparison_results.csv")
with open(csv_file, 'w', newline='') as f: with open(csv_file, "w", newline="") as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow(['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links', writer.writerow(
'External Links', 'Markdown Length', 'Execution Time']) [
"Test Name",
"Strategy",
"Success",
"Images",
"Internal Links",
"External Links",
"Markdown Length",
"Execution Time",
]
)
for name, new_result, current_result in all_results: for name, new_result, current_result in all_results:
writer.writerow([name, 'New', new_result.success, new_result.images, writer.writerow(
new_result.internal_links, new_result.external_links, [
new_result.markdown_length, f"{new_result.execution_time:.3f}"]) name,
writer.writerow([name, 'Current', current_result.success, current_result.images, "New",
current_result.internal_links, current_result.external_links, new_result.success,
current_result.markdown_length, f"{current_result.execution_time:.3f}"]) new_result.images,
new_result.internal_links,
new_result.external_links,
new_result.markdown_length,
f"{new_result.execution_time:.3f}",
]
)
writer.writerow(
[
name,
"Current",
current_result.success,
current_result.images,
current_result.internal_links,
current_result.external_links,
current_result.markdown_length,
f"{current_result.execution_time:.3f}",
]
)
def print_comparison_table(self, all_results: List[tuple]): def print_comparison_table(self, all_results: List[tuple]):
table_data = [] table_data = []
headers = ['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links', headers = [
'External Links', 'Markdown Length', 'Time (s)'] "Test Name",
"Strategy",
"Success",
"Images",
"Internal Links",
"External Links",
"Markdown Length",
"Time (s)",
]
for name, new_result, current_result in all_results: for name, new_result, current_result in all_results:
# Check for differences # Check for differences
differences = [] differences = []
if new_result.images != current_result.images: differences.append('images') if new_result.images != current_result.images:
if new_result.internal_links != current_result.internal_links: differences.append('internal_links') differences.append("images")
if new_result.external_links != current_result.external_links: differences.append('external_links') if new_result.internal_links != current_result.internal_links:
if new_result.markdown_length != current_result.markdown_length: differences.append('markdown') differences.append("internal_links")
if new_result.external_links != current_result.external_links:
differences.append("external_links")
if new_result.markdown_length != current_result.markdown_length:
differences.append("markdown")
# Add row for new strategy # Add row for new strategy
new_row = [ new_row = [
name, 'New', new_result.success, new_result.images, name,
new_result.internal_links, new_result.external_links, "New",
new_result.markdown_length, f"{new_result.execution_time:.3f}" new_result.success,
new_result.images,
new_result.internal_links,
new_result.external_links,
new_result.markdown_length,
f"{new_result.execution_time:.3f}",
] ]
table_data.append(new_row) table_data.append(new_row)
# Add row for current strategy # Add row for current strategy
current_row = [ current_row = [
'', 'Current', current_result.success, current_result.images, "",
current_result.internal_links, current_result.external_links, "Current",
current_result.markdown_length, f"{current_result.execution_time:.3f}" current_result.success,
current_result.images,
current_result.internal_links,
current_result.external_links,
current_result.markdown_length,
f"{current_result.execution_time:.3f}",
] ]
table_data.append(current_row) table_data.append(current_row)
# Add difference summary if any # Add difference summary if any
if differences: if differences:
table_data.append(['', '⚠️ Differences', ', '.join(differences), '', '', '', '', '']) table_data.append(
["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
)
# Add empty row for better readability # Add empty row for better readability
table_data.append([''] * len(headers)) table_data.append([""] * len(headers))
print("\nStrategy Comparison Results:") print("\nStrategy Comparison Results:")
print(tabulate(table_data, headers=headers, tablefmt='grid')) print(tabulate(table_data, headers=headers, tablefmt="grid"))
if __name__ == "__main__": if __name__ == "__main__":
tester = StrategyTester() tester = StrategyTester()
tester.run_all_tests() tester.run_all_tests()

View File

@@ -1,14 +1,13 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_user_agent(): async def test_custom_user_agent():
@@ -20,6 +19,7 @@ async def test_custom_user_agent():
assert result.success assert result.success
assert custom_user_agent in result.html assert custom_user_agent in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_headers(): async def test_custom_headers():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -31,6 +31,7 @@ async def test_custom_headers():
assert "X-Test-Header" in result.html assert "X-Test-Header" in result.html
assert "TestValue" in result.html assert "TestValue" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_javascript_execution(): async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -40,19 +41,22 @@ async def test_javascript_execution():
assert result.success assert result.success
assert "<h1>Modified by JS</h1>" in result.html assert "<h1>Modified by JS</h1>" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_hook_execution(): async def test_hook_execution():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
async def test_hook(page): async def test_hook(page):
await page.evaluate("document.body.style.backgroundColor = 'red';") await page.evaluate("document.body.style.backgroundColor = 'red';")
return page return page
crawler.crawler_strategy.set_hook('after_goto', test_hook) crawler.crawler_strategy.set_hook("after_goto", test_hook)
url = "https://www.example.com" url = "https://www.example.com"
result = await crawler.arun(url=url, bypass_cache=True) result = await crawler.arun(url=url, bypass_cache=True)
assert result.success assert result.success
assert "background-color: red" in result.html assert "background-color: red" in result.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot(): async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -63,6 +67,7 @@ async def test_screenshot():
assert isinstance(result.screenshot, str) assert isinstance(result.screenshot, str)
assert len(result.screenshot) > 0 assert len(result.screenshot) > 0
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,8 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cache_url(): async def test_cache_url():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -23,6 +22,7 @@ async def test_cache_url():
assert result2.success assert result2.success
assert result2.html == result1.html assert result2.html == result1.html
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bypass_cache(): async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -34,25 +34,29 @@ async def test_bypass_cache():
# Second run bypassing cache # Second run bypassing cache
result2 = await crawler.arun(url=url, bypass_cache=True) result2 = await crawler.arun(url=url, bypass_cache=True)
assert result2.success assert result2.success
assert result2.html != result1.html # Content might be different due to dynamic nature of websites assert (
result2.html != result1.html
) # Content might be different due to dynamic nature of websites
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cache_size(): async def test_cache_size():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
initial_size = await crawler.aget_cache_size() initial_size = await crawler.aget_cache_size()
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
await crawler.arun(url=url, bypass_cache=True) await crawler.arun(url=url, bypass_cache=True)
new_size = await crawler.aget_cache_size() new_size = await crawler.aget_cache_size()
assert new_size == initial_size + 1 assert new_size == initial_size + 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_clear_cache(): async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.example.org" url = "https://www.example.org"
await crawler.arun(url=url, bypass_cache=True) await crawler.arun(url=url, bypass_cache=True)
initial_size = await crawler.aget_cache_size() initial_size = await crawler.aget_cache_size()
assert initial_size > 0 assert initial_size > 0
@@ -60,12 +64,13 @@ async def test_clear_cache():
new_size = await crawler.aget_cache_size() new_size = await crawler.aget_cache_size()
assert new_size == 0 assert new_size == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flush_cache(): async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.example.net" url = "https://www.example.net"
await crawler.arun(url=url, bypass_cache=True) await crawler.arun(url=url, bypass_cache=True)
initial_size = await crawler.aget_cache_size() initial_size = await crawler.aget_cache_size()
assert initial_size > 0 assert initial_size > 0
@@ -75,8 +80,11 @@ async def test_flush_cache():
# Try to retrieve the previously cached URL # Try to retrieve the previously cached URL
result = await crawler.arun(url=url, bypass_cache=False) result = await crawler.arun(url=url, bypass_cache=False)
assert result.success # The crawler should still succeed, but it will fetch the content anew assert (
result.success
) # The crawler should still succeed, but it will fetch the content anew
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,114 +1,133 @@
import pytest import pytest
import asyncio, time import time
from crawl4ai import ( from crawl4ai import (
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, AsyncWebCrawler,
MemoryAdaptiveDispatcher, SemaphoreDispatcher, BrowserConfig,
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode CrawlerRunConfig,
MemoryAdaptiveDispatcher,
SemaphoreDispatcher,
RateLimiter,
CrawlerMonitor,
DisplayMode,
CacheMode,
) )
@pytest.fixture @pytest.fixture
def browser_config(): def browser_config():
return BrowserConfig( return BrowserConfig(headless=True, verbose=False)
headless=True,
verbose=False
)
@pytest.fixture @pytest.fixture
def run_config(): def run_config():
return CrawlerRunConfig( return CrawlerRunConfig(cache_mode=CacheMode.BYPASS, verbose=False)
cache_mode=CacheMode.BYPASS,
verbose=False
)
@pytest.fixture @pytest.fixture
def test_urls(): def test_urls():
return [ return [
"http://example.com", "http://example.com",
"http://example.com/page1", "http://example.com/page1",
"http://example.com/page2" "http://example.com/page2",
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
class TestDispatchStrategies: class TestDispatchStrategies:
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls): async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=70.0, memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1
max_session_permit=2, )
check_interval=0.1 results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_memory_adaptive_with_rate_limit(self, browser_config, run_config, test_urls): async def test_memory_adaptive_with_rate_limit(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=70.0, memory_threshold_percent=70.0,
max_session_permit=2, max_session_permit=2,
check_interval=0.1, check_interval=0.1,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(0.1, 0.2), base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
max_delay=1.0, ),
max_retries=2 )
) results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_semaphore_basic(self, browser_config, run_config, test_urls): async def test_semaphore_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(semaphore_count=2)
semaphore_count=2 results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_semaphore_with_rate_limit(self, browser_config, run_config, test_urls): async def test_semaphore_with_rate_limit(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = SemaphoreDispatcher( dispatcher = SemaphoreDispatcher(
semaphore_count=2, semaphore_count=2,
rate_limiter=RateLimiter( rate_limiter=RateLimiter(
base_delay=(0.1, 0.2), base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
max_delay=1.0, ),
max_retries=2 )
) results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
assert all(r.success for r in results) assert all(r.success for r in results)
async def test_memory_adaptive_memory_error(self, browser_config, run_config, test_urls): async def test_memory_adaptive_memory_error(
self, browser_config, run_config, test_urls
):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher( dispatcher = MemoryAdaptiveDispatcher(
memory_threshold_percent=1.0, # Set unrealistically low threshold memory_threshold_percent=1.0, # Set unrealistically low threshold
max_session_permit=2, max_session_permit=2,
check_interval=0.1, check_interval=0.1,
memory_wait_timeout=1.0 # Short timeout for testing memory_wait_timeout=1.0, # Short timeout for testing
) )
with pytest.raises(MemoryError): with pytest.raises(MemoryError):
await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher) await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
)
async def test_empty_urls(self, browser_config, run_config): async def test_empty_urls(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many([], config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
[], config=run_config, dispatcher=dispatcher
)
assert len(results) == 0 assert len(results) == 0
async def test_single_url(self, browser_config, run_config): async def test_single_url(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many(["http://example.com"], config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
["http://example.com"], config=run_config, dispatcher=dispatcher
)
assert len(results) == 1 assert len(results) == 1
assert results[0].success assert results[0].success
async def test_invalid_urls(self, browser_config, run_config): async def test_invalid_urls(self, browser_config, run_config):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
results = await crawler.arun_many(["http://invalid.url.that.doesnt.exist"], config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
["http://invalid.url.that.doesnt.exist"],
config=run_config,
dispatcher=dispatcher,
)
assert len(results) == 1 assert len(results) == 1
assert not results[0].success assert not results[0].success
@@ -121,27 +140,31 @@ class TestDispatchStrategies:
base_delay=(0.1, 0.2), base_delay=(0.1, 0.2),
max_delay=1.0, max_delay=1.0,
max_retries=2, max_retries=2,
rate_limit_codes=[200] # Force rate limiting for testing rate_limit_codes=[200], # Force rate limiting for testing
) ),
) )
start_time = time.time() start_time = time.time()
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher) results = await crawler.arun_many(
urls, config=run_config, dispatcher=dispatcher
)
duration = time.time() - start_time duration = time.time() - start_time
assert len(results) == len(urls) assert len(results) == len(urls)
assert duration > 1.0 # Ensure rate limiting caused delays assert duration > 1.0 # Ensure rate limiting caused delays
async def test_monitor_integration(self, browser_config, run_config, test_urls): async def test_monitor_integration(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler: async with AsyncWebCrawler(config=browser_config) as crawler:
monitor = CrawlerMonitor(max_visible_rows=5, display_mode=DisplayMode.DETAILED) monitor = CrawlerMonitor(
dispatcher = MemoryAdaptiveDispatcher( max_visible_rows=5, display_mode=DisplayMode.DETAILED
max_session_permit=2, )
monitor=monitor dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
) )
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
assert len(results) == len(test_urls) assert len(results) == len(test_urls)
# Check monitor stats # Check monitor stats
assert len(monitor.stats) == len(test_urls) assert len(monitor.stats) == len(test_urls)
assert all(stat.end_time is not None for stat in monitor.stats.values()) assert all(stat.end_time is not None for stat in monitor.stats.values())
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v", "--asyncio-mode=auto"]) pytest.main([__file__, "-v", "--asyncio-mode=auto"])

View File

@@ -2,9 +2,9 @@ import os
import re import re
import sys import sys
import pytest import pytest
import json
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import asyncio import asyncio
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
@@ -59,19 +59,21 @@ from crawl4ai.async_webcrawler import AsyncWebCrawler
# assert result.success # assert result.success
# assert "github" in result.html.lower() # assert "github" in result.html.lower()
# Add this test to your existing test file # Add this test to your existing test file
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_typescript_commits_multi_page(): async def test_typescript_commits_multi_page():
first_commit = "" first_commit = ""
async def on_execution_started(page): async def on_execution_started(page):
nonlocal first_commit nonlocal first_commit
try: try:
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4')) # Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
while True: while True:
await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4') await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4")
commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4') commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4")
commit = await commit.evaluate('(element) => element.textContent') commit = await commit.evaluate("(element) => element.textContent")
commit = re.sub(r'\s+', '', commit) commit = re.sub(r"\s+", "", commit)
if commit and commit != first_commit: if commit and commit != first_commit:
first_commit = commit first_commit = commit
break break
@@ -79,9 +81,8 @@ async def test_typescript_commits_multi_page():
except Exception as e: except Exception as e:
print(f"Warning: New content didn't appear after JavaScript execution: {e}") print(f"Warning: New content didn't appear after JavaScript execution: {e}")
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started) crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
url = "https://github.com/microsoft/TypeScript/commits/main" url = "https://github.com/microsoft/TypeScript/commits/main"
session_id = "typescript_commits_session" session_id = "typescript_commits_session"
@@ -97,19 +98,21 @@ async def test_typescript_commits_multi_page():
url=url, # Only use URL for the first page url=url, # Only use URL for the first page
session_id=session_id, session_id=session_id,
css_selector="li.Box-sc-g0xbh4-0", css_selector="li.Box-sc-g0xbh4-0",
js=js_next_page if page > 0 else None, # Don't click 'next' on the first page js=js_next_page
if page > 0
else None, # Don't click 'next' on the first page
bypass_cache=True, bypass_cache=True,
js_only=page > 0 # Use js_only for subsequent pages js_only=page > 0, # Use js_only for subsequent pages
) )
assert result.success, f"Failed to crawl page {page + 1}" assert result.success, f"Failed to crawl page {page + 1}"
# Parse the HTML and extract commits # Parse the HTML and extract commits
soup = BeautifulSoup(result.cleaned_html, 'html.parser') soup = BeautifulSoup(result.cleaned_html, "html.parser")
commits = soup.select("li") commits = soup.select("li")
# Take first commit find h4 extract text # Take first commit find h4 extract text
first_commit = commits[0].find("h4").text first_commit = commits[0].find("h4").text
first_commit = re.sub(r'\s+', '', first_commit) first_commit = re.sub(r"\s+", "", first_commit)
all_commits.extend(commits) all_commits.extend(commits)
print(f"Page {page + 1}: Found {len(commits)} commits") print(f"Page {page + 1}: Found {len(commits)} commits")
@@ -118,10 +121,13 @@ async def test_typescript_commits_multi_page():
await crawler.crawler_strategy.kill_session(session_id) await crawler.crawler_strategy.kill_session(session_id)
# Assertions # Assertions
assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}" assert (
len(all_commits) >= 90
print(f"Successfully crawled {len(all_commits)} commits across 3 pages") ), f"Expected at least 90 commits, but got {len(all_commits)}"
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -75,4 +75,4 @@
# # Entry point for debugging # # Entry point for debugging
# if __name__ == "__main__": # if __name__ == "__main__":
# pytest.main([__file__, "-v"]) # pytest.main([__file__, "-v"])

View File

@@ -1,11 +1,15 @@
import json import json
import time import time
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from crawl4ai.content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy from crawl4ai.content_scraping_strategy import (
from typing import Dict, Any, List, Tuple WebScrapingStrategy,
LXMLWebScrapingStrategy,
)
from typing import Dict, List, Tuple
import difflib import difflib
from lxml import html as lhtml, etree from lxml import html as lhtml, etree
def normalize_dom(element): def normalize_dom(element):
""" """
Recursively normalizes an lxml HTML element: Recursively normalizes an lxml HTML element:
@@ -15,7 +19,7 @@ def normalize_dom(element):
Returns the same element (mutated). Returns the same element (mutated).
""" """
# Remove comment nodes # Remove comment nodes
comments = element.xpath('//comment()') comments = element.xpath("//comment()")
for c in comments: for c in comments:
p = c.getparent() p = c.getparent()
if p is not None: if p is not None:
@@ -45,7 +49,7 @@ def strip_html_body(root):
""" """
If 'root' is <html>, find its <body> child and move all of <body>'s children If 'root' is <html>, find its <body> child and move all of <body>'s children
into a new <div>. Return that <div>. into a new <div>. Return that <div>.
If 'root' is <body>, similarly move all of its children into a new <div> and return it. If 'root' is <body>, similarly move all of its children into a new <div> and return it.
Otherwise, return 'root' as-is. Otherwise, return 'root' as-is.
@@ -53,8 +57,8 @@ def strip_html_body(root):
tag_name = (root.tag or "").lower() tag_name = (root.tag or "").lower()
# Case 1: The root is <html> # Case 1: The root is <html>
if tag_name == 'html': if tag_name == "html":
bodies = root.xpath('./body') bodies = root.xpath("./body")
if bodies: if bodies:
body = bodies[0] body = bodies[0]
new_div = lhtml.Element("div") new_div = lhtml.Element("div")
@@ -66,7 +70,7 @@ def strip_html_body(root):
return root return root
# Case 2: The root is <body> # Case 2: The root is <body>
elif tag_name == 'body': elif tag_name == "body":
new_div = lhtml.Element("div") new_div = lhtml.Element("div")
for child in root: for child in root:
new_div.append(child) new_div.append(child)
@@ -92,7 +96,9 @@ def compare_nodes(node1, node2, differences, path="/"):
attrs1 = list(node1.attrib.items()) attrs1 = list(node1.attrib.items())
attrs2 = list(node2.attrib.items()) attrs2 = list(node2.attrib.items())
if attrs1 != attrs2: if attrs1 != attrs2:
differences.append(f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}") differences.append(
f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}"
)
# 3) Compare text (trim or unify whitespace as needed) # 3) Compare text (trim or unify whitespace as needed)
text1 = (node1.text or "").strip() text1 = (node1.text or "").strip()
@@ -102,7 +108,9 @@ def compare_nodes(node1, node2, differences, path="/"):
text2 = " ".join(text2.split()) text2 = " ".join(text2.split())
if text1 != text2: if text1 != text2:
# If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup # If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup
differences.append(f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'") differences.append(
f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'"
)
# 4) Compare number of children # 4) Compare number of children
children1 = list(node1) children1 = list(node1)
@@ -123,7 +131,9 @@ def compare_nodes(node1, node2, differences, path="/"):
tail1 = (node1.tail or "").strip() tail1 = (node1.tail or "").strip()
tail2 = (node2.tail or "").strip() tail2 = (node2.tail or "").strip()
if tail1 != tail2: if tail1 != tail2:
differences.append(f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'") differences.append(
f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'"
)
def compare_html_structurally(html1, html2): def compare_html_structurally(html1, html2):
@@ -156,11 +166,11 @@ def compare_html_structurally(html1, html2):
return differences return differences
def generate_large_html(n_elements=1000): def generate_large_html(n_elements=1000):
html = ['<!DOCTYPE html><html><head></head><body>'] html = ["<!DOCTYPE html><html><head></head><body>"]
for i in range(n_elements): for i in range(n_elements):
html.append(f''' html.append(
f"""
<div class="article"> <div class="article">
<h2>Heading {i}</h2> <h2>Heading {i}</h2>
<p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p> <p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p>
@@ -170,13 +180,15 @@ def generate_large_html(n_elements=1000):
<li>List item {i}.2</li> <li>List item {i}.2</li>
</ul> </ul>
</div> </div>
''') """
html.append('</body></html>') )
return ''.join(html) html.append("</body></html>")
return "".join(html)
def generate_complicated_html(): def generate_complicated_html():
""" """
HTML with multiple domains, forms, data attributes, HTML with multiple domains, forms, data attributes,
various images, comments, style, and noscript to test all parameter toggles. various images, comments, style, and noscript to test all parameter toggles.
""" """
return """ return """
@@ -258,7 +270,7 @@ def generate_complicated_html():
def get_test_scenarios(): def get_test_scenarios():
""" """
Returns a dictionary of parameter sets (test scenarios) for the scraper. Returns a dictionary of parameter sets (test scenarios) for the scraper.
Each scenario name maps to a dictionary of keyword arguments Each scenario name maps to a dictionary of keyword arguments
that will be passed into scrap() for testing various features. that will be passed into scrap() for testing various features.
""" """
TEST_SCENARIOS = { TEST_SCENARIOS = {
@@ -341,7 +353,7 @@ def get_test_scenarios():
# "exclude_external_links": True # "exclude_external_links": True
# }, # },
# "comprehensive_removal": { # "comprehensive_removal": {
# # Exclude multiple tags, remove forms & comments, # # Exclude multiple tags, remove forms & comments,
# # and also remove targeted selectors # # and also remove targeted selectors
# "excluded_tags": ["aside", "noscript", "script"], # "excluded_tags": ["aside", "noscript", "script"],
# "excluded_selector": "#promo-section, .social-widget", # "excluded_selector": "#promo-section, .social-widget",
@@ -352,19 +364,18 @@ def get_test_scenarios():
return TEST_SCENARIOS return TEST_SCENARIOS
class ScraperEquivalenceTester: class ScraperEquivalenceTester:
def __init__(self): def __init__(self):
self.test_cases = { self.test_cases = {
'basic': self.generate_basic_html(), "basic": self.generate_basic_html(),
'complex': self.generate_complex_html(), "complex": self.generate_complex_html(),
'malformed': self.generate_malformed_html(), "malformed": self.generate_malformed_html(),
# 'real_world': self.load_real_samples() # 'real_world': self.load_real_samples()
} }
def generate_basic_html(self): def generate_basic_html(self):
return generate_large_html(1000) # Your existing function return generate_large_html(1000) # Your existing function
def generate_complex_html(self): def generate_complex_html(self):
return """ return """
<html><body> <html><body>
@@ -384,7 +395,7 @@ class ScraperEquivalenceTester:
</div> </div>
</body></html> </body></html>
""" """
def generate_malformed_html(self): def generate_malformed_html(self):
return """ return """
<div>Unclosed div <div>Unclosed div
@@ -395,139 +406,139 @@ class ScraperEquivalenceTester:
<!-- Malformed comment -- > --> <!-- Malformed comment -- > -->
<![CDATA[Test CDATA]]> <![CDATA[Test CDATA]]>
""" """
def load_real_samples(self): def load_real_samples(self):
# Load some real-world HTML samples you've collected # Load some real-world HTML samples you've collected
samples = { samples = {
'article': open('tests/samples/article.html').read(), "article": open("tests/samples/article.html").read(),
'product': open('tests/samples/product.html').read(), "product": open("tests/samples/product.html").read(),
'blog': open('tests/samples/blog.html').read() "blog": open("tests/samples/blog.html").read(),
} }
return samples return samples
def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]: def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]:
"""Detailed comparison of link structures""" """Detailed comparison of link structures"""
differences = [] differences = []
for category in ['internal', 'external']: for category in ["internal", "external"]:
old_urls = {link['href'] for link in old_links[category]} old_urls = {link["href"] for link in old_links[category]}
new_urls = {link['href'] for link in new_links[category]} new_urls = {link["href"] for link in new_links[category]}
missing = old_urls - new_urls missing = old_urls - new_urls
extra = new_urls - old_urls extra = new_urls - old_urls
if missing: if missing:
differences.append(f"Missing {category} links: {missing}") differences.append(f"Missing {category} links: {missing}")
if extra: if extra:
differences.append(f"Extra {category} links: {extra}") differences.append(f"Extra {category} links: {extra}")
# Compare link attributes for common URLs # Compare link attributes for common URLs
common = old_urls & new_urls common = old_urls & new_urls
for url in common: for url in common:
old_link = next(l for l in old_links[category] if l['href'] == url) old_link = next(l for l in old_links[category] if l["href"] == url)
new_link = next(l for l in new_links[category] if l['href'] == url) new_link = next(l for l in new_links[category] if l["href"] == url)
for attr in ['text', 'title']: for attr in ["text", "title"]:
if old_link[attr] != new_link[attr]: if old_link[attr] != new_link[attr]:
differences.append( differences.append(
f"Link attribute mismatch for {url} - {attr}:" f"Link attribute mismatch for {url} - {attr}:"
f" old='{old_link[attr]}' vs new='{new_link[attr]}'" f" old='{old_link[attr]}' vs new='{new_link[attr]}'"
) )
return differences return differences
def deep_compare_media(self, old_media: Dict, new_media: Dict) -> List[str]: def deep_compare_media(self, old_media: Dict, new_media: Dict) -> List[str]:
"""Detailed comparison of media elements""" """Detailed comparison of media elements"""
differences = [] differences = []
for media_type in ['images', 'videos', 'audios']: for media_type in ["images", "videos", "audios"]:
old_srcs = {item['src'] for item in old_media[media_type]} old_srcs = {item["src"] for item in old_media[media_type]}
new_srcs = {item['src'] for item in new_media[media_type]} new_srcs = {item["src"] for item in new_media[media_type]}
missing = old_srcs - new_srcs missing = old_srcs - new_srcs
extra = new_srcs - old_srcs extra = new_srcs - old_srcs
if missing: if missing:
differences.append(f"Missing {media_type}: {missing}") differences.append(f"Missing {media_type}: {missing}")
if extra: if extra:
differences.append(f"Extra {media_type}: {extra}") differences.append(f"Extra {media_type}: {extra}")
# Compare media attributes for common sources # Compare media attributes for common sources
common = old_srcs & new_srcs common = old_srcs & new_srcs
for src in common: for src in common:
old_item = next(m for m in old_media[media_type] if m['src'] == src) old_item = next(m for m in old_media[media_type] if m["src"] == src)
new_item = next(m for m in new_media[media_type] if m['src'] == src) new_item = next(m for m in new_media[media_type] if m["src"] == src)
for attr in ['alt', 'description']: for attr in ["alt", "description"]:
if old_item.get(attr) != new_item.get(attr): if old_item.get(attr) != new_item.get(attr):
differences.append( differences.append(
f"{media_type} attribute mismatch for {src} - {attr}:" f"{media_type} attribute mismatch for {src} - {attr}:"
f" old='{old_item.get(attr)}' vs new='{new_item.get(attr)}'" f" old='{old_item.get(attr)}' vs new='{new_item.get(attr)}'"
) )
return differences return differences
def compare_html_content(self, old_html: str, new_html: str) -> List[str]: def compare_html_content(self, old_html: str, new_html: str) -> List[str]:
"""Compare HTML content structure and text""" """Compare HTML content structure and text"""
# return compare_html_structurally(old_html, new_html) # return compare_html_structurally(old_html, new_html)
differences = [] differences = []
def normalize_html(html: str) -> Tuple[str, str]: def normalize_html(html: str) -> Tuple[str, str]:
soup = BeautifulSoup(html, 'lxml') soup = BeautifulSoup(html, "lxml")
# Get both structure and text # Get both structure and text
structure = ' '.join(tag.name for tag in soup.find_all()) structure = " ".join(tag.name for tag in soup.find_all())
text = ' '.join(soup.get_text().split()) text = " ".join(soup.get_text().split())
return structure, text return structure, text
old_structure, old_text = normalize_html(old_html) old_structure, old_text = normalize_html(old_html)
new_structure, new_text = normalize_html(new_html) new_structure, new_text = normalize_html(new_html)
# Compare structure # Compare structure
if abs(len(old_structure) - len(new_structure)) > 100: if abs(len(old_structure) - len(new_structure)) > 100:
# if old_structure != new_structure: # if old_structure != new_structure:
diff = difflib.unified_diff( diff = difflib.unified_diff(
old_structure.split(), old_structure.split(), new_structure.split(), lineterm=""
new_structure.split(),
lineterm=''
) )
differences.append("HTML structure differences:\n" + '\n'.join(diff)) differences.append("HTML structure differences:\n" + "\n".join(diff))
# Compare text content # Compare text content
if abs(len(old_text) - len(new_text)) > 100: if abs(len(old_text) - len(new_text)) > 100:
# if old_text != new_text: # if old_text != new_text:
# Show detailed text differences # Show detailed text differences
text_diff = difflib.unified_diff( text_diff = difflib.unified_diff(
old_text.split(), old_text.split(), new_text.split(), lineterm=""
new_text.split(),
lineterm=''
) )
differences.append("Text content differences:\n" + '\n'.join(text_diff)) differences.append("Text content differences:\n" + "\n".join(text_diff))
return differences return differences
def compare_results(self, old_result: Dict, new_result: Dict) -> Dict[str, List[str]]: def compare_results(
self, old_result: Dict, new_result: Dict
) -> Dict[str, List[str]]:
"""Comprehensive comparison of scraper outputs""" """Comprehensive comparison of scraper outputs"""
differences = {} differences = {}
# Compare links # Compare links
link_differences = self.deep_compare_links(old_result['links'], new_result['links']) link_differences = self.deep_compare_links(
old_result["links"], new_result["links"]
)
if link_differences: if link_differences:
differences['links'] = link_differences differences["links"] = link_differences
# Compare media # Compare media
media_differences = self.deep_compare_media(old_result['media'], new_result['media']) media_differences = self.deep_compare_media(
old_result["media"], new_result["media"]
)
if media_differences: if media_differences:
differences['media'] = media_differences differences["media"] = media_differences
# Compare HTML # Compare HTML
html_differences = self.compare_html_content( html_differences = self.compare_html_content(
old_result['cleaned_html'], old_result["cleaned_html"], new_result["cleaned_html"]
new_result['cleaned_html']
) )
if html_differences: if html_differences:
differences['html'] = html_differences differences["html"] = html_differences
return differences return differences
def run_tests(self) -> Dict: def run_tests(self) -> Dict:
@@ -535,52 +546,49 @@ class ScraperEquivalenceTester:
# We'll still keep some "test_cases" logic from above (basic, complex, malformed). # We'll still keep some "test_cases" logic from above (basic, complex, malformed).
# But we add a new section for the complicated HTML scenarios. # But we add a new section for the complicated HTML scenarios.
results = { results = {"tests": [], "summary": {"passed": 0, "failed": 0}}
'tests': [],
'summary': {'passed': 0, 'failed': 0}
}
# 1) First, run the existing 3 built-in test cases (basic, complex, malformed). # 1) First, run the existing 3 built-in test cases (basic, complex, malformed).
# for case_name, html in self.test_cases.items(): # for case_name, html in self.test_cases.items():
# print(f"\nTesting built-in case: {case_name}...") # print(f"\nTesting built-in case: {case_name}...")
# original = WebScrapingStrategy() # original = WebScrapingStrategy()
# lxml = LXMLWebScrapingStrategy() # lxml = LXMLWebScrapingStrategy()
# start = time.time() # start = time.time()
# orig_result = original.scrap("http://test.com", html) # orig_result = original.scrap("http://test.com", html)
# orig_time = time.time() - start # orig_time = time.time() - start
# print("\nOriginal Mode:") # print("\nOriginal Mode:")
# print(f"Cleaned HTML size: {len(orig_result['cleaned_html'])/1024:.2f} KB") # print(f"Cleaned HTML size: {len(orig_result['cleaned_html'])/1024:.2f} KB")
# print(f"Images: {len(orig_result['media']['images'])}") # print(f"Images: {len(orig_result['media']['images'])}")
# print(f"External links: {len(orig_result['links']['external'])}") # print(f"External links: {len(orig_result['links']['external'])}")
# print(f"Times - Original: {orig_time:.3f}s") # print(f"Times - Original: {orig_time:.3f}s")
# start = time.time() # start = time.time()
# lxml_result = lxml.scrap("http://test.com", html) # lxml_result = lxml.scrap("http://test.com", html)
# lxml_time = time.time() - start # lxml_time = time.time() - start
# print("\nLXML Mode:") # print("\nLXML Mode:")
# print(f"Cleaned HTML size: {len(lxml_result['cleaned_html'])/1024:.2f} KB") # print(f"Cleaned HTML size: {len(lxml_result['cleaned_html'])/1024:.2f} KB")
# print(f"Images: {len(lxml_result['media']['images'])}") # print(f"Images: {len(lxml_result['media']['images'])}")
# print(f"External links: {len(lxml_result['links']['external'])}") # print(f"External links: {len(lxml_result['links']['external'])}")
# print(f"Times - LXML: {lxml_time:.3f}s") # print(f"Times - LXML: {lxml_time:.3f}s")
# # Compare # # Compare
# diffs = {} # diffs = {}
# link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links']) # link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links'])
# if link_diff: # if link_diff:
# diffs['links'] = link_diff # diffs['links'] = link_diff
# media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media']) # media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media'])
# if media_diff: # if media_diff:
# diffs['media'] = media_diff # diffs['media'] = media_diff
# html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html']) # html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html'])
# if html_diff: # if html_diff:
# diffs['html'] = html_diff # diffs['html'] = html_diff
# test_result = { # test_result = {
# 'case': case_name, # 'case': case_name,
# 'lxml_mode': { # 'lxml_mode': {
@@ -590,7 +598,7 @@ class ScraperEquivalenceTester:
# 'original_time': orig_time # 'original_time': orig_time
# } # }
# results['tests'].append(test_result) # results['tests'].append(test_result)
# if not diffs: # if not diffs:
# results['summary']['passed'] += 1 # results['summary']['passed'] += 1
# else: # else:
@@ -599,50 +607,55 @@ class ScraperEquivalenceTester:
# 2) Now, run the complicated HTML with multiple parameter scenarios. # 2) Now, run the complicated HTML with multiple parameter scenarios.
complicated_html = generate_complicated_html() complicated_html = generate_complicated_html()
print("\n=== Testing complicated HTML with multiple parameter scenarios ===") print("\n=== Testing complicated HTML with multiple parameter scenarios ===")
# Create the scrapers once (or you can re-create if needed) # Create the scrapers once (or you can re-create if needed)
original = WebScrapingStrategy() original = WebScrapingStrategy()
lxml = LXMLWebScrapingStrategy() lxml = LXMLWebScrapingStrategy()
for scenario_name, params in get_test_scenarios().items(): for scenario_name, params in get_test_scenarios().items():
print(f"\nScenario: {scenario_name}") print(f"\nScenario: {scenario_name}")
start = time.time() start = time.time()
orig_result = original.scrap("http://test.com", complicated_html, **params) orig_result = original.scrap("http://test.com", complicated_html, **params)
orig_time = time.time() - start orig_time = time.time() - start
start = time.time() start = time.time()
lxml_result = lxml.scrap("http://test.com", complicated_html, **params) lxml_result = lxml.scrap("http://test.com", complicated_html, **params)
lxml_time = time.time() - start lxml_time = time.time() - start
diffs = {} diffs = {}
link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links']) link_diff = self.deep_compare_links(
orig_result["links"], lxml_result["links"]
)
if link_diff: if link_diff:
diffs['links'] = link_diff diffs["links"] = link_diff
media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media']) media_diff = self.deep_compare_media(
orig_result["media"], lxml_result["media"]
)
if media_diff: if media_diff:
diffs['media'] = media_diff diffs["media"] = media_diff
html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html']) html_diff = self.compare_html_content(
orig_result["cleaned_html"], lxml_result["cleaned_html"]
)
if html_diff: if html_diff:
diffs['html'] = html_diff diffs["html"] = html_diff
test_result = { test_result = {
'case': f"complicated_{scenario_name}", "case": f"complicated_{scenario_name}",
'lxml_mode': { "lxml_mode": {"differences": diffs, "execution_time": lxml_time},
'differences': diffs, "original_time": orig_time,
'execution_time': lxml_time
},
'original_time': orig_time
} }
results['tests'].append(test_result) results["tests"].append(test_result)
if not diffs: if not diffs:
results['summary']['passed'] += 1 results["summary"]["passed"] += 1
print(f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)") print(
f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)"
)
else: else:
results['summary']['failed'] += 1 results["summary"]["failed"] += 1
print("❌ Differences found:") print("❌ Differences found:")
for category, dlist in diffs.items(): for category, dlist in diffs.items():
print(f" {category}:") print(f" {category}:")
@@ -657,20 +670,22 @@ class ScraperEquivalenceTester:
print(f"Total Cases: {len(results['tests'])}") print(f"Total Cases: {len(results['tests'])}")
print(f"Passed: {results['summary']['passed']}") print(f"Passed: {results['summary']['passed']}")
print(f"Failed: {results['summary']['failed']}") print(f"Failed: {results['summary']['failed']}")
for test in results['tests']: for test in results["tests"]:
print(f"\nTest Case: {test['case']}") print(f"\nTest Case: {test['case']}")
if not test['lxml_mode']['differences']: if not test["lxml_mode"]["differences"]:
print("✅ All implementations produced identical results") print("✅ All implementations produced identical results")
print(f"Times - Original: {test['original_time']:.3f}s, " print(
f"LXML: {test['lxml_mode']['execution_time']:.3f}s") f"Times - Original: {test['original_time']:.3f}s, "
f"LXML: {test['lxml_mode']['execution_time']:.3f}s"
)
else: else:
print("❌ Differences found:") print("❌ Differences found:")
if test['lxml_mode']['differences']: if test["lxml_mode"]["differences"]:
print("\nLXML Mode Differences:") print("\nLXML Mode Differences:")
for category, diffs in test['lxml_mode']['differences'].items(): for category, diffs in test["lxml_mode"]["differences"].items():
print(f"\n{category}:") print(f"\n{category}:")
for diff in diffs: for diff in diffs:
print(f" - {diff}") print(f" - {diff}")
@@ -680,11 +695,11 @@ def main():
tester = ScraperEquivalenceTester() tester = ScraperEquivalenceTester()
results = tester.run_tests() results = tester.run_tests()
tester.print_report(results) tester.print_report(results)
# Save detailed results for debugging # Save detailed results for debugging
with open('scraper_equivalence_results.json', 'w') as f: with open("scraper_equivalence_results.json", "w") as f:
json.dump(results, f, indent=2) json.dump(results, f, indent=2)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -4,10 +4,10 @@
# - **State:** open # - **State:** open
import os, sys, time import os, sys, time
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir) sys.path.append(parent_dir)
__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
import asyncio
import os import os
import time import time
from typing import Dict, Any from typing import Dict, Any
@@ -16,18 +16,18 @@ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
# Get current directory # Get current directory
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
def print_test_result(name: str, result: Dict[str, Any], execution_time: float): def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
"""Helper function to print test results.""" """Helper function to print test results."""
print(f"\n{'='*20} {name} {'='*20}") print(f"\n{'='*20} {name} {'='*20}")
print(f"Execution time: {execution_time:.4f} seconds") print(f"Execution time: {execution_time:.4f} seconds")
# Save markdown to files # Save markdown to files
for key, content in result.items(): for key, content in result.items():
if isinstance(content, str): if isinstance(content, str):
with open(__location__ + f"/output/{name.lower()}_{key}.md", "w") as f: with open(__location__ + f"/output/{name.lower()}_{key}.md", "w") as f:
f.write(content) f.write(content)
# # Print first few lines of each markdown version # # Print first few lines of each markdown version
# for key, content in result.items(): # for key, content in result.items():
# if isinstance(content, str): # if isinstance(content, str):
@@ -36,32 +36,39 @@ def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
# print(preview) # print(preview)
# print(f"Total length: {len(content)} characters") # print(f"Total length: {len(content)} characters")
def test_basic_markdown_conversion(): def test_basic_markdown_conversion():
"""Test basic markdown conversion with links.""" """Test basic markdown conversion with links."""
with open(__location__ + "/data/wikipedia.html", "r") as f: with open(__location__ + "/data/wikipedia.html", "r") as f:
cleaned_html = f.read() cleaned_html = f.read()
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
start_time = time.perf_counter() start_time = time.perf_counter()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=cleaned_html, cleaned_html=cleaned_html, base_url="https://en.wikipedia.org"
base_url="https://en.wikipedia.org"
) )
execution_time = time.perf_counter() - start_time execution_time = time.perf_counter() - start_time
print_test_result("Basic Markdown Conversion", { print_test_result(
'raw': result.raw_markdown, "Basic Markdown Conversion",
'with_citations': result.markdown_with_citations, {
'references': result.references_markdown "raw": result.raw_markdown,
}, execution_time) "with_citations": result.markdown_with_citations,
"references": result.references_markdown,
},
execution_time,
)
# Basic assertions # Basic assertions
assert result.raw_markdown, "Raw markdown should not be empty" assert result.raw_markdown, "Raw markdown should not be empty"
assert result.markdown_with_citations, "Markdown with citations should not be empty" assert result.markdown_with_citations, "Markdown with citations should not be empty"
assert result.references_markdown, "References should not be empty" assert result.references_markdown, "References should not be empty"
assert "" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets" assert "" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets"
assert "## References" in result.references_markdown, "Should contain references section" assert (
"## References" in result.references_markdown
), "Should contain references section"
def test_relative_links(): def test_relative_links():
"""Test handling of relative links with base URL.""" """Test handling of relative links with base URL."""
@@ -69,97 +76,106 @@ def test_relative_links():
Here's a [relative link](/wiki/Apple) and an [absolute link](https://example.com). Here's a [relative link](/wiki/Apple) and an [absolute link](https://example.com).
Also an [image](/images/test.png) and another [page](/wiki/Banana). Also an [image](/images/test.png) and another [page](/wiki/Banana).
""" """
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://en.wikipedia.org"
base_url="https://en.wikipedia.org"
) )
assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown
assert "https://example.com" in result.references_markdown assert "https://example.com" in result.references_markdown
assert "https://en.wikipedia.org/images/test.png" in result.references_markdown assert "https://en.wikipedia.org/images/test.png" in result.references_markdown
def test_duplicate_links(): def test_duplicate_links():
"""Test handling of duplicate links.""" """Test handling of duplicate links."""
markdown = """ markdown = """
Here's a [link](/test) and another [link](/test) and a [different link](/other). Here's a [link](/test) and another [link](/test) and a [different link](/other).
""" """
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://example.com"
base_url="https://example.com"
) )
# Count citations in markdown # Count citations in markdown
citations = result.markdown_with_citations.count("⟨1⟩") citations = result.markdown_with_citations.count("⟨1⟩")
assert citations == 2, "Same link should use same citation number" assert citations == 2, "Same link should use same citation number"
def test_link_descriptions(): def test_link_descriptions():
"""Test handling of link titles and descriptions.""" """Test handling of link titles and descriptions."""
markdown = """ markdown = """
Here's a [link with title](/test "Test Title") and a [link with description](/other) to test. Here's a [link with title](/test "Test Title") and a [link with description](/other) to test.
""" """
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://example.com"
base_url="https://example.com"
) )
assert "Test Title" in result.references_markdown, "Link title should be in references" assert (
assert "link with description" in result.references_markdown, "Link text should be in references" "Test Title" in result.references_markdown
), "Link title should be in references"
assert (
"link with description" in result.references_markdown
), "Link text should be in references"
def test_performance_large_document(): def test_performance_large_document():
"""Test performance with large document.""" """Test performance with large document."""
with open(__location__ + "/data/wikipedia.md", "r") as f: with open(__location__ + "/data/wikipedia.md", "r") as f:
markdown = f.read() markdown = f.read()
# Test with multiple iterations # Test with multiple iterations
iterations = 5 iterations = 5
times = [] times = []
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
for i in range(iterations): for i in range(iterations):
start_time = time.perf_counter() start_time = time.perf_counter()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://en.wikipedia.org"
base_url="https://en.wikipedia.org"
) )
end_time = time.perf_counter() end_time = time.perf_counter()
times.append(end_time - start_time) times.append(end_time - start_time)
avg_time = sum(times) / len(times) avg_time = sum(times) / len(times)
print(f"\n{'='*20} Performance Test {'='*20}") print(f"\n{'='*20} Performance Test {'='*20}")
print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds") print(
f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds"
)
print(f"Min time: {min(times):.4f} seconds") print(f"Min time: {min(times):.4f} seconds")
print(f"Max time: {max(times):.4f} seconds") print(f"Max time: {max(times):.4f} seconds")
def test_image_links(): def test_image_links():
"""Test handling of image links.""" """Test handling of image links."""
markdown = """ markdown = """
Here's an ![image](/image.png "Image Title") and another ![image](/other.jpg). Here's an ![image](/image.png "Image Title") and another ![image](/other.jpg).
And a regular [link](/page). And a regular [link](/page).
""" """
generator = DefaultMarkdownGenerator() generator = DefaultMarkdownGenerator()
result = generator.generate_markdown( result = generator.generate_markdown(
cleaned_html=markdown, cleaned_html=markdown, base_url="https://example.com"
base_url="https://example.com"
) )
assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved" assert (
assert "Image Title" in result.references_markdown, "Image title should be in references" "![" in result.markdown_with_citations
), "Image markdown syntax should be preserved"
assert (
"Image Title" in result.references_markdown
), "Image title should be in references"
if __name__ == "__main__": if __name__ == "__main__":
print("Running markdown generation strategy tests...") print("Running markdown generation strategy tests...")
test_basic_markdown_conversion() test_basic_markdown_conversion()
test_relative_links() test_relative_links()
test_duplicate_links() test_duplicate_links()
test_link_descriptions() test_link_descriptions()
test_performance_large_document() test_performance_large_document()
test_image_links() test_image_links()

View File

@@ -1,8 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import json
# Add the parent directory to the Python path # Add the parent directory to the Python path
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -10,24 +8,37 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_word_count_threshold(): async def test_word_count_threshold():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True) result_no_threshold = await crawler.arun(
result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True) url=url, word_count_threshold=0, bypass_cache=True
)
result_with_threshold = await crawler.arun(
url=url, word_count_threshold=50, bypass_cache=True
)
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown) assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_css_selector(): async def test_css_selector():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
css_selector = "h1, h2, h3" css_selector = "h1, h2, h3"
result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True) result = await crawler.arun(
url=url, css_selector=css_selector, bypass_cache=True
)
assert result.success assert result.success
assert "<h1" in result.cleaned_html or "<h2" in result.cleaned_html or "<h3" in result.cleaned_html assert (
"<h1" in result.cleaned_html
or "<h2" in result.cleaned_html
or "<h3" in result.cleaned_html
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_javascript_execution(): async def test_javascript_execution():
@@ -36,59 +47,70 @@ async def test_javascript_execution():
# Crawl without JS # Crawl without JS
result_without_more = await crawler.arun(url=url, bypass_cache=True) result_without_more = await crawler.arun(url=url, bypass_cache=True)
js_code = ["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"] js_code = [
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
]
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True) result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
assert result_with_more.success assert result_with_more.success
assert len(result_with_more.markdown) > len(result_without_more.markdown) assert len(result_with_more.markdown) > len(result_without_more.markdown)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot(): async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
result = await crawler.arun(url=url, screenshot=True, bypass_cache=True) result = await crawler.arun(url=url, screenshot=True, bypass_cache=True)
assert result.success assert result.success
assert result.screenshot assert result.screenshot
assert isinstance(result.screenshot, str) # Should be a base64 encoded string assert isinstance(result.screenshot, str) # Should be a base64 encoded string
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_custom_user_agent(): async def test_custom_user_agent():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0" custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True) result = await crawler.arun(
url=url, user_agent=custom_user_agent, bypass_cache=True
)
assert result.success assert result.success
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful # Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_media_and_links(): async def test_extract_media_and_links():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
result = await crawler.arun(url=url, bypass_cache=True) result = await crawler.arun(url=url, bypass_cache=True)
assert result.success assert result.success
assert result.media assert result.media
assert isinstance(result.media, dict) assert isinstance(result.media, dict)
assert 'images' in result.media assert "images" in result.media
assert result.links assert result.links
assert isinstance(result.links, dict) assert isinstance(result.links, dict)
assert 'internal' in result.links and 'external' in result.links assert "internal" in result.links and "external" in result.links
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metadata_extraction(): async def test_metadata_extraction():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
result = await crawler.arun(url=url, bypass_cache=True) result = await crawler.arun(url=url, bypass_cache=True)
assert result.success assert result.success
assert result.metadata assert result.metadata
assert isinstance(result.metadata, dict) assert isinstance(result.metadata, dict)
# Check for common metadata fields # Check for common metadata fields
assert any(key in result.metadata for key in ['title', 'description', 'keywords']) assert any(
key in result.metadata for key in ["title", "description", "keywords"]
)
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import time import time
# Add the parent directory to the Python path # Add the parent directory to the Python path
@@ -10,6 +9,7 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_crawl_speed(): async def test_crawl_speed():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -17,13 +17,14 @@ async def test_crawl_speed():
start_time = time.time() start_time = time.time()
result = await crawler.arun(url=url, bypass_cache=True) result = await crawler.arun(url=url, bypass_cache=True)
end_time = time.time() end_time = time.time()
assert result.success assert result.success
crawl_time = end_time - start_time crawl_time = end_time - start_time
print(f"Crawl time: {crawl_time:.2f} seconds") print(f"Crawl time: {crawl_time:.2f} seconds")
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds" assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_concurrent_crawling_performance(): async def test_concurrent_crawling_performance():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
@@ -32,41 +33,47 @@ async def test_concurrent_crawling_performance():
"https://www.example.com", "https://www.example.com",
"https://www.python.org", "https://www.python.org",
"https://www.github.com", "https://www.github.com",
"https://www.stackoverflow.com" "https://www.stackoverflow.com",
] ]
start_time = time.time() start_time = time.time()
results = await crawler.arun_many(urls=urls, bypass_cache=True) results = await crawler.arun_many(urls=urls, bypass_cache=True)
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
print(f"Total time for concurrent crawling: {total_time:.2f} seconds") print(f"Total time for concurrent crawling: {total_time:.2f} seconds")
assert all(result.success for result in results) assert all(result.success for result in results)
assert len(results) == len(urls) assert len(results) == len(urls)
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds" assert (
total_time < len(urls) * 5
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_crawl_speed_with_caching(): async def test_crawl_speed_with_caching():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nbcnews.com/business" url = "https://www.nbcnews.com/business"
start_time = time.time() start_time = time.time()
result1 = await crawler.arun(url=url, bypass_cache=True) result1 = await crawler.arun(url=url, bypass_cache=True)
end_time = time.time() end_time = time.time()
first_crawl_time = end_time - start_time first_crawl_time = end_time - start_time
start_time = time.time() start_time = time.time()
result2 = await crawler.arun(url=url, bypass_cache=False) result2 = await crawler.arun(url=url, bypass_cache=False)
end_time = time.time() end_time = time.time()
second_crawl_time = end_time - start_time second_crawl_time = end_time - start_time
assert result1.success and result2.success assert result1.success and result2.success
print(f"First crawl time: {first_crawl_time:.2f} seconds") print(f"First crawl time: {first_crawl_time:.2f} seconds")
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds") print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster" assert (
second_crawl_time < first_crawl_time / 2
), "Cached crawl not significantly faster"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -1,7 +1,6 @@
import os import os
import sys import sys
import pytest import pytest
import asyncio
import base64 import base64
from PIL import Image from PIL import Image
import io import io
@@ -12,113 +11,112 @@ sys.path.append(parent_dir)
from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.async_webcrawler import AsyncWebCrawler
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_screenshot(): async def test_basic_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://example.com" # A static website url = "https://example.com" # A static website
result = await crawler.arun(url=url, bypass_cache=True, screenshot=True) result = await crawler.arun(url=url, bypass_cache=True, screenshot=True)
assert result.success assert result.success
assert result.screenshot is not None assert result.screenshot is not None
# Verify the screenshot is a valid image # Verify the screenshot is a valid image
image_data = base64.b64decode(result.screenshot) image_data = base64.b64decode(result.screenshot)
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_with_wait_for(): async def test_screenshot_with_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
# Using a website with dynamic content # Using a website with dynamic content
url = "https://www.youtube.com" url = "https://www.youtube.com"
wait_for = "css:#content" # Wait for the main content to load wait_for = "css:#content" # Wait for the main content to load
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
bypass_cache=True,
screenshot=True,
wait_for=wait_for
) )
assert result.success assert result.success
assert result.screenshot is not None assert result.screenshot is not None
# Verify the screenshot is a valid image # Verify the screenshot is a valid image
image_data = base64.b64decode(result.screenshot) image_data = base64.b64decode(result.screenshot)
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
# You might want to add more specific checks here, like image dimensions # You might want to add more specific checks here, like image dimensions
# or even use image recognition to verify certain elements are present # or even use image recognition to verify certain elements are present
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_with_js_wait_for(): async def test_screenshot_with_js_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.amazon.com" url = "https://www.amazon.com"
wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null" wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null"
result = await crawler.arun( result = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
bypass_cache=True,
screenshot=True,
wait_for=wait_for
) )
assert result.success assert result.success
assert result.screenshot is not None assert result.screenshot is not None
image_data = base64.b64decode(result.screenshot) image_data = base64.b64decode(result.screenshot)
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_without_wait_for(): async def test_screenshot_without_wait_for():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.nytimes.com" # A website with lots of dynamic content url = "https://www.nytimes.com" # A website with lots of dynamic content
result = await crawler.arun(url=url, bypass_cache=True, screenshot=True) result = await crawler.arun(url=url, bypass_cache=True, screenshot=True)
assert result.success assert result.success
assert result.screenshot is not None assert result.screenshot is not None
image_data = base64.b64decode(result.screenshot) image_data = base64.b64decode(result.screenshot)
image = Image.open(io.BytesIO(image_data)) image = Image.open(io.BytesIO(image_data))
assert image.format == "PNG" assert image.format == "PNG"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_screenshot_comparison(): async def test_screenshot_comparison():
async with AsyncWebCrawler(verbose=True) as crawler: async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.reddit.com" url = "https://www.reddit.com"
wait_for = "css:#SHORTCUT_FOCUSABLE_DIV" wait_for = "css:#SHORTCUT_FOCUSABLE_DIV"
# Take screenshot without wait_for # Take screenshot without wait_for
result_without_wait = await crawler.arun( result_without_wait = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True
bypass_cache=True,
screenshot=True
) )
# Take screenshot with wait_for # Take screenshot with wait_for
result_with_wait = await crawler.arun( result_with_wait = await crawler.arun(
url=url, url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
bypass_cache=True,
screenshot=True,
wait_for=wait_for
) )
assert result_without_wait.success and result_with_wait.success assert result_without_wait.success and result_with_wait.success
assert result_without_wait.screenshot is not None assert result_without_wait.screenshot is not None
assert result_with_wait.screenshot is not None assert result_with_wait.screenshot is not None
# Compare the two screenshots # Compare the two screenshots
image_without_wait = Image.open(io.BytesIO(base64.b64decode(result_without_wait.screenshot))) image_without_wait = Image.open(
image_with_wait = Image.open(io.BytesIO(base64.b64decode(result_with_wait.screenshot))) io.BytesIO(base64.b64decode(result_without_wait.screenshot))
)
image_with_wait = Image.open(
io.BytesIO(base64.b64decode(result_with_wait.screenshot))
)
# This is a simple size comparison. In a real-world scenario, you might want to use # This is a simple size comparison. In a real-world scenario, you might want to use
# more sophisticated image comparison techniques. # more sophisticated image comparison techniques.
assert image_with_wait.size[0] >= image_without_wait.size[0] assert image_with_wait.size[0] >= image_without_wait.size[0]
assert image_with_wait.size[1] >= image_without_wait.size[1] assert image_with_wait.size[1] >= image_without_wait.size[1]
# Entry point for debugging # Entry point for debugging
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])

View File

@@ -6,53 +6,72 @@ import base64
import os import os
from typing import Dict, Any from typing import Dict, Any
class Crawl4AiTester: class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
self.base_url = base_url self.base_url = base_url
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback self.api_token = api_token or os.getenv(
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {} "CRAWL4AI_API_TOKEN"
) # Check environment variable as fallback
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: self.headers = (
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
)
def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job # Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers) response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
if response.status_code == 403: if response.status_code == 403:
raise Exception("API token is invalid or missing") raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"] task_id = response.json()["task_id"]
print(f"Task ID: {task_id}") print(f"Task ID: {task_id}")
# Poll for result # Poll for result
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers) )
result = requests.get(
f"{self.base_url}/task/{task_id}", headers=self.headers
)
status = result.json() status = result.json()
if status["status"] == "failed": if status["status"] == "failed":
print("Task failed:", status.get("error")) print("Task failed:", status.get("error"))
raise Exception(f"Task failed: {status.get('error')}") raise Exception(f"Task failed: {status.get('error')}")
if status["status"] == "completed": if status["status"] == "completed":
return status return status
time.sleep(2) time.sleep(2)
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60) response = requests.post(
f"{self.base_url}/crawl_sync",
json=request_data,
headers=self.headers,
timeout=60,
)
if response.status_code == 408: if response.status_code == 408:
raise TimeoutError("Task did not complete within server timeout") raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def test_docker_deployment(version="basic"): def test_docker_deployment(version="basic"):
tester = Crawl4AiTester( tester = Crawl4AiTester(
# base_url="http://localhost:11235" , # base_url="http://localhost:11235" ,
base_url="https://crawl4ai-sby74.ondigitalocean.app", base_url="https://crawl4ai-sby74.ondigitalocean.app",
api_token="test" api_token="test",
) )
print(f"Testing Crawl4AI Docker {version} version") print(f"Testing Crawl4AI Docker {version} version")
# Health check with timeout and retry # Health check with timeout and retry
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
@@ -60,18 +79,18 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10) health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json()) print("Health check:", health.json())
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
if i == max_retries - 1: if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts") print(f"Failed to connect after {max_retries} attempts")
sys.exit(1) sys.exit(1)
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
time.sleep(5) time.sleep(5)
# Test cases based on version # Test cases based on version
test_basic_crawl(tester) test_basic_crawl(tester)
test_basic_crawl(tester) test_basic_crawl(tester)
test_basic_crawl_sync(tester) test_basic_crawl_sync(tester)
# if version in ["full", "transformer"]: # if version in ["full", "transformer"]:
# test_cosine_extraction(tester) # test_cosine_extraction(tester)
@@ -81,35 +100,37 @@ def test_docker_deployment(version="basic"):
# test_llm_extraction(tester) # test_llm_extraction(tester)
# test_llm_with_ollama(tester) # test_llm_with_ollama(tester)
# test_screenshot(tester) # test_screenshot(tester)
def test_basic_crawl(tester: Crawl4AiTester): def test_basic_crawl(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl ===") print("\n=== Testing Basic Crawl ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0 assert len(result["result"]["markdown"]) > 0
def test_basic_crawl_sync(tester: Crawl4AiTester): def test_basic_crawl_sync(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl (Sync) ===") print("\n=== Testing Basic Crawl (Sync) ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 10, "priority": 10,
"session_id": "test" "session_id": "test",
} }
result = tester.submit_sync(request) result = tester.submit_sync(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result['status'] == 'completed' assert result["status"] == "completed"
assert result['result']['success'] assert result["result"]["success"]
assert len(result['result']['markdown']) > 0 assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester): def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
request = { request = {
@@ -119,32 +140,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}") print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester): def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description", "css_selector": ".wide-tease-item__description",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True "extra": {"word_count_threshold": 10},
},
"extra": {"word_count_threshold": 10}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}") print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester): def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
schema = { schema = {
@@ -165,21 +183,16 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price", "name": "price",
"selector": "td:nth-child(2)", "selector": "td:nth-child(2)",
"type": "text", "type": "text",
} },
], ],
} }
request = { request = {
"urls": "https://www.coinbase.com/explore", "urls": "https://www.coinbase.com/explore",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted {len(extracted)} items") print(f"Extracted {len(extracted)} items")
@@ -187,6 +200,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester): def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===") print("\n=== Testing LLM Extraction ===")
schema = { schema = {
@@ -194,20 +208,20 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": { "properties": {
"model_name": { "model_name": {
"type": "string", "type": "string",
"description": "Name of the OpenAI model." "description": "Name of the OpenAI model.",
}, },
"input_fee": { "input_fee": {
"type": "string", "type": "string",
"description": "Fee for input token for the OpenAI model." "description": "Fee for input token for the OpenAI model.",
}, },
"output_fee": { "output_fee": {
"type": "string", "type": "string",
"description": "Fee for output token for the OpenAI model." "description": "Fee for output token for the OpenAI model.",
} },
}, },
"required": ["model_name", "input_fee", "output_fee"] "required": ["model_name", "input_fee", "output_fee"],
} }
request = { request = {
"urls": "https://openai.com/api/pricing", "urls": "https://openai.com/api/pricing",
"priority": 8, "priority": 8,
@@ -218,12 +232,12 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"), "api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
} },
}, },
"crawler_params": {"word_count_threshold": 1} "crawler_params": {"word_count_threshold": 1},
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -233,6 +247,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester): def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===") print("\n=== Testing LLM with Ollama ===")
schema = { schema = {
@@ -240,20 +255,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
} },
} }
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 8, "priority": 8,
@@ -263,13 +278,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2", "provider": "ollama/llama2",
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics." "instruction": "Extract the main article information including title, summary, and main topics.",
} },
}, },
"extra": {"word_count_threshold": 1}, "extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True} "crawler_params": {"verbose": True},
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -278,6 +293,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Ollama extraction test failed: {str(e)}") print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester): def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===") print("\n=== Testing Cosine Extraction ===")
request = { request = {
@@ -289,11 +305,11 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy", "semantic_filter": "business finance economy",
"word_count_threshold": 10, "word_count_threshold": 10,
"max_dist": 0.2, "max_dist": 0.2,
"top_k": 3 "top_k": 3,
} },
} },
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -303,30 +319,30 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Cosine extraction test failed: {str(e)}") print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester): def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print("Screenshot captured:", bool(result["result"]["screenshot"])) print("Screenshot captured:", bool(result["result"]["screenshot"]))
if result["result"]["screenshot"]: if result["result"]["screenshot"]:
# Save screenshot # Save screenshot
screenshot_data = base64.b64decode(result["result"]["screenshot"]) screenshot_data = base64.b64decode(result["result"]["screenshot"])
with open("test_screenshot.jpg", "wb") as f: with open("test_screenshot.jpg", "wb") as f:
f.write(screenshot_data) f.write(screenshot_data)
print("Screenshot saved as test_screenshot.jpg") print("Screenshot saved as test_screenshot.jpg")
assert result["result"]["success"] assert result["result"]["success"]
if __name__ == "__main__": if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic" version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full" # version = "full"
test_docker_deployment(version) test_docker_deployment(version)

View File

@@ -1,13 +1,13 @@
import asyncio import asyncio
from pathlib import Path
from crawl4ai.docs_manager import DocsManager from crawl4ai.docs_manager import DocsManager
from click.testing import CliRunner from click.testing import CliRunner
from crawl4ai.cli import cli from crawl4ai.cli import cli
def test_cli(): def test_cli():
"""Test all CLI commands""" """Test all CLI commands"""
runner = CliRunner() runner = CliRunner()
print("\n1. Testing docs update...") print("\n1. Testing docs update...")
# Use sync version for testing # Use sync version for testing
docs_manager = DocsManager() docs_manager = DocsManager()
@@ -27,17 +27,18 @@ def test_cli():
# print("\n3. Testing search...") # print("\n3. Testing search...")
# result = runner.invoke(cli, ['docs', 'search', 'how to use crawler', '--build-index']) # result = runner.invoke(cli, ['docs', 'search', 'how to use crawler', '--build-index'])
# print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
# print(f"First 200 chars: {result.output[:200]}...") # print(f"First 200 chars: {result.output[:200]}...")
# print("\n4. Testing combine with sections...") # print("\n4. Testing combine with sections...")
# result = runner.invoke(cli, ['docs', 'combine', 'chunking_strategies', 'extraction_strategies', '--mode', 'extended']) # result = runner.invoke(cli, ['docs', 'combine', 'chunking_strategies', 'extraction_strategies', '--mode', 'extended'])
# print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
# print(f"First 200 chars: {result.output[:200]}...") # print(f"First 200 chars: {result.output[:200]}...")
print("\n5. Testing combine all sections...") print("\n5. Testing combine all sections...")
result = runner.invoke(cli, ['docs', 'combine', '--mode', 'condensed']) result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"])
print(f"Status: {'' if result.exit_code == 0 else ''}") print(f"Status: {'' if result.exit_code == 0 else ''}")
print(f"First 200 chars: {result.output[:200]}...") print(f"First 200 chars: {result.output[:200]}...")
if __name__ == "__main__": if __name__ == "__main__":
test_cli() test_cli()

View File

@@ -6,38 +6,44 @@ import base64
import os import os
from typing import Dict, Any from typing import Dict, Any
class Crawl4AiTester: class Crawl4AiTester:
def __init__(self, base_url: str = "http://localhost:11235"): def __init__(self, base_url: str = "http://localhost:11235"):
self.base_url = base_url self.base_url = base_url
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]: def submit_and_wait(
self, request_data: Dict[str, Any], timeout: int = 300
) -> Dict[str, Any]:
# Submit crawl job # Submit crawl job
response = requests.post(f"{self.base_url}/crawl", json=request_data) response = requests.post(f"{self.base_url}/crawl", json=request_data)
task_id = response.json()["task_id"] task_id = response.json()["task_id"]
print(f"Task ID: {task_id}") print(f"Task ID: {task_id}")
# Poll for result # Poll for result
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
result = requests.get(f"{self.base_url}/task/{task_id}") result = requests.get(f"{self.base_url}/task/{task_id}")
status = result.json() status = result.json()
if status["status"] == "failed": if status["status"] == "failed":
print("Task failed:", status.get("error")) print("Task failed:", status.get("error"))
raise Exception(f"Task failed: {status.get('error')}") raise Exception(f"Task failed: {status.get('error')}")
if status["status"] == "completed": if status["status"] == "completed":
return status return status
time.sleep(2) time.sleep(2)
def test_docker_deployment(version="basic"): def test_docker_deployment(version="basic"):
tester = Crawl4AiTester() tester = Crawl4AiTester()
print(f"Testing Crawl4AI Docker {version} version") print(f"Testing Crawl4AI Docker {version} version")
# Health check with timeout and retry # Health check with timeout and retry
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
@@ -45,16 +51,16 @@ def test_docker_deployment(version="basic"):
health = requests.get(f"{tester.base_url}/health", timeout=10) health = requests.get(f"{tester.base_url}/health", timeout=10)
print("Health check:", health.json()) print("Health check:", health.json())
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException:
if i == max_retries - 1: if i == max_retries - 1:
print(f"Failed to connect after {max_retries} attempts") print(f"Failed to connect after {max_retries} attempts")
sys.exit(1) sys.exit(1)
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
time.sleep(5) time.sleep(5)
# Test cases based on version # Test cases based on version
test_basic_crawl(tester) test_basic_crawl(tester)
# if version in ["full", "transformer"]: # if version in ["full", "transformer"]:
# test_cosine_extraction(tester) # test_cosine_extraction(tester)
@@ -64,20 +70,18 @@ def test_docker_deployment(version="basic"):
# test_llm_extraction(tester) # test_llm_extraction(tester)
# test_llm_with_ollama(tester) # test_llm_with_ollama(tester)
# test_screenshot(tester) # test_screenshot(tester)
def test_basic_crawl(tester: Crawl4AiTester): def test_basic_crawl(tester: Crawl4AiTester):
print("\n=== Testing Basic Crawl ===") print("\n=== Testing Basic Crawl ===")
request = { request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
assert len(result["result"]["markdown"]) > 0 assert len(result["result"]["markdown"]) > 0
def test_js_execution(tester: Crawl4AiTester): def test_js_execution(tester: Crawl4AiTester):
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
request = { request = {
@@ -87,32 +91,29 @@ def test_js_execution(tester: Crawl4AiTester):
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"JS execution result length: {len(result['result']['markdown'])}") print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_css_selector(tester: Crawl4AiTester): def test_css_selector(tester: Crawl4AiTester):
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description", "css_selector": ".wide-tease-item__description",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True "extra": {"word_count_threshold": 10},
},
"extra": {"word_count_threshold": 10}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print(f"CSS selector result length: {len(result['result']['markdown'])}") print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["result"]["success"] assert result["result"]["success"]
def test_structured_extraction(tester: Crawl4AiTester): def test_structured_extraction(tester: Crawl4AiTester):
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
schema = { schema = {
@@ -133,21 +134,16 @@ def test_structured_extraction(tester: Crawl4AiTester):
"name": "price", "name": "price",
"selector": "td:nth-child(2)", "selector": "td:nth-child(2)",
"type": "text", "type": "text",
} },
], ],
} }
request = { request = {
"urls": "https://www.coinbase.com/explore", "urls": "https://www.coinbase.com/explore",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted {len(extracted)} items") print(f"Extracted {len(extracted)} items")
@@ -155,6 +151,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
def test_llm_extraction(tester: Crawl4AiTester): def test_llm_extraction(tester: Crawl4AiTester):
print("\n=== Testing LLM Extraction ===") print("\n=== Testing LLM Extraction ===")
schema = { schema = {
@@ -162,20 +159,20 @@ def test_llm_extraction(tester: Crawl4AiTester):
"properties": { "properties": {
"model_name": { "model_name": {
"type": "string", "type": "string",
"description": "Name of the OpenAI model." "description": "Name of the OpenAI model.",
}, },
"input_fee": { "input_fee": {
"type": "string", "type": "string",
"description": "Fee for input token for the OpenAI model." "description": "Fee for input token for the OpenAI model.",
}, },
"output_fee": { "output_fee": {
"type": "string", "type": "string",
"description": "Fee for output token for the OpenAI model." "description": "Fee for output token for the OpenAI model.",
} },
}, },
"required": ["model_name", "input_fee", "output_fee"] "required": ["model_name", "input_fee", "output_fee"],
} }
request = { request = {
"urls": "https://openai.com/api/pricing", "urls": "https://openai.com/api/pricing",
"priority": 8, "priority": 8,
@@ -186,12 +183,12 @@ def test_llm_extraction(tester: Crawl4AiTester):
"api_token": os.getenv("OPENAI_API_KEY"), "api_token": os.getenv("OPENAI_API_KEY"),
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""" "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
} },
}, },
"crawler_params": {"word_count_threshold": 1} "crawler_params": {"word_count_threshold": 1},
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -201,6 +198,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
def test_llm_with_ollama(tester: Crawl4AiTester): def test_llm_with_ollama(tester: Crawl4AiTester):
print("\n=== Testing LLM with Ollama ===") print("\n=== Testing LLM with Ollama ===")
schema = { schema = {
@@ -208,20 +206,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
} },
} }
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 8, "priority": 8,
@@ -231,13 +229,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
"provider": "ollama/llama2", "provider": "ollama/llama2",
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": "Extract the main article information including title, summary, and main topics." "instruction": "Extract the main article information including title, summary, and main topics.",
} },
}, },
"extra": {"word_count_threshold": 1}, "extra": {"word_count_threshold": 1},
"crawler_params": {"verbose": True} "crawler_params": {"verbose": True},
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -246,6 +244,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Ollama extraction test failed: {str(e)}") print(f"Ollama extraction test failed: {str(e)}")
def test_cosine_extraction(tester: Crawl4AiTester): def test_cosine_extraction(tester: Crawl4AiTester):
print("\n=== Testing Cosine Extraction ===") print("\n=== Testing Cosine Extraction ===")
request = { request = {
@@ -257,11 +256,11 @@ def test_cosine_extraction(tester: Crawl4AiTester):
"semantic_filter": "business finance economy", "semantic_filter": "business finance economy",
"word_count_threshold": 10, "word_count_threshold": 10,
"max_dist": 0.2, "max_dist": 0.2,
"top_k": 3 "top_k": 3,
} },
} },
} }
try: try:
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
@@ -271,30 +270,30 @@ def test_cosine_extraction(tester: Crawl4AiTester):
except Exception as e: except Exception as e:
print(f"Cosine extraction test failed: {str(e)}") print(f"Cosine extraction test failed: {str(e)}")
def test_screenshot(tester: Crawl4AiTester): def test_screenshot(tester: Crawl4AiTester):
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
result = tester.submit_and_wait(request) result = tester.submit_and_wait(request)
print("Screenshot captured:", bool(result["result"]["screenshot"])) print("Screenshot captured:", bool(result["result"]["screenshot"]))
if result["result"]["screenshot"]: if result["result"]["screenshot"]:
# Save screenshot # Save screenshot
screenshot_data = base64.b64decode(result["result"]["screenshot"]) screenshot_data = base64.b64decode(result["result"]["screenshot"])
with open("test_screenshot.jpg", "wb") as f: with open("test_screenshot.jpg", "wb") as f:
f.write(screenshot_data) f.write(screenshot_data)
print("Screenshot saved as test_screenshot.jpg") print("Screenshot saved as test_screenshot.jpg")
assert result["result"]["success"] assert result["result"]["success"]
if __name__ == "__main__": if __name__ == "__main__":
version = sys.argv[1] if len(sys.argv) > 1 else "basic" version = sys.argv[1] if len(sys.argv) > 1 else "basic"
# version = "full" # version = "full"
test_docker_deployment(version) test_docker_deployment(version)

View File

@@ -3,20 +3,21 @@ from crawl4ai.async_logger import AsyncLogger
from pathlib import Path from pathlib import Path
import asyncio import asyncio
async def main(): async def main():
current_file = Path(__file__).resolve() current_file = Path(__file__).resolve()
# base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs" # base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
base_dir = current_file.parent.parent / "local/_docs/llm.txt" base_dir = current_file.parent.parent / "local/_docs/llm.txt"
docs_dir = base_dir docs_dir = base_dir
# Create directory if it doesn't exist # Create directory if it doesn't exist
docs_dir.mkdir(parents=True, exist_ok=True) docs_dir.mkdir(parents=True, exist_ok=True)
# Initialize logger # Initialize logger
logger = AsyncLogger() logger = AsyncLogger()
# Updated initialization with default batching params # Updated initialization with default batching params
# manager = AsyncLLMTextManager(docs_dir, logger, max_concurrent_calls=3, batch_size=2) # manager = AsyncLLMTextManager(docs_dir, logger, max_concurrent_calls=3, batch_size=2)
manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2) manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2)
# Let's first check what files we have # Let's first check what files we have
print("\nAvailable files:") print("\nAvailable files:")
@@ -26,8 +27,7 @@ async def main():
# Generate index files # Generate index files
print("\nGenerating index files...") print("\nGenerating index files...")
await manager.generate_index_files( await manager.generate_index_files(
force_generate_facts=False, force_generate_facts=False, clear_bm25_cache=False
clear_bm25_cache=False
) )
# Test some relevant queries about Crawl4AI # Test some relevant queries about Crawl4AI
@@ -41,9 +41,12 @@ async def main():
results = manager.search(query, top_k=2) results = manager.search(query, top_k=2)
print(f"Results length: {len(results)} characters") print(f"Results length: {len(results)} characters")
if results: if results:
print("First 200 chars of results:", results[:200].replace('\n', ' '), "...") print(
"First 200 chars of results:", results[:200].replace("\n", " "), "..."
)
else: else:
print("No results found") print("No results found")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -3,8 +3,8 @@ import aiohttp
import json import json
import time import time
import os import os
from typing import Optional, Dict, Any from typing import Dict, Any
from pydantic import BaseModel, HttpUrl
class NBCNewsAPITest: class NBCNewsAPITest:
def __init__(self, base_url: str = "http://localhost:8000"): def __init__(self, base_url: str = "http://localhost:8000"):
@@ -20,7 +20,9 @@ class NBCNewsAPITest:
await self.session.close() await self.session.close()
async def submit_crawl(self, request_data: Dict[str, Any]) -> str: async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response: async with self.session.post(
f"{self.base_url}/crawl", json=request_data
) as response:
result = await response.json() result = await response.json()
return result["task_id"] return result["task_id"]
@@ -28,11 +30,15 @@ class NBCNewsAPITest:
async with self.session.get(f"{self.base_url}/task/{task_id}") as response: async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
return await response.json() return await response.json()
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]: async def wait_for_task(
self, task_id: str, timeout: int = 300, poll_interval: int = 2
) -> Dict[str, Any]:
start_time = time.time() start_time = time.time()
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds") raise TimeoutError(
f"Task {task_id} did not complete within {timeout} seconds"
)
status = await self.get_task_status(task_id) status = await self.get_task_status(task_id)
if status["status"] in ["completed", "failed"]: if status["status"] in ["completed", "failed"]:
@@ -44,13 +50,11 @@ class NBCNewsAPITest:
async with self.session.get(f"{self.base_url}/health") as response: async with self.session.get(f"{self.base_url}/health") as response:
return await response.json() return await response.json()
async def test_basic_crawl(): async def test_basic_crawl():
print("\n=== Testing Basic Crawl ===") print("\n=== Testing Basic Crawl ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
request = { request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
print(f"Basic crawl result length: {len(result['result']['markdown'])}") print(f"Basic crawl result length: {len(result['result']['markdown'])}")
@@ -58,6 +62,7 @@ async def test_basic_crawl():
assert "result" in result assert "result" in result
assert result["result"]["success"] assert result["result"]["success"]
async def test_js_execution(): async def test_js_execution():
print("\n=== Testing JS Execution ===") print("\n=== Testing JS Execution ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -68,9 +73,7 @@ async def test_js_execution():
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
], ],
"wait_for": "article.tease-card:nth-child(10)", "wait_for": "article.tease-card:nth-child(10)",
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -78,13 +81,14 @@ async def test_js_execution():
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["result"]["success"] assert result["result"]["success"]
async def test_css_selector(): async def test_css_selector():
print("\n=== Testing CSS Selector ===") print("\n=== Testing CSS Selector ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 7, "priority": 7,
"css_selector": ".wide-tease-item__description" "css_selector": ".wide-tease-item__description",
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -92,6 +96,7 @@ async def test_css_selector():
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["result"]["success"] assert result["result"]["success"]
async def test_structured_extraction(): async def test_structured_extraction():
print("\n=== Testing Structured Extraction ===") print("\n=== Testing Structured Extraction ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -99,34 +104,25 @@ async def test_structured_extraction():
"name": "NBC News Articles", "name": "NBC News Articles",
"baseSelector": "article.tease-card", "baseSelector": "article.tease-card",
"fields": [ "fields": [
{ {"name": "title", "selector": "h2", "type": "text"},
"name": "title",
"selector": "h2",
"type": "text"
},
{ {
"name": "description", "name": "description",
"selector": ".tease-card__description", "selector": ".tease-card__description",
"type": "text" "type": "text",
}, },
{ {
"name": "link", "name": "link",
"selector": "a", "selector": "a",
"type": "attribute", "type": "attribute",
"attribute": "href" "attribute": "href",
} },
] ],
} }
request = { request = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 9, "priority": 9,
"extraction_config": { "extraction_config": {"type": "json_css", "params": {"schema": schema}},
"type": "json_css",
"params": {
"schema": schema
}
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -136,6 +132,7 @@ async def test_structured_extraction():
assert result["result"]["success"] assert result["result"]["success"]
assert len(extracted) > 0 assert len(extracted) > 0
async def test_batch_crawl(): async def test_batch_crawl():
print("\n=== Testing Batch Crawl ===") print("\n=== Testing Batch Crawl ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -143,12 +140,10 @@ async def test_batch_crawl():
"urls": [ "urls": [
"https://www.nbcnews.com/business", "https://www.nbcnews.com/business",
"https://www.nbcnews.com/business/consumer", "https://www.nbcnews.com/business/consumer",
"https://www.nbcnews.com/business/economy" "https://www.nbcnews.com/business/economy",
], ],
"priority": 6, "priority": 6,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -157,6 +152,7 @@ async def test_batch_crawl():
assert "results" in result assert "results" in result
assert len(result["results"]) == 3 assert len(result["results"]) == 3
async def test_llm_extraction(): async def test_llm_extraction():
print("\n=== Testing LLM Extraction with Ollama ===") print("\n=== Testing LLM Extraction with Ollama ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -165,19 +161,19 @@ async def test_llm_extraction():
"properties": { "properties": {
"article_title": { "article_title": {
"type": "string", "type": "string",
"description": "The main title of the news article" "description": "The main title of the news article",
}, },
"summary": { "summary": {
"type": "string", "type": "string",
"description": "A brief summary of the article content" "description": "A brief summary of the article content",
}, },
"main_topics": { "main_topics": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Main topics or themes discussed in the article" "description": "Main topics or themes discussed in the article",
} },
}, },
"required": ["article_title", "summary", "main_topics"] "required": ["article_title", "summary", "main_topics"],
} }
request = { request = {
@@ -191,26 +187,24 @@ async def test_llm_extraction():
"schema": schema, "schema": schema,
"extraction_type": "schema", "extraction_type": "schema",
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed. "instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
Focus on the primary business news article on the page.""" Focus on the primary business news article on the page.""",
} },
}, },
"crawler_params": { "crawler_params": {"headless": True, "word_count_threshold": 1},
"headless": True,
"word_count_threshold": 1
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
if result["status"] == "completed": if result["status"] == "completed":
extracted = json.loads(result["result"]["extracted_content"]) extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted article analysis:") print("Extracted article analysis:")
print(json.dumps(extracted, indent=2)) print(json.dumps(extracted, indent=2))
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["result"]["success"] assert result["result"]["success"]
async def test_screenshot(): async def test_screenshot():
print("\n=== Testing Screenshot ===") print("\n=== Testing Screenshot ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -218,9 +212,7 @@ async def test_screenshot():
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 5, "priority": 5,
"screenshot": True, "screenshot": True,
"crawler_params": { "crawler_params": {"headless": True},
"headless": True
}
} }
task_id = await api.submit_crawl(request) task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id) result = await api.wait_for_task(task_id)
@@ -229,6 +221,7 @@ async def test_screenshot():
assert result["result"]["success"] assert result["result"]["success"]
assert result["result"]["screenshot"] is not None assert result["result"]["screenshot"] is not None
async def test_priority_handling(): async def test_priority_handling():
print("\n=== Testing Priority Handling ===") print("\n=== Testing Priority Handling ===")
async with NBCNewsAPITest() as api: async with NBCNewsAPITest() as api:
@@ -236,7 +229,7 @@ async def test_priority_handling():
low_priority = { low_priority = {
"urls": "https://www.nbcnews.com/business", "urls": "https://www.nbcnews.com/business",
"priority": 1, "priority": 1,
"crawler_params": {"headless": True} "crawler_params": {"headless": True},
} }
low_task_id = await api.submit_crawl(low_priority) low_task_id = await api.submit_crawl(low_priority)
@@ -244,7 +237,7 @@ async def test_priority_handling():
high_priority = { high_priority = {
"urls": "https://www.nbcnews.com/business/consumer", "urls": "https://www.nbcnews.com/business/consumer",
"priority": 10, "priority": 10,
"crawler_params": {"headless": True} "crawler_params": {"headless": True},
} }
high_task_id = await api.submit_crawl(high_priority) high_task_id = await api.submit_crawl(high_priority)
@@ -256,6 +249,7 @@ async def test_priority_handling():
assert high_result["status"] == "completed" assert high_result["status"] == "completed"
assert low_result["status"] == "completed" assert low_result["status"] == "completed"
async def main(): async def main():
try: try:
# Start with health check # Start with health check
@@ -277,5 +271,6 @@ async def main():
print(f"Test failed: {str(e)}") print(f"Test failed: {str(e)}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,21 +1,26 @@
import nest_asyncio import nest_asyncio
nest_asyncio.apply() nest_asyncio.apply()
import asyncio import asyncio
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, LXMLWebScrapingStrategy, CacheMode from crawl4ai import (
AsyncWebCrawler,
CrawlerRunConfig,
LXMLWebScrapingStrategy,
CacheMode,
)
async def main(): async def main():
config = CrawlerRunConfig( config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS, cache_mode=CacheMode.BYPASS,
scraping_strategy=LXMLWebScrapingStrategy() # Faster alternative to default BeautifulSoup scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup
) )
async with AsyncWebCrawler() as crawler: async with AsyncWebCrawler() as crawler:
result = await crawler.arun( result = await crawler.arun(url="https://example.com", config=config)
url="https://example.com",
config=config
)
print(f"Success: {result.success}") print(f"Success: {result.success}")
print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}") print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,79 +1,105 @@
import unittest, os import unittest, os
from crawl4ai.web_crawler import WebCrawler from crawl4ai.web_crawler import WebCrawler
from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking from crawl4ai.chunking_strategy import (
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy RegexChunking,
FixedLengthWordChunking,
SlidingWindowChunking,
)
from crawl4ai.extraction_strategy import (
CosineStrategy,
LLMExtractionStrategy,
TopicExtractionStrategy,
NoExtractionStrategy,
)
class TestWebCrawler(unittest.TestCase): class TestWebCrawler(unittest.TestCase):
def setUp(self): def setUp(self):
self.crawler = WebCrawler() self.crawler = WebCrawler()
def test_warmup(self): def test_warmup(self):
self.crawler.warmup() self.crawler.warmup()
self.assertTrue(self.crawler.ready, "WebCrawler failed to warm up") self.assertTrue(self.crawler.ready, "WebCrawler failed to warm up")
def test_run_default_strategies(self): def test_run_default_strategies(self):
result = self.crawler.run( result = self.crawler.run(
url='https://www.nbcnews.com/business', url="https://www.nbcnews.com/business",
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=RegexChunking(), chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(), bypass_cache=True extraction_strategy=CosineStrategy(),
bypass_cache=True,
) )
self.assertTrue(result.success, "Failed to crawl and extract using default strategies") self.assertTrue(
result.success, "Failed to crawl and extract using default strategies"
)
def test_run_different_strategies(self): def test_run_different_strategies(self):
url = 'https://www.nbcnews.com/business' url = "https://www.nbcnews.com/business"
# Test with FixedLengthWordChunking and LLMExtractionStrategy # Test with FixedLengthWordChunking and LLMExtractionStrategy
result = self.crawler.run( result = self.crawler.run(
url=url, url=url,
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100), chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True extraction_strategy=LLMExtractionStrategy(
provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY")
),
bypass_cache=True,
) )
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy") self.assertTrue(
result.success,
"Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy",
)
# Test with SlidingWindowChunking and TopicExtractionStrategy # Test with SlidingWindowChunking and TopicExtractionStrategy
result = self.crawler.run( result = self.crawler.run(
url=url, url=url,
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50), chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True extraction_strategy=TopicExtractionStrategy(num_keywords=5),
bypass_cache=True,
) )
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy") self.assertTrue(
result.success,
"Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy",
)
def test_invalid_url(self): def test_invalid_url(self):
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
self.crawler.run(url='invalid_url', bypass_cache=True) self.crawler.run(url="invalid_url", bypass_cache=True)
self.assertIn("Invalid URL", str(context.exception)) self.assertIn("Invalid URL", str(context.exception))
def test_unsupported_extraction_strategy(self): def test_unsupported_extraction_strategy(self):
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True) self.crawler.run(
url="https://www.nbcnews.com/business",
extraction_strategy="UnsupportedStrategy",
bypass_cache=True,
)
self.assertIn("Unsupported extraction strategy", str(context.exception)) self.assertIn("Unsupported extraction strategy", str(context.exception))
def test_invalid_css_selector(self): def test_invalid_css_selector(self):
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError) as context:
self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True) self.crawler.run(
url="https://www.nbcnews.com/business",
css_selector="invalid_selector",
bypass_cache=True,
)
self.assertIn("Invalid CSS selector", str(context.exception)) self.assertIn("Invalid CSS selector", str(context.exception))
def test_crawl_with_cache_and_bypass_cache(self): def test_crawl_with_cache_and_bypass_cache(self):
url = 'https://www.nbcnews.com/business' url = "https://www.nbcnews.com/business"
# First crawl with cache enabled # First crawl with cache enabled
result = self.crawler.run(url=url, bypass_cache=False) result = self.crawler.run(url=url, bypass_cache=False)
self.assertTrue(result.success, "Failed to crawl and cache the result") self.assertTrue(result.success, "Failed to crawl and cache the result")
# Second crawl with bypass_cache=True # Second crawl with bypass_cache=True
result = self.crawler.run(url=url, bypass_cache=True) result = self.crawler.run(url=url, bypass_cache=True)
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data") self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
def test_fetch_multiple_pages(self): def test_fetch_multiple_pages(self):
urls = [ urls = ["https://www.nbcnews.com/business", "https://www.bbc.com/news"]
'https://www.nbcnews.com/business',
'https://www.bbc.com/news'
]
results = [] results = []
for url in urls: for url in urls:
result = self.crawler.run( result = self.crawler.run(
@@ -81,31 +107,42 @@ class TestWebCrawler(unittest.TestCase):
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=RegexChunking(), chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(), extraction_strategy=CosineStrategy(),
bypass_cache=True bypass_cache=True,
) )
results.append(result) results.append(result)
self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages") self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages")
for result in results: for result in results:
self.assertTrue(result.success, "Failed to crawl and extract a page in the list") self.assertTrue(
result.success, "Failed to crawl and extract a page in the list"
)
def test_run_fixed_length_word_chunking_and_no_extraction(self): def test_run_fixed_length_word_chunking_and_no_extraction(self):
result = self.crawler.run( result = self.crawler.run(
url='https://www.nbcnews.com/business', url="https://www.nbcnews.com/business",
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100), chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=NoExtractionStrategy(), bypass_cache=True extraction_strategy=NoExtractionStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy",
) )
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy")
def test_run_sliding_window_and_no_extraction(self): def test_run_sliding_window_and_no_extraction(self):
result = self.crawler.run( result = self.crawler.run(
url='https://www.nbcnews.com/business', url="https://www.nbcnews.com/business",
word_count_threshold=5, word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50), chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=NoExtractionStrategy(), bypass_cache=True extraction_strategy=NoExtractionStrategy(),
bypass_cache=True,
)
self.assertTrue(
result.success,
"Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy",
) )
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()