Apply Ruff Corrections
This commit is contained in:
8
.pre-commit-config.yaml
Normal file
8
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
# .pre-commit-config.yaml
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.1.11
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
@@ -2,14 +2,28 @@
|
||||
|
||||
from .async_webcrawler import AsyncWebCrawler, CacheMode
|
||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy
|
||||
from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy
|
||||
from .content_scraping_strategy import (
|
||||
ContentScrapingStrategy,
|
||||
WebScrapingStrategy,
|
||||
LXMLWebScrapingStrategy,
|
||||
)
|
||||
from .extraction_strategy import (
|
||||
ExtractionStrategy,
|
||||
LLMExtractionStrategy,
|
||||
CosineStrategy,
|
||||
JsonCssExtractionStrategy,
|
||||
)
|
||||
from .chunking_strategy import ChunkingStrategy, RegexChunking
|
||||
from .markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter
|
||||
from .models import CrawlResult, MarkdownGenerationResult
|
||||
from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode
|
||||
from .__version__ import __version__
|
||||
from .async_dispatcher import (
|
||||
MemoryAdaptiveDispatcher,
|
||||
SemaphoreDispatcher,
|
||||
RateLimiter,
|
||||
CrawlerMonitor,
|
||||
DisplayMode,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncWebCrawler",
|
||||
@@ -18,40 +32,45 @@ __all__ = [
|
||||
"ContentScrapingStrategy",
|
||||
"WebScrapingStrategy",
|
||||
"LXMLWebScrapingStrategy",
|
||||
'BrowserConfig',
|
||||
'CrawlerRunConfig',
|
||||
'ExtractionStrategy',
|
||||
'LLMExtractionStrategy',
|
||||
'CosineStrategy',
|
||||
'JsonCssExtractionStrategy',
|
||||
'ChunkingStrategy',
|
||||
'RegexChunking',
|
||||
'DefaultMarkdownGenerator',
|
||||
'PruningContentFilter',
|
||||
'BM25ContentFilter',
|
||||
'MemoryAdaptiveDispatcher',
|
||||
'SemaphoreDispatcher',
|
||||
'RateLimiter',
|
||||
'CrawlerMonitor',
|
||||
'DisplayMode',
|
||||
'MarkdownGenerationResult',
|
||||
"BrowserConfig",
|
||||
"CrawlerRunConfig",
|
||||
"ExtractionStrategy",
|
||||
"LLMExtractionStrategy",
|
||||
"CosineStrategy",
|
||||
"JsonCssExtractionStrategy",
|
||||
"ChunkingStrategy",
|
||||
"RegexChunking",
|
||||
"DefaultMarkdownGenerator",
|
||||
"PruningContentFilter",
|
||||
"BM25ContentFilter",
|
||||
"MemoryAdaptiveDispatcher",
|
||||
"SemaphoreDispatcher",
|
||||
"RateLimiter",
|
||||
"CrawlerMonitor",
|
||||
"DisplayMode",
|
||||
"MarkdownGenerationResult",
|
||||
]
|
||||
|
||||
|
||||
def is_sync_version_installed():
|
||||
try:
|
||||
import selenium
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
if is_sync_version_installed():
|
||||
try:
|
||||
from .web_crawler import WebCrawler
|
||||
|
||||
__all__.append("WebCrawler")
|
||||
except ImportError:
|
||||
import warnings
|
||||
print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.")
|
||||
print(
|
||||
"Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies."
|
||||
)
|
||||
else:
|
||||
WebCrawler = None
|
||||
# import warnings
|
||||
# print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.")
|
||||
# print("Warning: Synchronous WebCrawler is not available. Install crawl4ai[sync] for synchronous support. However, please note that the synchronous version will be deprecated soon.")
|
||||
|
||||
@@ -5,7 +5,6 @@ from .config import (
|
||||
PAGE_TIMEOUT,
|
||||
IMAGE_SCORE_THRESHOLD,
|
||||
SOCIAL_MEDIA_DOMAINS,
|
||||
|
||||
)
|
||||
from .user_agent_generator import UserAgentGenerator
|
||||
from .extraction_strategy import ExtractionStrategy
|
||||
@@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy
|
||||
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class BrowserConfig:
|
||||
"""
|
||||
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
|
||||
@@ -84,7 +84,7 @@ class BrowserConfig:
|
||||
proxy: str = None,
|
||||
proxy_config: dict = None,
|
||||
viewport_width: int = 1080,
|
||||
viewport_height: int = 600,
|
||||
viewport_height: int = 600,
|
||||
accept_downloads: bool = False,
|
||||
downloads_path: str = None,
|
||||
storage_state=None,
|
||||
@@ -103,7 +103,7 @@ class BrowserConfig:
|
||||
text_mode: bool = False,
|
||||
light_mode: bool = False,
|
||||
extra_args: list = None,
|
||||
debugging_port : int = 9222,
|
||||
debugging_port: int = 9222,
|
||||
):
|
||||
self.browser_type = browser_type
|
||||
self.headless = headless
|
||||
@@ -142,7 +142,7 @@ class BrowserConfig:
|
||||
self.user_agent = user_agenr_generator.generate()
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
self.browser_hint = user_agenr_generator.generate_client_hints(self.user_agent)
|
||||
self.headers.setdefault("sec-ch-ua", self.browser_hint)
|
||||
|
||||
@@ -313,7 +313,7 @@ class CrawlerRunConfig:
|
||||
Default: True.
|
||||
log_console (bool): If True, log console messages from the page.
|
||||
Default: False.
|
||||
|
||||
|
||||
# Optional Parameters
|
||||
url: str = None # This is not a compulsory parameter
|
||||
"""
|
||||
@@ -335,10 +335,8 @@ class CrawlerRunConfig:
|
||||
prettiify: bool = False,
|
||||
parser_type: str = "lxml",
|
||||
scraping_strategy: ContentScrapingStrategy = None,
|
||||
|
||||
# SSL Parameters
|
||||
fetch_ssl_certificate: bool = False,
|
||||
|
||||
# Caching Parameters
|
||||
cache_mode=None,
|
||||
session_id: str = None,
|
||||
@@ -346,7 +344,6 @@ class CrawlerRunConfig:
|
||||
disable_cache: bool = False,
|
||||
no_cache_read: bool = False,
|
||||
no_cache_write: bool = False,
|
||||
|
||||
# Page Navigation and Timing Parameters
|
||||
wait_until: str = "domcontentloaded",
|
||||
page_timeout: int = PAGE_TIMEOUT,
|
||||
@@ -356,7 +353,6 @@ class CrawlerRunConfig:
|
||||
mean_delay: float = 0.1,
|
||||
max_range: float = 0.3,
|
||||
semaphore_count: int = 5,
|
||||
|
||||
# Page Interaction Parameters
|
||||
js_code: Union[str, List[str]] = None,
|
||||
js_only: bool = False,
|
||||
@@ -369,7 +365,6 @@ class CrawlerRunConfig:
|
||||
override_navigator: bool = False,
|
||||
magic: bool = False,
|
||||
adjust_viewport_to_content: bool = False,
|
||||
|
||||
# Media Handling Parameters
|
||||
screenshot: bool = False,
|
||||
screenshot_wait_for: float = None,
|
||||
@@ -378,21 +373,18 @@ class CrawlerRunConfig:
|
||||
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||
image_score_threshold: int = IMAGE_SCORE_THRESHOLD,
|
||||
exclude_external_images: bool = False,
|
||||
|
||||
# Link and Domain Handling Parameters
|
||||
exclude_social_media_domains: list = None,
|
||||
exclude_external_links: bool = False,
|
||||
exclude_social_media_links: bool = False,
|
||||
exclude_domains: list = None,
|
||||
|
||||
# Debugging and Logging Parameters
|
||||
verbose: bool = True,
|
||||
log_console: bool = False,
|
||||
|
||||
url: str = None,
|
||||
):
|
||||
self.url = url
|
||||
|
||||
|
||||
# Content Processing Parameters
|
||||
self.word_count_threshold = word_count_threshold
|
||||
self.extraction_strategy = extraction_strategy
|
||||
@@ -453,7 +445,9 @@ class CrawlerRunConfig:
|
||||
self.exclude_external_images = exclude_external_images
|
||||
|
||||
# Link and Domain Handling Parameters
|
||||
self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
|
||||
self.exclude_social_media_domains = (
|
||||
exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
|
||||
)
|
||||
self.exclude_external_links = exclude_external_links
|
||||
self.exclude_social_media_links = exclude_social_media_links
|
||||
self.exclude_domains = exclude_domains or []
|
||||
@@ -466,11 +460,15 @@ class CrawlerRunConfig:
|
||||
if self.extraction_strategy is not None and not isinstance(
|
||||
self.extraction_strategy, ExtractionStrategy
|
||||
):
|
||||
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy")
|
||||
raise ValueError(
|
||||
"extraction_strategy must be an instance of ExtractionStrategy"
|
||||
)
|
||||
if self.chunking_strategy is not None and not isinstance(
|
||||
self.chunking_strategy, ChunkingStrategy
|
||||
):
|
||||
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy")
|
||||
raise ValueError(
|
||||
"chunking_strategy must be an instance of ChunkingStrategy"
|
||||
)
|
||||
|
||||
# Set default chunking strategy if None
|
||||
if self.chunking_strategy is None:
|
||||
@@ -494,10 +492,8 @@ class CrawlerRunConfig:
|
||||
prettiify=kwargs.get("prettiify", False),
|
||||
parser_type=kwargs.get("parser_type", "lxml"),
|
||||
scraping_strategy=kwargs.get("scraping_strategy"),
|
||||
|
||||
# SSL Parameters
|
||||
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
|
||||
|
||||
# Caching Parameters
|
||||
cache_mode=kwargs.get("cache_mode"),
|
||||
session_id=kwargs.get("session_id"),
|
||||
@@ -505,7 +501,6 @@ class CrawlerRunConfig:
|
||||
disable_cache=kwargs.get("disable_cache", False),
|
||||
no_cache_read=kwargs.get("no_cache_read", False),
|
||||
no_cache_write=kwargs.get("no_cache_write", False),
|
||||
|
||||
# Page Navigation and Timing Parameters
|
||||
wait_until=kwargs.get("wait_until", "domcontentloaded"),
|
||||
page_timeout=kwargs.get("page_timeout", 60000),
|
||||
@@ -515,7 +510,6 @@ class CrawlerRunConfig:
|
||||
mean_delay=kwargs.get("mean_delay", 0.1),
|
||||
max_range=kwargs.get("max_range", 0.3),
|
||||
semaphore_count=kwargs.get("semaphore_count", 5),
|
||||
|
||||
# Page Interaction Parameters
|
||||
js_code=kwargs.get("js_code"),
|
||||
js_only=kwargs.get("js_only", False),
|
||||
@@ -528,29 +522,34 @@ class CrawlerRunConfig:
|
||||
override_navigator=kwargs.get("override_navigator", False),
|
||||
magic=kwargs.get("magic", False),
|
||||
adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
|
||||
|
||||
# Media Handling Parameters
|
||||
screenshot=kwargs.get("screenshot", False),
|
||||
screenshot_wait_for=kwargs.get("screenshot_wait_for"),
|
||||
screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD),
|
||||
screenshot_height_threshold=kwargs.get(
|
||||
"screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD
|
||||
),
|
||||
pdf=kwargs.get("pdf", False),
|
||||
image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD),
|
||||
image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD),
|
||||
image_description_min_word_threshold=kwargs.get(
|
||||
"image_description_min_word_threshold",
|
||||
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||
),
|
||||
image_score_threshold=kwargs.get(
|
||||
"image_score_threshold", IMAGE_SCORE_THRESHOLD
|
||||
),
|
||||
exclude_external_images=kwargs.get("exclude_external_images", False),
|
||||
|
||||
# Link and Domain Handling Parameters
|
||||
exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS),
|
||||
exclude_social_media_domains=kwargs.get(
|
||||
"exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS
|
||||
),
|
||||
exclude_external_links=kwargs.get("exclude_external_links", False),
|
||||
exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
|
||||
exclude_domains=kwargs.get("exclude_domains", []),
|
||||
|
||||
# Debugging and Logging Parameters
|
||||
verbose=kwargs.get("verbose", True),
|
||||
log_console=kwargs.get("log_console", False),
|
||||
|
||||
url=kwargs.get("url"),
|
||||
)
|
||||
|
||||
|
||||
# Create a funciton returns dict of the object
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,27 +1,29 @@
|
||||
import os, sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
from typing import Optional, Tuple, Dict
|
||||
from typing import Optional, Dict
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import json # Added for serialization/deserialization
|
||||
from .utils import ensure_content_dirs, generate_content_hash
|
||||
from .models import CrawlResult, MarkdownGenerationResult
|
||||
import xxhash
|
||||
import aiofiles
|
||||
from .config import NEED_MIGRATION
|
||||
from .version_manager import VersionManager
|
||||
from .async_logger import AsyncLogger
|
||||
from .utils import get_error_context, create_box_message
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
|
||||
base_directory = DB_PATH = os.path.join(
|
||||
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
|
||||
)
|
||||
os.makedirs(DB_PATH, exist_ok=True)
|
||||
DB_PATH = os.path.join(base_directory, "crawl4ai.db")
|
||||
|
||||
|
||||
class AsyncDatabaseManager:
|
||||
def __init__(self, pool_size: int = 10, max_retries: int = 3):
|
||||
self.db_path = DB_PATH
|
||||
@@ -32,28 +34,27 @@ class AsyncDatabaseManager:
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.init_lock = asyncio.Lock()
|
||||
self.connection_semaphore = asyncio.Semaphore(pool_size)
|
||||
self._initialized = False
|
||||
self._initialized = False
|
||||
self.version_manager = VersionManager()
|
||||
self.logger = AsyncLogger(
|
||||
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
|
||||
verbose=False,
|
||||
tag_width=10
|
||||
tag_width=10,
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the database and connection pool"""
|
||||
try:
|
||||
self.logger.info("Initializing database", tag="INIT")
|
||||
# Ensure the database file exists
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
|
||||
# Check if version update is needed
|
||||
needs_update = self.version_manager.needs_update()
|
||||
|
||||
|
||||
# Always ensure base table exists
|
||||
await self.ainit_db()
|
||||
|
||||
|
||||
# Verify the table exists
|
||||
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
||||
async with db.execute(
|
||||
@@ -62,33 +63,37 @@ class AsyncDatabaseManager:
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
raise Exception("crawled_data table was not created")
|
||||
|
||||
|
||||
# If version changed or fresh install, run updates
|
||||
if needs_update:
|
||||
self.logger.info("New version detected, running updates", tag="INIT")
|
||||
await self.update_db_schema()
|
||||
from .migrations import run_migration # Import here to avoid circular imports
|
||||
from .migrations import (
|
||||
run_migration,
|
||||
) # Import here to avoid circular imports
|
||||
|
||||
await run_migration()
|
||||
self.version_manager.update_version() # Update stored version after successful migration
|
||||
self.logger.success("Version update completed successfully", tag="COMPLETE")
|
||||
self.logger.success(
|
||||
"Version update completed successfully", tag="COMPLETE"
|
||||
)
|
||||
else:
|
||||
self.logger.success("Database initialization completed successfully", tag="COMPLETE")
|
||||
self.logger.success(
|
||||
"Database initialization completed successfully", tag="COMPLETE"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
message="Database initialization error: {error}",
|
||||
tag="ERROR",
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
self.logger.info(
|
||||
message="Database will be initialized on first use",
|
||||
tag="INIT"
|
||||
message="Database will be initialized on first use", tag="INIT"
|
||||
)
|
||||
|
||||
|
||||
raise
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup connections when shutting down"""
|
||||
async with self.pool_lock:
|
||||
@@ -107,6 +112,7 @@ class AsyncDatabaseManager:
|
||||
self._initialized = True
|
||||
except Exception as e:
|
||||
import sys
|
||||
|
||||
error_context = get_error_context(sys.exc_info())
|
||||
self.logger.error(
|
||||
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
|
||||
@@ -115,41 +121,52 @@ class AsyncDatabaseManager:
|
||||
params={
|
||||
"error": str(e),
|
||||
"context": error_context["code_context"],
|
||||
"traceback": error_context["full_traceback"]
|
||||
}
|
||||
"traceback": error_context["full_traceback"],
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
await self.connection_semaphore.acquire()
|
||||
task_id = id(asyncio.current_task())
|
||||
|
||||
|
||||
try:
|
||||
async with self.pool_lock:
|
||||
if task_id not in self.connection_pool:
|
||||
try:
|
||||
conn = await aiosqlite.connect(
|
||||
self.db_path,
|
||||
timeout=30.0
|
||||
)
|
||||
await conn.execute('PRAGMA journal_mode = WAL')
|
||||
await conn.execute('PRAGMA busy_timeout = 5000')
|
||||
|
||||
conn = await aiosqlite.connect(self.db_path, timeout=30.0)
|
||||
await conn.execute("PRAGMA journal_mode = WAL")
|
||||
await conn.execute("PRAGMA busy_timeout = 5000")
|
||||
|
||||
# Verify database structure
|
||||
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor:
|
||||
async with conn.execute(
|
||||
"PRAGMA table_info(crawled_data)"
|
||||
) as cursor:
|
||||
columns = await cursor.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
expected_columns = {
|
||||
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content',
|
||||
'success', 'media', 'links', 'metadata', 'screenshot',
|
||||
'response_headers', 'downloaded_files'
|
||||
"url",
|
||||
"html",
|
||||
"cleaned_html",
|
||||
"markdown",
|
||||
"extracted_content",
|
||||
"success",
|
||||
"media",
|
||||
"links",
|
||||
"metadata",
|
||||
"screenshot",
|
||||
"response_headers",
|
||||
"downloaded_files",
|
||||
}
|
||||
missing_columns = expected_columns - set(column_names)
|
||||
if missing_columns:
|
||||
raise ValueError(f"Database missing columns: {missing_columns}")
|
||||
|
||||
raise ValueError(
|
||||
f"Database missing columns: {missing_columns}"
|
||||
)
|
||||
|
||||
self.connection_pool[task_id] = conn
|
||||
except Exception as e:
|
||||
import sys
|
||||
|
||||
error_context = get_error_context(sys.exc_info())
|
||||
error_message = (
|
||||
f"Unexpected error in db get_connection at line {error_context['line_no']} "
|
||||
@@ -158,7 +175,7 @@ class AsyncDatabaseManager:
|
||||
f"Code context:\n{error_context['code_context']}"
|
||||
)
|
||||
self.logger.error(
|
||||
message=create_box_message(error_message, type= "error"),
|
||||
message=create_box_message(error_message, type="error"),
|
||||
)
|
||||
|
||||
raise
|
||||
@@ -167,6 +184,7 @@ class AsyncDatabaseManager:
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
|
||||
error_context = get_error_context(sys.exc_info())
|
||||
error_message = (
|
||||
f"Unexpected error in db get_connection at line {error_context['line_no']} "
|
||||
@@ -175,7 +193,7 @@ class AsyncDatabaseManager:
|
||||
f"Code context:\n{error_context['code_context']}"
|
||||
)
|
||||
self.logger.error(
|
||||
message=create_box_message(error_message, type= "error"),
|
||||
message=create_box_message(error_message, type="error"),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -185,7 +203,6 @@ class AsyncDatabaseManager:
|
||||
del self.connection_pool[task_id]
|
||||
self.connection_semaphore.release()
|
||||
|
||||
|
||||
async def execute_with_retry(self, operation, *args):
|
||||
"""Execute database operations with retry logic"""
|
||||
for attempt in range(self.max_retries):
|
||||
@@ -200,18 +217,16 @@ class AsyncDatabaseManager:
|
||||
message="Operation failed after {retries} attempts: {error}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={
|
||||
"retries": self.max_retries,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
params={"retries": self.max_retries, "error": str(e)},
|
||||
)
|
||||
raise
|
||||
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
|
||||
|
||||
async def ainit_db(self):
|
||||
"""Initialize database schema"""
|
||||
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
||||
await db.execute('''
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS crawled_data (
|
||||
url TEXT PRIMARY KEY,
|
||||
html TEXT,
|
||||
@@ -226,21 +241,27 @@ class AsyncDatabaseManager:
|
||||
response_headers TEXT DEFAULT "{}",
|
||||
downloaded_files TEXT DEFAULT "{}" -- New column added
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
|
||||
async def update_db_schema(self):
|
||||
"""Update database schema if needed"""
|
||||
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
||||
cursor = await db.execute("PRAGMA table_info(crawled_data)")
|
||||
columns = await cursor.fetchall()
|
||||
column_names = [column[1] for column in columns]
|
||||
|
||||
|
||||
# List of new columns to add
|
||||
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files']
|
||||
|
||||
new_columns = [
|
||||
"media",
|
||||
"links",
|
||||
"metadata",
|
||||
"screenshot",
|
||||
"response_headers",
|
||||
"downloaded_files",
|
||||
]
|
||||
|
||||
for column in new_columns:
|
||||
if column not in column_names:
|
||||
await self.aalter_db_add_column(column, db)
|
||||
@@ -248,75 +269,91 @@ class AsyncDatabaseManager:
|
||||
|
||||
async def aalter_db_add_column(self, new_column: str, db):
|
||||
"""Add new column to the database"""
|
||||
if new_column == 'response_headers':
|
||||
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"')
|
||||
if new_column == "response_headers":
|
||||
await db.execute(
|
||||
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"'
|
||||
)
|
||||
else:
|
||||
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
|
||||
await db.execute(
|
||||
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
|
||||
)
|
||||
self.logger.info(
|
||||
message="Added column '{column}' to the database",
|
||||
tag="INIT",
|
||||
params={"column": new_column}
|
||||
)
|
||||
|
||||
params={"column": new_column},
|
||||
)
|
||||
|
||||
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
|
||||
"""Retrieve cached URL data as CrawlResult"""
|
||||
|
||||
async def _get(db):
|
||||
async with db.execute(
|
||||
'SELECT * FROM crawled_data WHERE url = ?', (url,)
|
||||
"SELECT * FROM crawled_data WHERE url = ?", (url,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
|
||||
# Get column names
|
||||
columns = [description[0] for description in cursor.description]
|
||||
# Create dict from row data
|
||||
row_dict = dict(zip(columns, row))
|
||||
|
||||
|
||||
# Load content from files using stored hashes
|
||||
content_fields = {
|
||||
'html': row_dict['html'],
|
||||
'cleaned_html': row_dict['cleaned_html'],
|
||||
'markdown': row_dict['markdown'],
|
||||
'extracted_content': row_dict['extracted_content'],
|
||||
'screenshot': row_dict['screenshot'],
|
||||
'screenshots': row_dict['screenshot'],
|
||||
"html": row_dict["html"],
|
||||
"cleaned_html": row_dict["cleaned_html"],
|
||||
"markdown": row_dict["markdown"],
|
||||
"extracted_content": row_dict["extracted_content"],
|
||||
"screenshot": row_dict["screenshot"],
|
||||
"screenshots": row_dict["screenshot"],
|
||||
}
|
||||
|
||||
|
||||
for field, hash_value in content_fields.items():
|
||||
if hash_value:
|
||||
content = await self._load_content(
|
||||
hash_value,
|
||||
field.split('_')[0] # Get content type from field name
|
||||
hash_value,
|
||||
field.split("_")[0], # Get content type from field name
|
||||
)
|
||||
row_dict[field] = content or ""
|
||||
else:
|
||||
row_dict[field] = ""
|
||||
|
||||
# Parse JSON fields
|
||||
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown']
|
||||
json_fields = [
|
||||
"media",
|
||||
"links",
|
||||
"metadata",
|
||||
"response_headers",
|
||||
"markdown",
|
||||
]
|
||||
for field in json_fields:
|
||||
try:
|
||||
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {}
|
||||
row_dict[field] = (
|
||||
json.loads(row_dict[field]) if row_dict[field] else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
row_dict[field] = {}
|
||||
|
||||
if isinstance(row_dict['markdown'], Dict):
|
||||
row_dict['markdown_v2'] = row_dict['markdown']
|
||||
if row_dict['markdown'].get('raw_markdown'):
|
||||
row_dict['markdown'] = row_dict['markdown']['raw_markdown']
|
||||
|
||||
if isinstance(row_dict["markdown"], Dict):
|
||||
row_dict["markdown_v2"] = row_dict["markdown"]
|
||||
if row_dict["markdown"].get("raw_markdown"):
|
||||
row_dict["markdown"] = row_dict["markdown"]["raw_markdown"]
|
||||
|
||||
# Parse downloaded_files
|
||||
try:
|
||||
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else []
|
||||
row_dict["downloaded_files"] = (
|
||||
json.loads(row_dict["downloaded_files"])
|
||||
if row_dict["downloaded_files"]
|
||||
else []
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
row_dict['downloaded_files'] = []
|
||||
row_dict["downloaded_files"] = []
|
||||
|
||||
# Remove any fields not in CrawlResult model
|
||||
valid_fields = CrawlResult.__annotations__.keys()
|
||||
filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields}
|
||||
|
||||
|
||||
return CrawlResult(**filtered_dict)
|
||||
|
||||
try:
|
||||
@@ -326,7 +363,7 @@ class AsyncDatabaseManager:
|
||||
message="Error retrieving cached URL: {error}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -334,37 +371,52 @@ class AsyncDatabaseManager:
|
||||
"""Cache CrawlResult data"""
|
||||
# Store content files and get hashes
|
||||
content_map = {
|
||||
'html': (result.html, 'html'),
|
||||
'cleaned_html': (result.cleaned_html or "", 'cleaned'),
|
||||
'markdown': None,
|
||||
'extracted_content': (result.extracted_content or "", 'extracted'),
|
||||
'screenshot': (result.screenshot or "", 'screenshots')
|
||||
"html": (result.html, "html"),
|
||||
"cleaned_html": (result.cleaned_html or "", "cleaned"),
|
||||
"markdown": None,
|
||||
"extracted_content": (result.extracted_content or "", "extracted"),
|
||||
"screenshot": (result.screenshot or "", "screenshots"),
|
||||
}
|
||||
|
||||
try:
|
||||
if isinstance(result.markdown, MarkdownGenerationResult):
|
||||
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown')
|
||||
elif hasattr(result, 'markdown_v2'):
|
||||
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown')
|
||||
content_map["markdown"] = (
|
||||
result.markdown.model_dump_json(),
|
||||
"markdown",
|
||||
)
|
||||
elif hasattr(result, "markdown_v2"):
|
||||
content_map["markdown"] = (
|
||||
result.markdown_v2.model_dump_json(),
|
||||
"markdown",
|
||||
)
|
||||
elif isinstance(result.markdown, str):
|
||||
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
|
||||
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown')
|
||||
content_map["markdown"] = (
|
||||
markdown_result.model_dump_json(),
|
||||
"markdown",
|
||||
)
|
||||
else:
|
||||
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
|
||||
content_map["markdown"] = (
|
||||
MarkdownGenerationResult().model_dump_json(),
|
||||
"markdown",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
message=f"Error processing markdown content: {str(e)}",
|
||||
tag="WARNING"
|
||||
message=f"Error processing markdown content: {str(e)}", tag="WARNING"
|
||||
)
|
||||
# Fallback to empty markdown result
|
||||
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
|
||||
|
||||
content_map["markdown"] = (
|
||||
MarkdownGenerationResult().model_dump_json(),
|
||||
"markdown",
|
||||
)
|
||||
|
||||
content_hashes = {}
|
||||
for field, (content, content_type) in content_map.items():
|
||||
content_hashes[field] = await self._store_content(content, content_type)
|
||||
|
||||
async def _cache(db):
|
||||
await db.execute('''
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO crawled_data (
|
||||
url, html, cleaned_html, markdown,
|
||||
extracted_content, success, media, links, metadata,
|
||||
@@ -383,20 +435,22 @@ class AsyncDatabaseManager:
|
||||
screenshot = excluded.screenshot,
|
||||
response_headers = excluded.response_headers,
|
||||
downloaded_files = excluded.downloaded_files
|
||||
''', (
|
||||
result.url,
|
||||
content_hashes['html'],
|
||||
content_hashes['cleaned_html'],
|
||||
content_hashes['markdown'],
|
||||
content_hashes['extracted_content'],
|
||||
result.success,
|
||||
json.dumps(result.media),
|
||||
json.dumps(result.links),
|
||||
json.dumps(result.metadata or {}),
|
||||
content_hashes['screenshot'],
|
||||
json.dumps(result.response_headers or {}),
|
||||
json.dumps(result.downloaded_files or [])
|
||||
))
|
||||
""",
|
||||
(
|
||||
result.url,
|
||||
content_hashes["html"],
|
||||
content_hashes["cleaned_html"],
|
||||
content_hashes["markdown"],
|
||||
content_hashes["extracted_content"],
|
||||
result.success,
|
||||
json.dumps(result.media),
|
||||
json.dumps(result.links),
|
||||
json.dumps(result.metadata or {}),
|
||||
content_hashes["screenshot"],
|
||||
json.dumps(result.response_headers or {}),
|
||||
json.dumps(result.downloaded_files or []),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
await self.execute_with_retry(_cache)
|
||||
@@ -405,14 +459,14 @@ class AsyncDatabaseManager:
|
||||
message="Error caching URL: {error}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
async def aget_total_count(self) -> int:
|
||||
"""Get total number of cached URLs"""
|
||||
|
||||
async def _count(db):
|
||||
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor:
|
||||
async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
@@ -423,14 +477,15 @@ class AsyncDatabaseManager:
|
||||
message="Error getting total count: {error}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
return 0
|
||||
|
||||
async def aclear_db(self):
|
||||
"""Clear all data from the database"""
|
||||
|
||||
async def _clear(db):
|
||||
await db.execute('DELETE FROM crawled_data')
|
||||
await db.execute("DELETE FROM crawled_data")
|
||||
|
||||
try:
|
||||
await self.execute_with_retry(_clear)
|
||||
@@ -439,13 +494,14 @@ class AsyncDatabaseManager:
|
||||
message="Error clearing database: {error}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
|
||||
async def aflush_db(self):
|
||||
"""Drop the entire table"""
|
||||
|
||||
async def _flush(db):
|
||||
await db.execute('DROP TABLE IF EXISTS crawled_data')
|
||||
await db.execute("DROP TABLE IF EXISTS crawled_data")
|
||||
|
||||
try:
|
||||
await self.execute_with_retry(_flush)
|
||||
@@ -454,42 +510,44 @@ class AsyncDatabaseManager:
|
||||
message="Error flushing database: {error}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def _store_content(self, content: str, content_type: str) -> str:
|
||||
"""Store content in filesystem and return hash"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
content_hash = generate_content_hash(content)
|
||||
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
||||
|
||||
|
||||
# Only write if file doesn't exist
|
||||
if not os.path.exists(file_path):
|
||||
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
|
||||
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
|
||||
await f.write(content)
|
||||
|
||||
|
||||
return content_hash
|
||||
|
||||
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
|
||||
async def _load_content(
|
||||
self, content_hash: str, content_type: str
|
||||
) -> Optional[str]:
|
||||
"""Load content from filesystem by hash"""
|
||||
if not content_hash:
|
||||
return None
|
||||
|
||||
|
||||
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
||||
try:
|
||||
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
|
||||
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||
return await f.read()
|
||||
except:
|
||||
self.logger.error(
|
||||
message="Failed to load content: {file_path}",
|
||||
tag="ERROR",
|
||||
force_verbose=True,
|
||||
params={"file_path": file_path}
|
||||
params={"file_path": file_path},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
async_db_manager = AsyncDatabaseManager()
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from typing import Dict, Optional, List
|
||||
from .async_configs import *
|
||||
from .models import *
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
from .async_configs import CrawlerRunConfig
|
||||
from .models import (
|
||||
CrawlResult,
|
||||
CrawlerTaskResult,
|
||||
CrawlStatus,
|
||||
DisplayMode,
|
||||
CrawlStats,
|
||||
DomainState,
|
||||
)
|
||||
|
||||
from rich.live import Live
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
from rich.style import Style
|
||||
from rich import box
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
import time
|
||||
import psutil
|
||||
@@ -26,63 +31,66 @@ class RateLimiter:
|
||||
base_delay: Tuple[float, float] = (1.0, 3.0),
|
||||
max_delay: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
rate_limit_codes: List[int] = None
|
||||
rate_limit_codes: List[int] = None,
|
||||
):
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.max_retries = max_retries
|
||||
self.rate_limit_codes = rate_limit_codes or [429, 503]
|
||||
self.domains: Dict[str, DomainState] = {}
|
||||
|
||||
|
||||
def get_domain(self, url: str) -> str:
|
||||
return urlparse(url).netloc
|
||||
|
||||
|
||||
async def wait_if_needed(self, url: str) -> None:
|
||||
domain = self.get_domain(url)
|
||||
state = self.domains.get(domain)
|
||||
|
||||
|
||||
if not state:
|
||||
self.domains[domain] = DomainState()
|
||||
state = self.domains[domain]
|
||||
|
||||
|
||||
now = time.time()
|
||||
if state.last_request_time:
|
||||
wait_time = max(0, state.current_delay - (now - state.last_request_time))
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
|
||||
# Random delay within base range if no current delay
|
||||
if state.current_delay == 0:
|
||||
state.current_delay = random.uniform(*self.base_delay)
|
||||
|
||||
|
||||
state.last_request_time = time.time()
|
||||
|
||||
|
||||
def update_delay(self, url: str, status_code: int) -> bool:
|
||||
domain = self.get_domain(url)
|
||||
state = self.domains[domain]
|
||||
|
||||
|
||||
if status_code in self.rate_limit_codes:
|
||||
state.fail_count += 1
|
||||
if state.fail_count > self.max_retries:
|
||||
return False
|
||||
|
||||
|
||||
# Exponential backoff with random jitter
|
||||
state.current_delay = min(
|
||||
state.current_delay * 2 * random.uniform(0.75, 1.25),
|
||||
self.max_delay
|
||||
state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay
|
||||
)
|
||||
else:
|
||||
# Gradually reduce delay on success
|
||||
state.current_delay = max(
|
||||
random.uniform(*self.base_delay),
|
||||
state.current_delay * 0.75
|
||||
random.uniform(*self.base_delay), state.current_delay * 0.75
|
||||
)
|
||||
state.fail_count = 0
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class CrawlerMonitor:
|
||||
def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED):
|
||||
def __init__(
|
||||
self,
|
||||
max_visible_rows: int = 15,
|
||||
display_mode: DisplayMode = DisplayMode.DETAILED,
|
||||
):
|
||||
self.console = Console()
|
||||
self.max_visible_rows = max_visible_rows
|
||||
self.display_mode = display_mode
|
||||
@@ -90,23 +98,25 @@ class CrawlerMonitor:
|
||||
self.process = psutil.Process()
|
||||
self.start_time = datetime.now()
|
||||
self.live = Live(self._create_table(), refresh_per_second=2)
|
||||
|
||||
|
||||
def start(self):
|
||||
self.live.start()
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.live.stop()
|
||||
|
||||
|
||||
def add_task(self, task_id: str, url: str):
|
||||
self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED)
|
||||
self.stats[task_id] = CrawlStats(
|
||||
task_id=task_id, url=url, status=CrawlStatus.QUEUED
|
||||
)
|
||||
self.live.update(self._create_table())
|
||||
|
||||
|
||||
def update_task(self, task_id: str, **kwargs):
|
||||
if task_id in self.stats:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.stats[task_id], key, value)
|
||||
self.live.update(self._create_table())
|
||||
|
||||
|
||||
def _create_aggregated_table(self) -> Table:
|
||||
"""Creates a compact table showing only aggregated statistics"""
|
||||
table = Table(
|
||||
@@ -114,78 +124,78 @@ class CrawlerMonitor:
|
||||
title="Crawler Status Overview",
|
||||
title_style="bold magenta",
|
||||
header_style="bold blue",
|
||||
show_lines=True
|
||||
show_lines=True,
|
||||
)
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
total_tasks = len(self.stats)
|
||||
queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED)
|
||||
in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS)
|
||||
completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED)
|
||||
failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED)
|
||||
|
||||
queued = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED
|
||||
)
|
||||
in_progress = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
|
||||
)
|
||||
completed = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
|
||||
)
|
||||
failed = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
|
||||
)
|
||||
|
||||
# Memory statistics
|
||||
current_memory = self.process.memory_info().rss / (1024 * 1024)
|
||||
total_task_memory = sum(stat.memory_usage for stat in self.stats.values())
|
||||
peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0)
|
||||
|
||||
peak_memory = max(
|
||||
(stat.peak_memory for stat in self.stats.values()), default=0.0
|
||||
)
|
||||
|
||||
# Duration
|
||||
duration = datetime.now() - self.start_time
|
||||
|
||||
|
||||
# Create status row
|
||||
table.add_column("Status", style="bold cyan")
|
||||
table.add_column("Count", justify="right")
|
||||
table.add_column("Percentage", justify="right")
|
||||
|
||||
table.add_row(
|
||||
"Total Tasks",
|
||||
str(total_tasks),
|
||||
"100%"
|
||||
)
|
||||
|
||||
table.add_row("Total Tasks", str(total_tasks), "100%")
|
||||
table.add_row(
|
||||
"[yellow]In Queue[/yellow]",
|
||||
str(queued),
|
||||
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
||||
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||
)
|
||||
table.add_row(
|
||||
"[blue]In Progress[/blue]",
|
||||
str(in_progress),
|
||||
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
||||
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||
)
|
||||
table.add_row(
|
||||
"[green]Completed[/green]",
|
||||
str(completed),
|
||||
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
||||
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||
)
|
||||
table.add_row(
|
||||
"[red]Failed[/red]",
|
||||
str(failed),
|
||||
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
||||
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||
)
|
||||
|
||||
|
||||
# Add memory information
|
||||
table.add_section()
|
||||
table.add_row(
|
||||
"[magenta]Current Memory[/magenta]",
|
||||
f"{current_memory:.1f} MB",
|
||||
""
|
||||
"[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", ""
|
||||
)
|
||||
table.add_row(
|
||||
"[magenta]Total Task Memory[/magenta]",
|
||||
f"{total_task_memory:.1f} MB",
|
||||
""
|
||||
"[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", ""
|
||||
)
|
||||
table.add_row(
|
||||
"[magenta]Peak Task Memory[/magenta]",
|
||||
f"{peak_memory:.1f} MB",
|
||||
""
|
||||
"[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", ""
|
||||
)
|
||||
table.add_row(
|
||||
"[yellow]Runtime[/yellow]",
|
||||
str(timedelta(seconds=int(duration.total_seconds()))),
|
||||
""
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
return table
|
||||
|
||||
def _create_detailed_table(self) -> Table:
|
||||
@@ -193,9 +203,9 @@ class CrawlerMonitor:
|
||||
box=box.ROUNDED,
|
||||
title="Crawler Performance Monitor",
|
||||
title_style="bold magenta",
|
||||
header_style="bold blue"
|
||||
header_style="bold blue",
|
||||
)
|
||||
|
||||
|
||||
# Add columns
|
||||
table.add_column("Task ID", style="cyan", no_wrap=True)
|
||||
table.add_column("URL", style="cyan", no_wrap=True)
|
||||
@@ -204,47 +214,54 @@ class CrawlerMonitor:
|
||||
table.add_column("Peak (MB)", justify="right")
|
||||
table.add_column("Duration", justify="right")
|
||||
table.add_column("Info", style="italic")
|
||||
|
||||
|
||||
# Add summary row
|
||||
total_memory = sum(stat.memory_usage for stat in self.stats.values())
|
||||
active_count = sum(1 for stat in self.stats.values()
|
||||
if stat.status == CrawlStatus.IN_PROGRESS)
|
||||
completed_count = sum(1 for stat in self.stats.values()
|
||||
if stat.status == CrawlStatus.COMPLETED)
|
||||
failed_count = sum(1 for stat in self.stats.values()
|
||||
if stat.status == CrawlStatus.FAILED)
|
||||
|
||||
active_count = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
|
||||
)
|
||||
completed_count = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
|
||||
)
|
||||
failed_count = sum(
|
||||
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
|
||||
)
|
||||
|
||||
table.add_row(
|
||||
"[bold yellow]SUMMARY",
|
||||
f"Total: {len(self.stats)}",
|
||||
f"Active: {active_count}",
|
||||
f"{total_memory:.1f}",
|
||||
f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
|
||||
str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))),
|
||||
str(
|
||||
timedelta(
|
||||
seconds=int((datetime.now() - self.start_time).total_seconds())
|
||||
)
|
||||
),
|
||||
f"✓{completed_count} ✗{failed_count}",
|
||||
style="bold"
|
||||
style="bold",
|
||||
)
|
||||
|
||||
|
||||
table.add_section()
|
||||
|
||||
|
||||
# Add rows for each task
|
||||
visible_stats = sorted(
|
||||
self.stats.values(),
|
||||
key=lambda x: (
|
||||
x.status != CrawlStatus.IN_PROGRESS,
|
||||
x.status != CrawlStatus.QUEUED,
|
||||
x.end_time or datetime.max
|
||||
)
|
||||
)[:self.max_visible_rows]
|
||||
|
||||
x.end_time or datetime.max,
|
||||
),
|
||||
)[: self.max_visible_rows]
|
||||
|
||||
for stat in visible_stats:
|
||||
status_style = {
|
||||
CrawlStatus.QUEUED: "white",
|
||||
CrawlStatus.IN_PROGRESS: "yellow",
|
||||
CrawlStatus.COMPLETED: "green",
|
||||
CrawlStatus.FAILED: "red"
|
||||
CrawlStatus.FAILED: "red",
|
||||
}[stat.status]
|
||||
|
||||
|
||||
table.add_row(
|
||||
stat.task_id[:8], # Show first 8 chars of task ID
|
||||
stat.url[:40] + "..." if len(stat.url) > 40 else stat.url,
|
||||
@@ -252,9 +269,9 @@ class CrawlerMonitor:
|
||||
f"{stat.memory_usage:.1f}",
|
||||
f"{stat.peak_memory:.1f}",
|
||||
stat.duration,
|
||||
stat.error_message[:40] if stat.error_message else ""
|
||||
stat.error_message[:40] if stat.error_message else "",
|
||||
)
|
||||
|
||||
|
||||
return table
|
||||
|
||||
def _create_table(self) -> Table:
|
||||
@@ -268,7 +285,7 @@ class BaseDispatcher(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
monitor: Optional[CrawlerMonitor] = None
|
||||
monitor: Optional[CrawlerMonitor] = None,
|
||||
):
|
||||
self.crawler = None
|
||||
self._domain_last_hit: Dict[str, float] = {}
|
||||
@@ -278,24 +295,25 @@ class BaseDispatcher(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def crawl_url(
|
||||
self,
|
||||
url: str,
|
||||
config: CrawlerRunConfig,
|
||||
self,
|
||||
url: str,
|
||||
config: CrawlerRunConfig,
|
||||
task_id: str,
|
||||
monitor: Optional[CrawlerMonitor] = None
|
||||
monitor: Optional[CrawlerMonitor] = None,
|
||||
) -> CrawlerTaskResult:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_urls(
|
||||
self,
|
||||
urls: List[str],
|
||||
crawler: "AsyncWebCrawler",
|
||||
self,
|
||||
urls: List[str],
|
||||
crawler: "AsyncWebCrawler", # noqa: F821
|
||||
config: CrawlerRunConfig,
|
||||
monitor: Optional[CrawlerMonitor] = None
|
||||
monitor: Optional[CrawlerMonitor] = None,
|
||||
) -> List[CrawlerTaskResult]:
|
||||
pass
|
||||
|
||||
|
||||
class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -304,39 +322,41 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
max_session_permit: int = 20,
|
||||
memory_wait_timeout: float = 300.0, # 5 minutes default timeout
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
monitor: Optional[CrawlerMonitor] = None
|
||||
monitor: Optional[CrawlerMonitor] = None,
|
||||
):
|
||||
super().__init__(rate_limiter, monitor)
|
||||
self.memory_threshold_percent = memory_threshold_percent
|
||||
self.check_interval = check_interval
|
||||
self.max_session_permit = max_session_permit
|
||||
self.memory_wait_timeout = memory_wait_timeout
|
||||
|
||||
|
||||
async def crawl_url(
|
||||
self,
|
||||
url: str,
|
||||
config: CrawlerRunConfig,
|
||||
self,
|
||||
url: str,
|
||||
config: CrawlerRunConfig,
|
||||
task_id: str,
|
||||
) -> CrawlerTaskResult:
|
||||
start_time = datetime.now()
|
||||
error_message = ""
|
||||
memory_usage = peak_memory = 0.0
|
||||
|
||||
|
||||
try:
|
||||
if self.monitor:
|
||||
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time)
|
||||
self.monitor.update_task(
|
||||
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
|
||||
)
|
||||
self.concurrent_sessions += 1
|
||||
|
||||
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.wait_if_needed(url)
|
||||
|
||||
|
||||
process = psutil.Process()
|
||||
start_memory = process.memory_info().rss / (1024 * 1024)
|
||||
result = await self.crawler.arun(url, config=config, session_id=task_id)
|
||||
end_memory = process.memory_info().rss / (1024 * 1024)
|
||||
|
||||
|
||||
memory_usage = peak_memory = end_memory - start_memory
|
||||
|
||||
|
||||
if self.rate_limiter and result.status_code:
|
||||
if not self.rate_limiter.update_delay(url, result.status_code):
|
||||
error_message = f"Rate limit retry count exceeded for domain {urlparse(url).netloc}"
|
||||
@@ -350,22 +370,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
peak_memory=peak_memory,
|
||||
start_time=start_time,
|
||||
end_time=datetime.now(),
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
|
||||
if not result.success:
|
||||
error_message = result.error_message
|
||||
if self.monitor:
|
||||
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
||||
elif self.monitor:
|
||||
self.monitor.update_task(task_id, status=CrawlStatus.COMPLETED)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
if self.monitor:
|
||||
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
||||
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e))
|
||||
|
||||
result = CrawlResult(
|
||||
url=url, html="", metadata={}, success=False, error_message=str(e)
|
||||
)
|
||||
|
||||
finally:
|
||||
end_time = datetime.now()
|
||||
if self.monitor:
|
||||
@@ -374,10 +396,10 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
end_time=end_time,
|
||||
memory_usage=memory_usage,
|
||||
peak_memory=peak_memory,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
self.concurrent_sessions -= 1
|
||||
|
||||
|
||||
return CrawlerTaskResult(
|
||||
task_id=task_id,
|
||||
url=url,
|
||||
@@ -386,20 +408,20 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
peak_memory=peak_memory,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
async def run_urls(
|
||||
self,
|
||||
urls: List[str],
|
||||
crawler: "AsyncWebCrawler",
|
||||
self,
|
||||
urls: List[str],
|
||||
crawler: "AsyncWebCrawler", # noqa: F821
|
||||
config: CrawlerRunConfig,
|
||||
) -> List[CrawlerTaskResult]:
|
||||
self.crawler = crawler
|
||||
|
||||
|
||||
if self.monitor:
|
||||
self.monitor.start()
|
||||
|
||||
|
||||
try:
|
||||
pending_tasks = []
|
||||
active_tasks = []
|
||||
@@ -417,23 +439,24 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
if psutil.virtual_memory().percent >= self.memory_threshold_percent:
|
||||
# Check if we've exceeded the timeout
|
||||
if time.time() - wait_start_time > self.memory_wait_timeout:
|
||||
raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds")
|
||||
raise MemoryError(
|
||||
f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds"
|
||||
)
|
||||
await asyncio.sleep(self.check_interval)
|
||||
continue
|
||||
|
||||
|
||||
url, task_id = task_queue.pop(0)
|
||||
task = asyncio.create_task(self.crawl_url(url, config, task_id))
|
||||
active_tasks.append(task)
|
||||
|
||||
|
||||
if not active_tasks:
|
||||
await asyncio.sleep(self.check_interval)
|
||||
continue
|
||||
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
active_tasks,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
active_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
|
||||
pending_tasks.extend(done)
|
||||
active_tasks = list(pending)
|
||||
|
||||
@@ -442,24 +465,25 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||
if self.monitor:
|
||||
self.monitor.stop()
|
||||
|
||||
|
||||
class SemaphoreDispatcher(BaseDispatcher):
|
||||
def __init__(
|
||||
self,
|
||||
semaphore_count: int = 5,
|
||||
max_session_permit: int = 20,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
monitor: Optional[CrawlerMonitor] = None
|
||||
monitor: Optional[CrawlerMonitor] = None,
|
||||
):
|
||||
super().__init__(rate_limiter, monitor)
|
||||
self.semaphore_count = semaphore_count
|
||||
self.max_session_permit = max_session_permit
|
||||
|
||||
self.max_session_permit = max_session_permit
|
||||
|
||||
async def crawl_url(
|
||||
self,
|
||||
url: str,
|
||||
config: CrawlerRunConfig,
|
||||
self,
|
||||
url: str,
|
||||
config: CrawlerRunConfig,
|
||||
task_id: str,
|
||||
semaphore: asyncio.Semaphore = None
|
||||
semaphore: asyncio.Semaphore = None,
|
||||
) -> CrawlerTaskResult:
|
||||
start_time = datetime.now()
|
||||
error_message = ""
|
||||
@@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
|
||||
try:
|
||||
if self.monitor:
|
||||
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time)
|
||||
self.monitor.update_task(
|
||||
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
|
||||
)
|
||||
|
||||
if self.rate_limiter:
|
||||
await self.rate_limiter.wait_if_needed(url)
|
||||
@@ -477,7 +503,7 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
start_memory = process.memory_info().rss / (1024 * 1024)
|
||||
result = await self.crawler.arun(url, config=config, session_id=task_id)
|
||||
end_memory = process.memory_info().rss / (1024 * 1024)
|
||||
|
||||
|
||||
memory_usage = peak_memory = end_memory - start_memory
|
||||
|
||||
if self.rate_limiter and result.status_code:
|
||||
@@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
peak_memory=peak_memory,
|
||||
start_time=start_time,
|
||||
end_time=datetime.now(),
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
@@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
error_message = str(e)
|
||||
if self.monitor:
|
||||
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
||||
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e))
|
||||
result = CrawlResult(
|
||||
url=url, html="", metadata={}, success=False, error_message=str(e)
|
||||
)
|
||||
|
||||
finally:
|
||||
end_time = datetime.now()
|
||||
@@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
end_time=end_time,
|
||||
memory_usage=memory_usage,
|
||||
peak_memory=peak_memory,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
return CrawlerTaskResult(
|
||||
@@ -528,13 +556,13 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
peak_memory=peak_memory,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
error_message=error_message
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
async def run_urls(
|
||||
self,
|
||||
crawler: "AsyncWebCrawler",
|
||||
urls: List[str],
|
||||
self,
|
||||
crawler: "AsyncWebCrawler", # noqa: F821
|
||||
urls: List[str],
|
||||
config: CrawlerRunConfig,
|
||||
) -> List[CrawlerTaskResult]:
|
||||
self.crawler = crawler
|
||||
@@ -557,4 +585,4 @@ class SemaphoreDispatcher(BaseDispatcher):
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
finally:
|
||||
if self.monitor:
|
||||
self.monitor.stop()
|
||||
self.monitor.stop()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from colorama import Fore, Back, Style, init
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from colorama import Fore, Style, init
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
DEBUG = 1
|
||||
INFO = 2
|
||||
@@ -12,23 +12,24 @@ class LogLevel(Enum):
|
||||
WARNING = 4
|
||||
ERROR = 5
|
||||
|
||||
|
||||
class AsyncLogger:
|
||||
"""
|
||||
Asynchronous logger with support for colored console output and file logging.
|
||||
Supports templated messages with colored components.
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_ICONS = {
|
||||
'INIT': '→',
|
||||
'READY': '✓',
|
||||
'FETCH': '↓',
|
||||
'SCRAPE': '◆',
|
||||
'EXTRACT': '■',
|
||||
'COMPLETE': '●',
|
||||
'ERROR': '×',
|
||||
'DEBUG': '⋯',
|
||||
'INFO': 'ℹ',
|
||||
'WARNING': '⚠',
|
||||
"INIT": "→",
|
||||
"READY": "✓",
|
||||
"FETCH": "↓",
|
||||
"SCRAPE": "◆",
|
||||
"EXTRACT": "■",
|
||||
"COMPLETE": "●",
|
||||
"ERROR": "×",
|
||||
"DEBUG": "⋯",
|
||||
"INFO": "ℹ",
|
||||
"WARNING": "⚠",
|
||||
}
|
||||
|
||||
DEFAULT_COLORS = {
|
||||
@@ -46,11 +47,11 @@ class AsyncLogger:
|
||||
tag_width: int = 10,
|
||||
icons: Optional[Dict[str, str]] = None,
|
||||
colors: Optional[Dict[LogLevel, str]] = None,
|
||||
verbose: bool = True
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the logger.
|
||||
|
||||
|
||||
Args:
|
||||
log_file: Optional file path for logging
|
||||
log_level: Minimum log level to display
|
||||
@@ -66,7 +67,7 @@ class AsyncLogger:
|
||||
self.icons = icons or self.DEFAULT_ICONS
|
||||
self.colors = colors or self.DEFAULT_COLORS
|
||||
self.verbose = verbose
|
||||
|
||||
|
||||
# Create log file directory if needed
|
||||
if log_file:
|
||||
os.makedirs(os.path.dirname(os.path.abspath(log_file)), exist_ok=True)
|
||||
@@ -77,18 +78,20 @@ class AsyncLogger:
|
||||
|
||||
def _get_icon(self, tag: str) -> str:
|
||||
"""Get the icon for a tag, defaulting to info icon if not found."""
|
||||
return self.icons.get(tag, self.icons['INFO'])
|
||||
return self.icons.get(tag, self.icons["INFO"])
|
||||
|
||||
def _write_to_file(self, message: str):
|
||||
"""Write a message to the log file if configured."""
|
||||
if self.log_file:
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
|
||||
with open(self.log_file, 'a', encoding='utf-8') as f:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||
# Strip ANSI color codes for file output
|
||||
clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '')
|
||||
clean_message = message.replace(Fore.RESET, "").replace(
|
||||
Style.RESET_ALL, ""
|
||||
)
|
||||
for color in vars(Fore).values():
|
||||
if isinstance(color, str):
|
||||
clean_message = clean_message.replace(color, '')
|
||||
clean_message = clean_message.replace(color, "")
|
||||
f.write(f"[{timestamp}] {clean_message}\n")
|
||||
|
||||
def _log(
|
||||
@@ -99,11 +102,11 @@ class AsyncLogger:
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
colors: Optional[Dict[str, str]] = None,
|
||||
base_color: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Core logging method that handles message formatting and output.
|
||||
|
||||
|
||||
Args:
|
||||
level: Log level for this message
|
||||
message: Message template string
|
||||
@@ -120,7 +123,7 @@ class AsyncLogger:
|
||||
try:
|
||||
# First format the message with raw parameters
|
||||
formatted_message = message.format(**params)
|
||||
|
||||
|
||||
# Then apply colors if specified
|
||||
if colors:
|
||||
for key, color in colors.items():
|
||||
@@ -128,12 +131,13 @@ class AsyncLogger:
|
||||
if key in params:
|
||||
value_str = str(params[key])
|
||||
formatted_message = formatted_message.replace(
|
||||
value_str,
|
||||
f"{color}{value_str}{Style.RESET_ALL}"
|
||||
value_str, f"{color}{value_str}{Style.RESET_ALL}"
|
||||
)
|
||||
|
||||
|
||||
except KeyError as e:
|
||||
formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template"
|
||||
formatted_message = (
|
||||
f"LOGGING ERROR: Missing parameter {e} in message template"
|
||||
)
|
||||
level = LogLevel.ERROR
|
||||
else:
|
||||
formatted_message = message
|
||||
@@ -175,11 +179,11 @@ class AsyncLogger:
|
||||
success: bool,
|
||||
timing: float,
|
||||
tag: str = "FETCH",
|
||||
url_length: int = 50
|
||||
url_length: int = 50,
|
||||
):
|
||||
"""
|
||||
Convenience method for logging URL fetch status.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL being processed
|
||||
success: Whether the operation was successful
|
||||
@@ -195,24 +199,20 @@ class AsyncLogger:
|
||||
"url": url,
|
||||
"url_length": url_length,
|
||||
"status": success,
|
||||
"timing": timing
|
||||
"timing": timing,
|
||||
},
|
||||
colors={
|
||||
"status": Fore.GREEN if success else Fore.RED,
|
||||
"timing": Fore.YELLOW
|
||||
}
|
||||
"timing": Fore.YELLOW,
|
||||
},
|
||||
)
|
||||
|
||||
def error_status(
|
||||
self,
|
||||
url: str,
|
||||
error: str,
|
||||
tag: str = "ERROR",
|
||||
url_length: int = 50
|
||||
self, url: str, error: str, tag: str = "ERROR", url_length: int = 50
|
||||
):
|
||||
"""
|
||||
Convenience method for logging error status.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL being processed
|
||||
error: Error message
|
||||
@@ -223,9 +223,5 @@ class AsyncLogger:
|
||||
level=LogLevel.ERROR,
|
||||
message="{url:.{url_length}}... | Error: {error}",
|
||||
tag=tag,
|
||||
params={
|
||||
"url": url,
|
||||
"url_length": url_length,
|
||||
"error": error
|
||||
}
|
||||
)
|
||||
params={"url": url, "url_length": url_length, "error": error},
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
class CacheMode(Enum):
|
||||
"""
|
||||
Defines the caching behavior for web crawling operations.
|
||||
|
||||
|
||||
Modes:
|
||||
- ENABLED: Normal caching behavior (read and write)
|
||||
- DISABLED: No caching at all
|
||||
@@ -12,6 +12,7 @@ class CacheMode(Enum):
|
||||
- WRITE_ONLY: Only write to cache, don't read
|
||||
- BYPASS: Bypass cache for this operation
|
||||
"""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
READ_ONLY = "read_only"
|
||||
@@ -22,10 +23,10 @@ class CacheMode(Enum):
|
||||
class CacheContext:
|
||||
"""
|
||||
Encapsulates cache-related decisions and URL handling.
|
||||
|
||||
|
||||
This class centralizes all cache-related logic and URL type checking,
|
||||
making the caching behavior more predictable and maintainable.
|
||||
|
||||
|
||||
Attributes:
|
||||
url (str): The URL being processed.
|
||||
cache_mode (CacheMode): The cache mode for the current operation.
|
||||
@@ -36,10 +37,11 @@ class CacheContext:
|
||||
is_raw_html (bool): True if the URL is raw HTML, False otherwise.
|
||||
_url_display (str): The display name for the URL (web, local file, or raw HTML).
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False):
|
||||
"""
|
||||
Initializes the CacheContext with the provided URL and cache mode.
|
||||
|
||||
|
||||
Args:
|
||||
url (str): The URL being processed.
|
||||
cache_mode (CacheMode): The cache mode for the current operation.
|
||||
@@ -48,42 +50,42 @@ class CacheContext:
|
||||
self.url = url
|
||||
self.cache_mode = cache_mode
|
||||
self.always_bypass = always_bypass
|
||||
self.is_cacheable = url.startswith(('http://', 'https://', 'file://'))
|
||||
self.is_web_url = url.startswith(('http://', 'https://'))
|
||||
self.is_cacheable = url.startswith(("http://", "https://", "file://"))
|
||||
self.is_web_url = url.startswith(("http://", "https://"))
|
||||
self.is_local_file = url.startswith("file://")
|
||||
self.is_raw_html = url.startswith("raw:")
|
||||
self._url_display = url if not self.is_raw_html else "Raw HTML"
|
||||
|
||||
|
||||
def should_read(self) -> bool:
|
||||
"""
|
||||
Determines if cache should be read based on context.
|
||||
|
||||
|
||||
How it works:
|
||||
1. If always_bypass is True or is_cacheable is False, return False.
|
||||
2. If cache_mode is ENABLED or READ_ONLY, return True.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if cache should be read, False otherwise.
|
||||
"""
|
||||
if self.always_bypass or not self.is_cacheable:
|
||||
return False
|
||||
return self.cache_mode in [CacheMode.ENABLED, CacheMode.READ_ONLY]
|
||||
|
||||
|
||||
def should_write(self) -> bool:
|
||||
"""
|
||||
Determines if cache should be written based on context.
|
||||
|
||||
|
||||
How it works:
|
||||
1. If always_bypass is True or is_cacheable is False, return False.
|
||||
2. If cache_mode is ENABLED or WRITE_ONLY, return True.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if cache should be written, False otherwise.
|
||||
"""
|
||||
if self.always_bypass or not self.is_cacheable:
|
||||
return False
|
||||
return self.cache_mode in [CacheMode.ENABLED, CacheMode.WRITE_ONLY]
|
||||
|
||||
|
||||
@property
|
||||
def display_url(self) -> str:
|
||||
"""Returns the URL in display format."""
|
||||
@@ -94,11 +96,11 @@ def _legacy_to_cache_mode(
|
||||
disable_cache: bool = False,
|
||||
bypass_cache: bool = False,
|
||||
no_cache_read: bool = False,
|
||||
no_cache_write: bool = False
|
||||
no_cache_write: bool = False,
|
||||
) -> CacheMode:
|
||||
"""
|
||||
Converts legacy cache parameters to the new CacheMode enum.
|
||||
|
||||
|
||||
This is an internal function to help transition from the old boolean flags
|
||||
to the new CacheMode system.
|
||||
"""
|
||||
|
||||
@@ -3,49 +3,53 @@ import re
|
||||
from collections import Counter
|
||||
import string
|
||||
from .model_loader import load_nltk_punkt
|
||||
from .utils import *
|
||||
|
||||
|
||||
# Define the abstract base class for chunking strategies
|
||||
class ChunkingStrategy(ABC):
|
||||
"""
|
||||
Abstract base class for chunking strategies.
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str) -> list:
|
||||
"""
|
||||
Abstract method to chunk the given text.
|
||||
|
||||
|
||||
Args:
|
||||
text (str): The text to chunk.
|
||||
|
||||
|
||||
Returns:
|
||||
list: A list of chunks.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Create an identity chunking strategy f(x) = [x]
|
||||
class IdentityChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that returns the input text as a single chunk.
|
||||
"""
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
return [text]
|
||||
|
||||
|
||||
# Regex-based chunking
|
||||
class RegexChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that splits text based on regular expression patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, patterns=None, **kwargs):
|
||||
"""
|
||||
Initialize the RegexChunking object.
|
||||
|
||||
|
||||
Args:
|
||||
patterns (list): A list of regular expression patterns to split text.
|
||||
"""
|
||||
if patterns is None:
|
||||
patterns = [r'\n\n'] # Default split pattern
|
||||
patterns = [r"\n\n"] # Default split pattern
|
||||
self.patterns = patterns
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
@@ -56,18 +60,19 @@ class RegexChunking(ChunkingStrategy):
|
||||
new_paragraphs.extend(re.split(pattern, paragraph))
|
||||
paragraphs = new_paragraphs
|
||||
return paragraphs
|
||||
|
||||
# NLP-based sentence chunking
|
||||
|
||||
|
||||
# NLP-based sentence chunking
|
||||
class NlpSentenceChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that splits text into sentences using NLTK's sentence tokenizer.
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize the NlpSentenceChunking object.
|
||||
"""
|
||||
load_nltk_punkt()
|
||||
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
# Improved regex for sentence splitting
|
||||
@@ -75,31 +80,34 @@ class NlpSentenceChunking(ChunkingStrategy):
|
||||
# r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z][A-Z]\.)(?<![A-Za-z]\.)(?<=\.|\?|\!|\n)\s'
|
||||
# )
|
||||
# sentences = sentence_endings.split(text)
|
||||
# sens = [sent.strip() for sent in sentences if sent]
|
||||
# sens = [sent.strip() for sent in sentences if sent]
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
sentences = sent_tokenize(text)
|
||||
sens = [sent.strip() for sent in sentences]
|
||||
|
||||
sens = [sent.strip() for sent in sentences]
|
||||
|
||||
return list(set(sens))
|
||||
|
||||
|
||||
|
||||
# Topic-based segmentation using TextTiling
|
||||
class TopicSegmentationChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that segments text into topics using NLTK's TextTilingTokenizer.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Segment the text into topics using TextTilingTokenizer
|
||||
2. Extract keywords for each topic segment
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, num_keywords=3, **kwargs):
|
||||
"""
|
||||
Initialize the TopicSegmentationChunking object.
|
||||
|
||||
|
||||
Args:
|
||||
num_keywords (int): The number of keywords to extract for each topic segment.
|
||||
"""
|
||||
import nltk as nl
|
||||
|
||||
self.tokenizer = nl.tokenize.TextTilingTokenizer()
|
||||
self.num_keywords = num_keywords
|
||||
|
||||
@@ -111,8 +119,14 @@ class TopicSegmentationChunking(ChunkingStrategy):
|
||||
def extract_keywords(self, text: str) -> list:
|
||||
# Tokenize and remove stopwords and punctuation
|
||||
import nltk as nl
|
||||
|
||||
tokens = nl.toknize.word_tokenize(text)
|
||||
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
|
||||
tokens = [
|
||||
token.lower()
|
||||
for token in tokens
|
||||
if token not in nl.corpus.stopwords.words("english")
|
||||
and token not in string.punctuation
|
||||
]
|
||||
|
||||
# Calculate frequency distribution
|
||||
freq_dist = Counter(tokens)
|
||||
@@ -123,23 +137,27 @@ class TopicSegmentationChunking(ChunkingStrategy):
|
||||
# Segment the text into topics
|
||||
segments = self.chunk(text)
|
||||
# Extract keywords for each topic segment
|
||||
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
|
||||
segments_with_topics = [
|
||||
(segment, self.extract_keywords(segment)) for segment in segments
|
||||
]
|
||||
return segments_with_topics
|
||||
|
||||
|
||||
|
||||
# Fixed-length word chunks
|
||||
class FixedLengthWordChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that splits text into fixed-length word chunks.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Split the text into words
|
||||
2. Create chunks of fixed length
|
||||
3. Return the list of chunks
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size=100, **kwargs):
|
||||
"""
|
||||
Initialize the fixed-length word chunking strategy with the given chunk size.
|
||||
|
||||
|
||||
Args:
|
||||
chunk_size (int): The size of each chunk in words.
|
||||
"""
|
||||
@@ -147,23 +165,28 @@ class FixedLengthWordChunking(ChunkingStrategy):
|
||||
|
||||
def chunk(self, text: str) -> list:
|
||||
words = text.split()
|
||||
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
|
||||
|
||||
return [
|
||||
" ".join(words[i : i + self.chunk_size])
|
||||
for i in range(0, len(words), self.chunk_size)
|
||||
]
|
||||
|
||||
|
||||
# Sliding window chunking
|
||||
class SlidingWindowChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that splits text into overlapping word chunks.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Split the text into words
|
||||
2. Create chunks of fixed length
|
||||
3. Return the list of chunks
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=100, step=50, **kwargs):
|
||||
"""
|
||||
Initialize the sliding window chunking strategy with the given window size and
|
||||
step size.
|
||||
|
||||
|
||||
Args:
|
||||
window_size (int): The size of the sliding window in words.
|
||||
step (int): The step size for sliding the window in words.
|
||||
@@ -174,35 +197,37 @@ class SlidingWindowChunking(ChunkingStrategy):
|
||||
def chunk(self, text: str) -> list:
|
||||
words = text.split()
|
||||
chunks = []
|
||||
|
||||
|
||||
if len(words) <= self.window_size:
|
||||
return [text]
|
||||
|
||||
|
||||
for i in range(0, len(words) - self.window_size + 1, self.step):
|
||||
chunk = ' '.join(words[i:i + self.window_size])
|
||||
chunk = " ".join(words[i : i + self.window_size])
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
# Handle the last chunk if it doesn't align perfectly
|
||||
if i + self.window_size < len(words):
|
||||
chunks.append(' '.join(words[-self.window_size:]))
|
||||
|
||||
chunks.append(" ".join(words[-self.window_size :]))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
|
||||
class OverlappingWindowChunking(ChunkingStrategy):
|
||||
"""
|
||||
Chunking strategy that splits text into overlapping word chunks.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Split the text into words using whitespace
|
||||
2. Create chunks of fixed length equal to the window size
|
||||
3. Slide the window by the overlap size
|
||||
4. Return the list of chunks
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=1000, overlap=100, **kwargs):
|
||||
"""
|
||||
Initialize the overlapping window chunking strategy with the given window size and
|
||||
overlap size.
|
||||
|
||||
|
||||
Args:
|
||||
window_size (int): The size of the window in words.
|
||||
overlap (int): The size of the overlap between consecutive chunks in words.
|
||||
@@ -213,19 +238,19 @@ class OverlappingWindowChunking(ChunkingStrategy):
|
||||
def chunk(self, text: str) -> list:
|
||||
words = text.split()
|
||||
chunks = []
|
||||
|
||||
|
||||
if len(words) <= self.window_size:
|
||||
return [text]
|
||||
|
||||
|
||||
start = 0
|
||||
while start < len(words):
|
||||
end = start + self.window_size
|
||||
chunk = ' '.join(words[start:end])
|
||||
chunk = " ".join(words[start:end])
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
if end >= len(words):
|
||||
break
|
||||
|
||||
|
||||
start = end - self.overlap
|
||||
|
||||
return chunks
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -8,15 +8,22 @@ from .async_logger import AsyncLogger
|
||||
logger = AsyncLogger(verbose=True)
|
||||
docs_manager = DocsManager(logger)
|
||||
|
||||
|
||||
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
|
||||
"""Print formatted table with headers and rows"""
|
||||
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
|
||||
border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+'
|
||||
|
||||
border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+"
|
||||
|
||||
def format_row(row):
|
||||
return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
|
||||
for cell, w in zip(row, widths)) + '|'
|
||||
|
||||
return (
|
||||
"|"
|
||||
+ "|".join(
|
||||
f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
|
||||
for cell, w in zip(row, widths)
|
||||
)
|
||||
+ "|"
|
||||
)
|
||||
|
||||
click.echo(border)
|
||||
click.echo(format_row(headers))
|
||||
click.echo(border)
|
||||
@@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
|
||||
click.echo(format_row(row))
|
||||
click.echo(border)
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""Crawl4AI Command Line Interface"""
|
||||
pass
|
||||
|
||||
|
||||
@cli.group()
|
||||
def docs():
|
||||
"""Documentation operations"""
|
||||
pass
|
||||
|
||||
|
||||
@docs.command()
|
||||
@click.argument('sections', nargs=-1)
|
||||
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended')
|
||||
@click.argument("sections", nargs=-1)
|
||||
@click.option(
|
||||
"--mode", type=click.Choice(["extended", "condensed"]), default="extended"
|
||||
)
|
||||
def combine(sections: tuple, mode: str):
|
||||
"""Combine documentation sections"""
|
||||
try:
|
||||
@@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str):
|
||||
logger.error(str(e), tag="ERROR")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@docs.command()
|
||||
@click.argument('query')
|
||||
@click.option('--top-k', '-k', default=5)
|
||||
@click.option('--build-index', is_flag=True, help='Build index if missing')
|
||||
@click.argument("query")
|
||||
@click.option("--top-k", "-k", default=5)
|
||||
@click.option("--build-index", is_flag=True, help="Build index if missing")
|
||||
def search(query: str, top_k: int, build_index: bool):
|
||||
"""Search documentation"""
|
||||
try:
|
||||
result = docs_manager.search(query, top_k)
|
||||
if result == "No search index available. Call build_search_index() first.":
|
||||
if build_index or click.confirm('No search index found. Build it now?'):
|
||||
if build_index or click.confirm("No search index found. Build it now?"):
|
||||
asyncio.run(docs_manager.llm_text.generate_index_files())
|
||||
result = docs_manager.search(query, top_k)
|
||||
click.echo(result)
|
||||
@@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool):
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@docs.command()
|
||||
def update():
|
||||
"""Update docs from GitHub"""
|
||||
@@ -73,22 +87,25 @@ def update():
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@docs.command()
|
||||
@click.option('--force-facts', is_flag=True, help='Force regenerate fact files')
|
||||
@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache')
|
||||
@click.option("--force-facts", is_flag=True, help="Force regenerate fact files")
|
||||
@click.option("--clear-cache", is_flag=True, help="Clear BM25 cache")
|
||||
def index(force_facts: bool, clear_cache: bool):
|
||||
"""Build or rebuild search indexes"""
|
||||
try:
|
||||
asyncio.run(docs_manager.ensure_docs_exist())
|
||||
asyncio.run(docs_manager.llm_text.generate_index_files(
|
||||
force_generate_facts=force_facts,
|
||||
clear_bm25_cache=clear_cache
|
||||
))
|
||||
asyncio.run(
|
||||
docs_manager.llm_text.generate_index_files(
|
||||
force_generate_facts=force_facts, clear_bm25_cache=clear_cache
|
||||
)
|
||||
)
|
||||
click.echo("Search indexes built successfully")
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Add docs list command
|
||||
@docs.command()
|
||||
def list():
|
||||
@@ -96,10 +113,11 @@ def list():
|
||||
try:
|
||||
sections = docs_manager.list()
|
||||
print_table(["Sections"], [[section] for section in sections])
|
||||
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -8,7 +8,7 @@ DEFAULT_PROVIDER = "openai/gpt-4o-mini"
|
||||
MODEL_REPO_BRANCH = "new-release-0.0.2"
|
||||
# Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy
|
||||
PROVIDER_MODELS = {
|
||||
"ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token
|
||||
"ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token
|
||||
"groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"),
|
||||
"groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"),
|
||||
"openai/gpt-4o-mini": os.getenv("OPENAI_API_KEY"),
|
||||
@@ -22,27 +22,49 @@ PROVIDER_MODELS = {
|
||||
}
|
||||
|
||||
# Chunk token threshold
|
||||
CHUNK_TOKEN_THRESHOLD = 2 ** 11 # 2048 tokens
|
||||
CHUNK_TOKEN_THRESHOLD = 2**11 # 2048 tokens
|
||||
OVERLAP_RATE = 0.1
|
||||
WORD_TOKEN_RATE = 1.3
|
||||
|
||||
# Threshold for the minimum number of word in a HTML tag to be considered
|
||||
# Threshold for the minimum number of word in a HTML tag to be considered
|
||||
MIN_WORD_THRESHOLD = 1
|
||||
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1
|
||||
|
||||
IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height']
|
||||
ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark']
|
||||
IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"]
|
||||
ONLY_TEXT_ELIGIBLE_TAGS = [
|
||||
"b",
|
||||
"i",
|
||||
"u",
|
||||
"span",
|
||||
"del",
|
||||
"ins",
|
||||
"sub",
|
||||
"sup",
|
||||
"strong",
|
||||
"em",
|
||||
"code",
|
||||
"kbd",
|
||||
"var",
|
||||
"s",
|
||||
"q",
|
||||
"abbr",
|
||||
"cite",
|
||||
"dfn",
|
||||
"time",
|
||||
"small",
|
||||
"mark",
|
||||
]
|
||||
SOCIAL_MEDIA_DOMAINS = [
|
||||
'facebook.com',
|
||||
'twitter.com',
|
||||
'x.com',
|
||||
'linkedin.com',
|
||||
'instagram.com',
|
||||
'pinterest.com',
|
||||
'tiktok.com',
|
||||
'snapchat.com',
|
||||
'reddit.com',
|
||||
]
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"x.com",
|
||||
"linkedin.com",
|
||||
"instagram.com",
|
||||
"pinterest.com",
|
||||
"tiktok.com",
|
||||
"snapchat.com",
|
||||
"reddit.com",
|
||||
]
|
||||
|
||||
# Threshold for the Image extraction - Range is 1 to 6
|
||||
# Images are scored based on point based system, to filter based on usefulness. Points are assigned
|
||||
@@ -60,5 +82,5 @@ NEED_MIGRATION = True
|
||||
URL_LOG_SHORTEN_LENGTH = 30
|
||||
SHOW_DEPRECATION_WARNINGS = True
|
||||
SCREENSHOT_HEIGHT_TRESHOLD = 10000
|
||||
PAGE_TIMEOUT=60000
|
||||
DOWNLOAD_PAGE_TIMEOUT=60000
|
||||
PAGE_TIMEOUT = 60000
|
||||
DOWNLOAD_PAGE_TIMEOUT = 60000
|
||||
|
||||
@@ -1,59 +1,100 @@
|
||||
import re
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import List, Tuple
|
||||
from rank_bm25 import BM25Okapi
|
||||
from time import perf_counter
|
||||
from collections import deque
|
||||
from bs4 import BeautifulSoup, NavigableString, Tag, Comment
|
||||
from bs4 import NavigableString, Comment
|
||||
from .utils import clean_tokens
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
from snowballstemmer import stemmer
|
||||
|
||||
|
||||
class RelevantContentFilter(ABC):
|
||||
"""Abstract base class for content filtering strategies"""
|
||||
|
||||
def __init__(self, user_query: str = None):
|
||||
self.user_query = user_query
|
||||
self.included_tags = {
|
||||
# Primary structure
|
||||
'article', 'main', 'section', 'div',
|
||||
"article",
|
||||
"main",
|
||||
"section",
|
||||
"div",
|
||||
# List structures
|
||||
'ul', 'ol', 'li', 'dl', 'dt', 'dd',
|
||||
"ul",
|
||||
"ol",
|
||||
"li",
|
||||
"dl",
|
||||
"dt",
|
||||
"dd",
|
||||
# Text content
|
||||
'p', 'span', 'blockquote', 'pre', 'code',
|
||||
"p",
|
||||
"span",
|
||||
"blockquote",
|
||||
"pre",
|
||||
"code",
|
||||
# Headers
|
||||
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
|
||||
"h1",
|
||||
"h2",
|
||||
"h3",
|
||||
"h4",
|
||||
"h5",
|
||||
"h6",
|
||||
# Tables
|
||||
'table', 'thead', 'tbody', 'tr', 'td', 'th',
|
||||
"table",
|
||||
"thead",
|
||||
"tbody",
|
||||
"tr",
|
||||
"td",
|
||||
"th",
|
||||
# Other semantic elements
|
||||
'figure', 'figcaption', 'details', 'summary',
|
||||
"figure",
|
||||
"figcaption",
|
||||
"details",
|
||||
"summary",
|
||||
# Text formatting
|
||||
'em', 'strong', 'b', 'i', 'mark', 'small',
|
||||
"em",
|
||||
"strong",
|
||||
"b",
|
||||
"i",
|
||||
"mark",
|
||||
"small",
|
||||
# Rich content
|
||||
'time', 'address', 'cite', 'q'
|
||||
"time",
|
||||
"address",
|
||||
"cite",
|
||||
"q",
|
||||
}
|
||||
self.excluded_tags = {
|
||||
'nav', 'footer', 'header', 'aside', 'script',
|
||||
'style', 'form', 'iframe', 'noscript'
|
||||
"nav",
|
||||
"footer",
|
||||
"header",
|
||||
"aside",
|
||||
"script",
|
||||
"style",
|
||||
"form",
|
||||
"iframe",
|
||||
"noscript",
|
||||
}
|
||||
self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}
|
||||
self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"}
|
||||
self.negative_patterns = re.compile(
|
||||
r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share',
|
||||
re.I
|
||||
r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I
|
||||
)
|
||||
self.min_word_count = 2
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def filter_content(self, html: str) -> List[str]:
|
||||
"""Abstract method to be implemented by specific filtering strategies"""
|
||||
pass
|
||||
|
||||
|
||||
def extract_page_query(self, soup: BeautifulSoup, body: Tag) -> str:
|
||||
"""Common method to extract page metadata with fallbacks"""
|
||||
if self.user_query:
|
||||
return self.user_query
|
||||
|
||||
query_parts = []
|
||||
|
||||
|
||||
# Title
|
||||
try:
|
||||
title = soup.title.string
|
||||
@@ -62,109 +103,145 @@ class RelevantContentFilter(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if soup.find('h1'):
|
||||
query_parts.append(soup.find('h1').get_text())
|
||||
|
||||
if soup.find("h1"):
|
||||
query_parts.append(soup.find("h1").get_text())
|
||||
|
||||
# Meta tags
|
||||
temp = ""
|
||||
for meta_name in ['keywords', 'description']:
|
||||
meta = soup.find('meta', attrs={'name': meta_name})
|
||||
if meta and meta.get('content'):
|
||||
query_parts.append(meta['content'])
|
||||
temp += meta['content']
|
||||
|
||||
for meta_name in ["keywords", "description"]:
|
||||
meta = soup.find("meta", attrs={"name": meta_name})
|
||||
if meta and meta.get("content"):
|
||||
query_parts.append(meta["content"])
|
||||
temp += meta["content"]
|
||||
|
||||
# If still empty, grab first significant paragraph
|
||||
if not temp:
|
||||
# Find the first tag P thatits text contains more than 50 characters
|
||||
for p in body.find_all('p'):
|
||||
for p in body.find_all("p"):
|
||||
if len(p.get_text()) > 150:
|
||||
query_parts.append(p.get_text()[:150])
|
||||
break
|
||||
|
||||
return ' '.join(filter(None, query_parts))
|
||||
break
|
||||
|
||||
def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]:
|
||||
return " ".join(filter(None, query_parts))
|
||||
|
||||
def extract_text_chunks(
|
||||
self, body: Tag, min_word_threshold: int = None
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Extracts text chunks from a BeautifulSoup body element while preserving order.
|
||||
Returns list of tuples (text, tag_name) for classification.
|
||||
|
||||
|
||||
Args:
|
||||
body: BeautifulSoup Tag object representing the body element
|
||||
|
||||
|
||||
Returns:
|
||||
List of (text, tag_name) tuples
|
||||
"""
|
||||
# Tags to ignore - inline elements that shouldn't break text flow
|
||||
INLINE_TAGS = {
|
||||
'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code',
|
||||
'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q',
|
||||
'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup',
|
||||
'textarea', 'time', 'tt', 'var'
|
||||
"a",
|
||||
"abbr",
|
||||
"acronym",
|
||||
"b",
|
||||
"bdo",
|
||||
"big",
|
||||
"br",
|
||||
"button",
|
||||
"cite",
|
||||
"code",
|
||||
"dfn",
|
||||
"em",
|
||||
"i",
|
||||
"img",
|
||||
"input",
|
||||
"kbd",
|
||||
"label",
|
||||
"map",
|
||||
"object",
|
||||
"q",
|
||||
"samp",
|
||||
"script",
|
||||
"select",
|
||||
"small",
|
||||
"span",
|
||||
"strong",
|
||||
"sub",
|
||||
"sup",
|
||||
"textarea",
|
||||
"time",
|
||||
"tt",
|
||||
"var",
|
||||
}
|
||||
|
||||
|
||||
# Tags that typically contain meaningful headers
|
||||
HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'}
|
||||
|
||||
HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"}
|
||||
|
||||
chunks = []
|
||||
current_text = []
|
||||
chunk_index = 0
|
||||
|
||||
|
||||
def should_break_chunk(tag: Tag) -> bool:
|
||||
"""Determine if a tag should cause a break in the current text chunk"""
|
||||
return (
|
||||
tag.name not in INLINE_TAGS
|
||||
and not (tag.name == 'p' and len(current_text) == 0)
|
||||
return tag.name not in INLINE_TAGS and not (
|
||||
tag.name == "p" and len(current_text) == 0
|
||||
)
|
||||
|
||||
|
||||
# Use deque for efficient push/pop operations
|
||||
stack = deque([(body, False)])
|
||||
|
||||
|
||||
while stack:
|
||||
element, visited = stack.pop()
|
||||
|
||||
|
||||
if visited:
|
||||
# End of block element - flush accumulated text
|
||||
if current_text and should_break_chunk(element):
|
||||
text = ' '.join(''.join(current_text).split())
|
||||
text = " ".join("".join(current_text).split())
|
||||
if text:
|
||||
tag_type = 'header' if element.name in HEADER_TAGS else 'content'
|
||||
tag_type = (
|
||||
"header" if element.name in HEADER_TAGS else "content"
|
||||
)
|
||||
chunks.append((chunk_index, text, tag_type, element))
|
||||
chunk_index += 1
|
||||
current_text = []
|
||||
continue
|
||||
|
||||
|
||||
if isinstance(element, NavigableString):
|
||||
if str(element).strip():
|
||||
current_text.append(str(element).strip())
|
||||
continue
|
||||
|
||||
|
||||
# Pre-allocate children to avoid multiple list operations
|
||||
children = list(element.children)
|
||||
if not children:
|
||||
continue
|
||||
|
||||
|
||||
# Mark block for revisit after processing children
|
||||
stack.append((element, True))
|
||||
|
||||
|
||||
# Add children in reverse order for correct processing
|
||||
for child in reversed(children):
|
||||
if isinstance(child, (Tag, NavigableString)):
|
||||
stack.append((child, False))
|
||||
|
||||
|
||||
# Handle any remaining text
|
||||
if current_text:
|
||||
text = ' '.join(''.join(current_text).split())
|
||||
text = " ".join("".join(current_text).split())
|
||||
if text:
|
||||
chunks.append((chunk_index, text, 'content', body))
|
||||
|
||||
if min_word_threshold:
|
||||
chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold]
|
||||
|
||||
return chunks
|
||||
chunks.append((chunk_index, text, "content", body))
|
||||
|
||||
def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]:
|
||||
if min_word_threshold:
|
||||
chunks = [
|
||||
chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold
|
||||
]
|
||||
|
||||
return chunks
|
||||
|
||||
def _deprecated_extract_text_chunks(
|
||||
self, soup: BeautifulSoup
|
||||
) -> List[Tuple[int, str, Tag]]:
|
||||
"""Common method for extracting text chunks"""
|
||||
_text_cache = {}
|
||||
|
||||
def fast_text(element: Tag) -> str:
|
||||
elem_id = id(element)
|
||||
if elem_id in _text_cache:
|
||||
@@ -175,13 +252,13 @@ class RelevantContentFilter(ABC):
|
||||
text = content.strip()
|
||||
if text:
|
||||
texts.append(text)
|
||||
result = ' '.join(texts)
|
||||
result = " ".join(texts)
|
||||
_text_cache[elem_id] = result
|
||||
return result
|
||||
|
||||
|
||||
candidates = []
|
||||
index = 0
|
||||
|
||||
|
||||
def dfs(element):
|
||||
nonlocal index
|
||||
if isinstance(element, Tag):
|
||||
@@ -189,7 +266,7 @@ class RelevantContentFilter(ABC):
|
||||
if not self.is_excluded(element):
|
||||
text = fast_text(element)
|
||||
word_count = len(text.split())
|
||||
|
||||
|
||||
# Headers pass through with adjusted minimum
|
||||
if element.name in self.header_tags:
|
||||
if word_count >= 3: # Minimal sanity check for headers
|
||||
@@ -199,7 +276,7 @@ class RelevantContentFilter(ABC):
|
||||
elif word_count >= self.min_word_count:
|
||||
candidates.append((index, text, element))
|
||||
index += 1
|
||||
|
||||
|
||||
for child in element.children:
|
||||
dfs(child)
|
||||
|
||||
@@ -210,59 +287,67 @@ class RelevantContentFilter(ABC):
|
||||
"""Common method for exclusion logic"""
|
||||
if tag.name in self.excluded_tags:
|
||||
return True
|
||||
class_id = ' '.join(filter(None, [
|
||||
' '.join(tag.get('class', [])),
|
||||
tag.get('id', '')
|
||||
]))
|
||||
class_id = " ".join(
|
||||
filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")])
|
||||
)
|
||||
return bool(self.negative_patterns.search(class_id))
|
||||
|
||||
def clean_element(self, tag: Tag) -> str:
|
||||
"""Common method for cleaning HTML elements with minimal overhead"""
|
||||
if not tag or not isinstance(tag, Tag):
|
||||
return ""
|
||||
|
||||
unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'}
|
||||
unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'}
|
||||
|
||||
|
||||
unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"}
|
||||
unwanted_attrs = {
|
||||
"style",
|
||||
"onclick",
|
||||
"onmouseover",
|
||||
"align",
|
||||
"bgcolor",
|
||||
"class",
|
||||
"id",
|
||||
}
|
||||
|
||||
# Use string builder pattern for better performance
|
||||
builder = []
|
||||
|
||||
|
||||
def render_tag(elem):
|
||||
if not isinstance(elem, Tag):
|
||||
if isinstance(elem, str):
|
||||
builder.append(elem.strip())
|
||||
return
|
||||
|
||||
|
||||
if elem.name in unwanted_tags:
|
||||
return
|
||||
|
||||
|
||||
# Start tag
|
||||
builder.append(f'<{elem.name}')
|
||||
|
||||
builder.append(f"<{elem.name}")
|
||||
|
||||
# Add cleaned attributes
|
||||
attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs}
|
||||
for key, value in attrs.items():
|
||||
builder.append(f' {key}="{value}"')
|
||||
|
||||
builder.append('>')
|
||||
|
||||
|
||||
builder.append(">")
|
||||
|
||||
# Process children
|
||||
for child in elem.children:
|
||||
render_tag(child)
|
||||
|
||||
|
||||
# Close tag
|
||||
builder.append(f'</{elem.name}>')
|
||||
|
||||
builder.append(f"</{elem.name}>")
|
||||
|
||||
try:
|
||||
render_tag(tag)
|
||||
return ''.join(builder)
|
||||
return "".join(builder)
|
||||
except Exception:
|
||||
return str(tag) # Fallback to original if anything fails
|
||||
|
||||
|
||||
class BM25ContentFilter(RelevantContentFilter):
|
||||
"""
|
||||
Content filtering using BM25 algorithm with priority tag handling.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Extracts page metadata with fallbacks.
|
||||
2. Extracts text chunks from the body element.
|
||||
@@ -271,22 +356,28 @@ class BM25ContentFilter(RelevantContentFilter):
|
||||
5. Filters out chunks below the threshold.
|
||||
6. Sorts chunks by score in descending order.
|
||||
7. Returns the top N chunks.
|
||||
|
||||
|
||||
Attributes:
|
||||
user_query (str): User query for filtering (optional).
|
||||
bm25_threshold (float): BM25 threshold for filtering (default: 1.0).
|
||||
language (str): Language for stemming (default: 'english').
|
||||
|
||||
|
||||
Methods:
|
||||
filter_content(self, html: str, min_word_threshold: int = None)
|
||||
"""
|
||||
def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_query: str = None,
|
||||
bm25_threshold: float = 1.0,
|
||||
language: str = "english",
|
||||
):
|
||||
"""
|
||||
Initializes the BM25ContentFilter class, if not provided, falls back to page metadata.
|
||||
|
||||
|
||||
Note:
|
||||
If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph.
|
||||
|
||||
|
||||
Args:
|
||||
user_query (str): User query for filtering (optional).
|
||||
bm25_threshold (float): BM25 threshold for filtering (default: 1.0).
|
||||
@@ -295,52 +386,52 @@ class BM25ContentFilter(RelevantContentFilter):
|
||||
super().__init__(user_query=user_query)
|
||||
self.bm25_threshold = bm25_threshold
|
||||
self.priority_tags = {
|
||||
'h1': 5.0,
|
||||
'h2': 4.0,
|
||||
'h3': 3.0,
|
||||
'title': 4.0,
|
||||
'strong': 2.0,
|
||||
'b': 1.5,
|
||||
'em': 1.5,
|
||||
'blockquote': 2.0,
|
||||
'code': 2.0,
|
||||
'pre': 1.5,
|
||||
'th': 1.5, # Table headers
|
||||
"h1": 5.0,
|
||||
"h2": 4.0,
|
||||
"h3": 3.0,
|
||||
"title": 4.0,
|
||||
"strong": 2.0,
|
||||
"b": 1.5,
|
||||
"em": 1.5,
|
||||
"blockquote": 2.0,
|
||||
"code": 2.0,
|
||||
"pre": 1.5,
|
||||
"th": 1.5, # Table headers
|
||||
}
|
||||
self.stemmer = stemmer(language)
|
||||
|
||||
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
|
||||
"""
|
||||
Implements content filtering using BM25 algorithm with priority tag handling.
|
||||
|
||||
|
||||
Note:
|
||||
This method implements the filtering logic for the BM25ContentFilter class.
|
||||
It takes HTML content as input and returns a list of filtered text chunks.
|
||||
|
||||
|
||||
Args:
|
||||
html (str): HTML content to be filtered.
|
||||
min_word_threshold (int): Minimum word threshold for filtering (optional).
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: List of filtered text chunks.
|
||||
"""
|
||||
if not html or not isinstance(html, str):
|
||||
return []
|
||||
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
|
||||
# Check if body is present
|
||||
if not soup.body:
|
||||
# Wrap in body tag if missing
|
||||
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
|
||||
body = soup.find('body')
|
||||
|
||||
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
|
||||
body = soup.find("body")
|
||||
|
||||
query = self.extract_page_query(soup, body)
|
||||
|
||||
|
||||
if not query:
|
||||
return []
|
||||
# return [self.clean_element(soup)]
|
||||
|
||||
|
||||
candidates = self.extract_text_chunks(body, min_word_threshold)
|
||||
|
||||
if not candidates:
|
||||
@@ -349,16 +440,20 @@ class BM25ContentFilter(RelevantContentFilter):
|
||||
# Tokenize corpus
|
||||
# tokenized_corpus = [chunk.lower().split() for _, chunk, _, _ in candidates]
|
||||
# tokenized_query = query.lower().split()
|
||||
|
||||
# tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()]
|
||||
# for _, chunk, _, _ in candidates]
|
||||
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
|
||||
|
||||
tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()]
|
||||
for _, chunk, _, _ in candidates]
|
||||
tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()]
|
||||
|
||||
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
|
||||
# tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()]
|
||||
# for _, chunk, _, _ in candidates]
|
||||
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
|
||||
|
||||
tokenized_corpus = [
|
||||
[self.stemmer.stemWord(word) for word in chunk.lower().split()]
|
||||
for _, chunk, _, _ in candidates
|
||||
]
|
||||
tokenized_query = [
|
||||
self.stemmer.stemWord(word) for word in query.lower().split()
|
||||
]
|
||||
|
||||
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
|
||||
# for _, chunk, _, _ in candidates]
|
||||
# tokenized_query = [self.stemmer.stemWord(word) for word in tokenize_text(query.lower())]
|
||||
|
||||
@@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter):
|
||||
|
||||
# Filter candidates by threshold
|
||||
selected_candidates = [
|
||||
(index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates
|
||||
(index, chunk, tag)
|
||||
for adjusted_score, index, chunk, tag in adjusted_candidates
|
||||
if adjusted_score >= self.bm25_threshold
|
||||
]
|
||||
|
||||
@@ -390,10 +486,11 @@ class BM25ContentFilter(RelevantContentFilter):
|
||||
|
||||
return [self.clean_element(tag) for _, _, tag in selected_candidates]
|
||||
|
||||
|
||||
class PruningContentFilter(RelevantContentFilter):
|
||||
"""
|
||||
Content filtering using pruning algorithm with dynamic threshold.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Extracts page metadata with fallbacks.
|
||||
2. Extracts text chunks from the body element.
|
||||
@@ -407,18 +504,24 @@ class PruningContentFilter(RelevantContentFilter):
|
||||
min_word_threshold (int): Minimum word threshold for filtering (optional).
|
||||
threshold_type (str): Threshold type for dynamic threshold (default: 'fixed').
|
||||
threshold (float): Fixed threshold value (default: 0.48).
|
||||
|
||||
|
||||
Methods:
|
||||
filter_content(self, html: str, min_word_threshold: int = None):
|
||||
"""
|
||||
def __init__(self, user_query: str = None, min_word_threshold: int = None,
|
||||
threshold_type: str = 'fixed', threshold: float = 0.48):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_query: str = None,
|
||||
min_word_threshold: int = None,
|
||||
threshold_type: str = "fixed",
|
||||
threshold: float = 0.48,
|
||||
):
|
||||
"""
|
||||
Initializes the PruningContentFilter class, if not provided, falls back to page metadata.
|
||||
|
||||
|
||||
Note:
|
||||
If no query is given and no page metadata is available, then it tries to pick up the first significant paragraph.
|
||||
|
||||
|
||||
Args:
|
||||
user_query (str): User query for filtering (optional).
|
||||
min_word_threshold (int): Minimum word threshold for filtering (optional).
|
||||
@@ -429,92 +532,92 @@ class PruningContentFilter(RelevantContentFilter):
|
||||
self.min_word_threshold = min_word_threshold
|
||||
self.threshold_type = threshold_type
|
||||
self.threshold = threshold
|
||||
|
||||
|
||||
# Add tag importance for dynamic threshold
|
||||
self.tag_importance = {
|
||||
'article': 1.5,
|
||||
'main': 1.4,
|
||||
'section': 1.3,
|
||||
'p': 1.2,
|
||||
'h1': 1.4,
|
||||
'h2': 1.3,
|
||||
'h3': 1.2,
|
||||
'div': 0.7,
|
||||
'span': 0.6
|
||||
"article": 1.5,
|
||||
"main": 1.4,
|
||||
"section": 1.3,
|
||||
"p": 1.2,
|
||||
"h1": 1.4,
|
||||
"h2": 1.3,
|
||||
"h3": 1.2,
|
||||
"div": 0.7,
|
||||
"span": 0.6,
|
||||
}
|
||||
|
||||
|
||||
# Metric configuration
|
||||
self.metric_config = {
|
||||
'text_density': True,
|
||||
'link_density': True,
|
||||
'tag_weight': True,
|
||||
'class_id_weight': True,
|
||||
'text_length': True,
|
||||
"text_density": True,
|
||||
"link_density": True,
|
||||
"tag_weight": True,
|
||||
"class_id_weight": True,
|
||||
"text_length": True,
|
||||
}
|
||||
|
||||
|
||||
self.metric_weights = {
|
||||
'text_density': 0.4,
|
||||
'link_density': 0.2,
|
||||
'tag_weight': 0.2,
|
||||
'class_id_weight': 0.1,
|
||||
'text_length': 0.1,
|
||||
"text_density": 0.4,
|
||||
"link_density": 0.2,
|
||||
"tag_weight": 0.2,
|
||||
"class_id_weight": 0.1,
|
||||
"text_length": 0.1,
|
||||
}
|
||||
|
||||
|
||||
self.tag_weights = {
|
||||
'div': 0.5,
|
||||
'p': 1.0,
|
||||
'article': 1.5,
|
||||
'section': 1.0,
|
||||
'span': 0.3,
|
||||
'li': 0.5,
|
||||
'ul': 0.5,
|
||||
'ol': 0.5,
|
||||
'h1': 1.2,
|
||||
'h2': 1.1,
|
||||
'h3': 1.0,
|
||||
'h4': 0.9,
|
||||
'h5': 0.8,
|
||||
'h6': 0.7,
|
||||
"div": 0.5,
|
||||
"p": 1.0,
|
||||
"article": 1.5,
|
||||
"section": 1.0,
|
||||
"span": 0.3,
|
||||
"li": 0.5,
|
||||
"ul": 0.5,
|
||||
"ol": 0.5,
|
||||
"h1": 1.2,
|
||||
"h2": 1.1,
|
||||
"h3": 1.0,
|
||||
"h4": 0.9,
|
||||
"h5": 0.8,
|
||||
"h6": 0.7,
|
||||
}
|
||||
|
||||
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
|
||||
"""
|
||||
Implements content filtering using pruning algorithm with dynamic threshold.
|
||||
|
||||
|
||||
Note:
|
||||
This method implements the filtering logic for the PruningContentFilter class.
|
||||
It takes HTML content as input and returns a list of filtered text chunks.
|
||||
|
||||
|
||||
Args:
|
||||
html (str): HTML content to be filtered.
|
||||
min_word_threshold (int): Minimum word threshold for filtering (optional).
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: List of filtered text chunks.
|
||||
"""
|
||||
if not html or not isinstance(html, str):
|
||||
return []
|
||||
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
if not soup.body:
|
||||
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
|
||||
|
||||
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
|
||||
|
||||
# Remove comments and unwanted tags
|
||||
self._remove_comments(soup)
|
||||
self._remove_unwanted_tags(soup)
|
||||
|
||||
|
||||
# Prune tree starting from body
|
||||
body = soup.find('body')
|
||||
body = soup.find("body")
|
||||
self._prune_tree(body)
|
||||
|
||||
|
||||
# Extract remaining content as list of HTML strings
|
||||
content_blocks = []
|
||||
for element in body.children:
|
||||
if isinstance(element, str) or not hasattr(element, 'name'):
|
||||
if isinstance(element, str) or not hasattr(element, "name"):
|
||||
continue
|
||||
if len(element.get_text(strip=True)) > 0:
|
||||
content_blocks.append(str(element))
|
||||
|
||||
|
||||
return content_blocks
|
||||
|
||||
def _remove_comments(self, soup):
|
||||
@@ -531,34 +634,38 @@ class PruningContentFilter(RelevantContentFilter):
|
||||
def _prune_tree(self, node):
|
||||
"""
|
||||
Prunes the tree starting from the given node.
|
||||
|
||||
|
||||
Args:
|
||||
node (Tag): The node from which the pruning starts.
|
||||
"""
|
||||
if not node or not hasattr(node, 'name') or node.name is None:
|
||||
if not node or not hasattr(node, "name") or node.name is None:
|
||||
return
|
||||
|
||||
text_len = len(node.get_text(strip=True))
|
||||
tag_len = len(node.encode_contents().decode('utf-8'))
|
||||
link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s)
|
||||
tag_len = len(node.encode_contents().decode("utf-8"))
|
||||
link_text_len = sum(
|
||||
len(s.strip())
|
||||
for s in (a.string for a in node.find_all("a", recursive=False))
|
||||
if s
|
||||
)
|
||||
|
||||
metrics = {
|
||||
'node': node,
|
||||
'tag_name': node.name,
|
||||
'text_len': text_len,
|
||||
'tag_len': tag_len,
|
||||
'link_text_len': link_text_len
|
||||
"node": node,
|
||||
"tag_name": node.name,
|
||||
"text_len": text_len,
|
||||
"tag_len": tag_len,
|
||||
"link_text_len": link_text_len,
|
||||
}
|
||||
|
||||
score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len)
|
||||
|
||||
if self.threshold_type == 'fixed':
|
||||
if self.threshold_type == "fixed":
|
||||
should_remove = score < self.threshold
|
||||
else: # dynamic
|
||||
tag_importance = self.tag_importance.get(node.name, 0.7)
|
||||
text_ratio = text_len / tag_len if tag_len > 0 else 0
|
||||
link_ratio = link_text_len / text_len if text_len > 0 else 1
|
||||
|
||||
|
||||
threshold = self.threshold # base threshold
|
||||
if tag_importance > 1:
|
||||
threshold *= 0.8
|
||||
@@ -566,13 +673,13 @@ class PruningContentFilter(RelevantContentFilter):
|
||||
threshold *= 0.9
|
||||
if link_ratio > 0.6:
|
||||
threshold *= 1.2
|
||||
|
||||
|
||||
should_remove = score < threshold
|
||||
|
||||
if should_remove:
|
||||
node.decompose()
|
||||
else:
|
||||
children = [child for child in node.children if hasattr(child, 'name')]
|
||||
children = [child for child in node.children if hasattr(child, "name")]
|
||||
for child in children:
|
||||
self._prune_tree(child)
|
||||
|
||||
@@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter):
|
||||
"""Computes the composite score"""
|
||||
if self.min_word_threshold:
|
||||
# Get raw text from metrics node - avoid extra processing
|
||||
text = metrics['node'].get_text(strip=True)
|
||||
word_count = text.count(' ') + 1
|
||||
text = metrics["node"].get_text(strip=True)
|
||||
word_count = text.count(" ") + 1
|
||||
if word_count < self.min_word_threshold:
|
||||
return -1.0 # Guaranteed removal
|
||||
score = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
if self.metric_config['text_density']:
|
||||
if self.metric_config["text_density"]:
|
||||
density = text_len / tag_len if tag_len > 0 else 0
|
||||
score += self.metric_weights['text_density'] * density
|
||||
total_weight += self.metric_weights['text_density']
|
||||
score += self.metric_weights["text_density"] * density
|
||||
total_weight += self.metric_weights["text_density"]
|
||||
|
||||
if self.metric_config['link_density']:
|
||||
if self.metric_config["link_density"]:
|
||||
density = 1 - (link_text_len / text_len if text_len > 0 else 0)
|
||||
score += self.metric_weights['link_density'] * density
|
||||
total_weight += self.metric_weights['link_density']
|
||||
score += self.metric_weights["link_density"] * density
|
||||
total_weight += self.metric_weights["link_density"]
|
||||
|
||||
if self.metric_config['tag_weight']:
|
||||
tag_score = self.tag_weights.get(metrics['tag_name'], 0.5)
|
||||
score += self.metric_weights['tag_weight'] * tag_score
|
||||
total_weight += self.metric_weights['tag_weight']
|
||||
if self.metric_config["tag_weight"]:
|
||||
tag_score = self.tag_weights.get(metrics["tag_name"], 0.5)
|
||||
score += self.metric_weights["tag_weight"] * tag_score
|
||||
total_weight += self.metric_weights["tag_weight"]
|
||||
|
||||
if self.metric_config['class_id_weight']:
|
||||
class_score = self._compute_class_id_weight(metrics['node'])
|
||||
score += self.metric_weights['class_id_weight'] * max(0, class_score)
|
||||
total_weight += self.metric_weights['class_id_weight']
|
||||
if self.metric_config["class_id_weight"]:
|
||||
class_score = self._compute_class_id_weight(metrics["node"])
|
||||
score += self.metric_weights["class_id_weight"] * max(0, class_score)
|
||||
total_weight += self.metric_weights["class_id_weight"]
|
||||
|
||||
if self.metric_config['text_length']:
|
||||
score += self.metric_weights['text_length'] * math.log(text_len + 1)
|
||||
total_weight += self.metric_weights['text_length']
|
||||
if self.metric_config["text_length"]:
|
||||
score += self.metric_weights["text_length"] * math.log(text_len + 1)
|
||||
total_weight += self.metric_weights["text_length"]
|
||||
|
||||
return score / total_weight if total_weight > 0 else 0
|
||||
|
||||
def _compute_class_id_weight(self, node):
|
||||
"""Computes the class ID weight"""
|
||||
class_id_score = 0
|
||||
if 'class' in node.attrs:
|
||||
classes = ' '.join(node['class'])
|
||||
if "class" in node.attrs:
|
||||
classes = " ".join(node["class"])
|
||||
if self.negative_patterns.match(classes):
|
||||
class_id_score -= 0.5
|
||||
if 'id' in node.attrs:
|
||||
element_id = node['id']
|
||||
if "id" in node.attrs:
|
||||
element_id = node["id"]
|
||||
if self.negative_patterns.match(element_id):
|
||||
class_id_score -= 0.5
|
||||
return class_id_score
|
||||
return class_id_score
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,54 +15,53 @@ import logging, time
|
||||
import base64
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from io import BytesIO
|
||||
from typing import List, Callable
|
||||
from typing import Callable
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
from .utils import *
|
||||
|
||||
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
|
||||
logger = logging.getLogger("selenium.webdriver.remote.remote_connection")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
logger_driver = logging.getLogger('selenium.webdriver.common.service')
|
||||
logger_driver = logging.getLogger("selenium.webdriver.common.service")
|
||||
logger_driver.setLevel(logging.WARNING)
|
||||
|
||||
urllib3_logger = logging.getLogger('urllib3.connectionpool')
|
||||
urllib3_logger = logging.getLogger("urllib3.connectionpool")
|
||||
urllib3_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Disable http.client logging
|
||||
http_client_logger = logging.getLogger('http.client')
|
||||
http_client_logger = logging.getLogger("http.client")
|
||||
http_client_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Disable driver_finder and service logging
|
||||
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder')
|
||||
driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder")
|
||||
driver_finder_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
|
||||
|
||||
class CrawlerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def crawl(self, url: str, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def take_screenshot(self, save_path: str):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def update_user_agent(self, user_agent: str):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
pass
|
||||
|
||||
|
||||
class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self, use_cached_html = False):
|
||||
def __init__(self, use_cached_html=False):
|
||||
super().__init__()
|
||||
self.use_cached_html = use_cached_html
|
||||
|
||||
|
||||
def crawl(self, url: str) -> str:
|
||||
data = {
|
||||
"urls": [url],
|
||||
@@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
html = response["results"][0]["html"]
|
||||
return sanitize_input_encode(html)
|
||||
|
||||
|
||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||
super().__init__()
|
||||
@@ -87,20 +87,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
if kwargs.get("user_agent"):
|
||||
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
|
||||
else:
|
||||
user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
||||
user_agent = kwargs.get(
|
||||
"user_agent",
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
)
|
||||
self.options.add_argument(f"--user-agent={user_agent}")
|
||||
self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
||||
|
||||
self.options.add_argument(
|
||||
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
)
|
||||
|
||||
self.options.headless = kwargs.get("headless", True)
|
||||
if self.options.headless:
|
||||
self.options.add_argument("--headless")
|
||||
|
||||
self.options.add_argument("--disable-gpu")
|
||||
|
||||
self.options.add_argument("--disable-gpu")
|
||||
self.options.add_argument("--window-size=1920,1080")
|
||||
self.options.add_argument("--no-sandbox")
|
||||
self.options.add_argument("--disable-dev-shm-usage")
|
||||
self.options.add_argument("--disable-blink-features=AutomationControlled")
|
||||
|
||||
self.options.add_argument("--disable-blink-features=AutomationControlled")
|
||||
|
||||
# self.options.add_argument("--disable-dev-shm-usage")
|
||||
self.options.add_argument("--disable-gpu")
|
||||
# self.options.add_argument("--disable-extensions")
|
||||
@@ -120,14 +125,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.use_cached_html = use_cached_html
|
||||
self.js_code = js_code
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
|
||||
# Hooks
|
||||
self.hooks = {
|
||||
'on_driver_created': None,
|
||||
'on_user_agent_updated': None,
|
||||
'before_get_url': None,
|
||||
'after_get_url': None,
|
||||
'before_return_html': None
|
||||
"on_driver_created": None,
|
||||
"on_user_agent_updated": None,
|
||||
"before_get_url": None,
|
||||
"after_get_url": None,
|
||||
"before_return_html": None,
|
||||
}
|
||||
|
||||
# chromedriver_autoinstaller.install()
|
||||
@@ -137,31 +142,28 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
# chromedriver_path = chromedriver_autoinstaller.install()
|
||||
# chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver()
|
||||
# self.service = Service(chromedriver_autoinstaller.install())
|
||||
|
||||
|
||||
|
||||
# chromedriver_path = ChromeDriverManager().install()
|
||||
# self.service = Service(chromedriver_path)
|
||||
# self.service.log_path = "NUL"
|
||||
# self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
|
||||
|
||||
# Use selenium-manager (built into Selenium 4.10.0+)
|
||||
self.service = Service()
|
||||
self.driver = webdriver.Chrome(options=self.options)
|
||||
|
||||
self.driver = self.execute_hook('on_driver_created', self.driver)
|
||||
|
||||
|
||||
self.driver = self.execute_hook("on_driver_created", self.driver)
|
||||
|
||||
if kwargs.get("cookies"):
|
||||
for cookie in kwargs.get("cookies"):
|
||||
self.driver.add_cookie(cookie)
|
||||
|
||||
|
||||
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
if hook_type in self.hooks:
|
||||
self.hooks[hook_type] = hook
|
||||
else:
|
||||
raise ValueError(f"Invalid hook type: {hook_type}")
|
||||
|
||||
|
||||
def execute_hook(self, hook_type: str, *args):
|
||||
hook = self.hooks.get(hook_type)
|
||||
if hook:
|
||||
@@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
if isinstance(result, webdriver.Chrome):
|
||||
return result
|
||||
else:
|
||||
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.")
|
||||
raise TypeError(
|
||||
f"Hook {hook_type} must return an instance of webdriver.Chrome or None."
|
||||
)
|
||||
# If the hook returns None or there is no hook, return self.driver
|
||||
return self.driver
|
||||
|
||||
@@ -178,60 +182,77 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.options.add_argument(f"user-agent={user_agent}")
|
||||
self.driver.quit()
|
||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
self.driver = self.execute_hook('on_user_agent_updated', self.driver)
|
||||
self.driver = self.execute_hook("on_user_agent_updated", self.driver)
|
||||
|
||||
def set_custom_headers(self, headers: dict):
|
||||
# Enable Network domain for sending headers
|
||||
self.driver.execute_cdp_cmd('Network.enable', {})
|
||||
self.driver.execute_cdp_cmd("Network.enable", {})
|
||||
# Set extra HTTP headers
|
||||
self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers})
|
||||
self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers})
|
||||
|
||||
def _ensure_page_load(self, max_checks=6, check_interval=0.01):
|
||||
def _ensure_page_load(self, max_checks=6, check_interval=0.01):
|
||||
initial_length = len(self.driver.page_source)
|
||||
|
||||
|
||||
for ix in range(max_checks):
|
||||
# print(f"Checking page load: {ix}")
|
||||
time.sleep(check_interval)
|
||||
current_length = len(self.driver.page_source)
|
||||
|
||||
|
||||
if current_length != initial_length:
|
||||
break
|
||||
|
||||
return self.driver.page_source
|
||||
|
||||
|
||||
def crawl(self, url: str, **kwargs) -> str:
|
||||
# Create md5 hash of the URL
|
||||
import hashlib
|
||||
|
||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||
|
||||
|
||||
if self.use_cached_html:
|
||||
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash)
|
||||
cache_file_path = os.path.join(
|
||||
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
|
||||
".crawl4ai",
|
||||
"cache",
|
||||
url_hash,
|
||||
)
|
||||
if os.path.exists(cache_file_path):
|
||||
with open(cache_file_path, "r") as f:
|
||||
return sanitize_input_encode(f.read())
|
||||
|
||||
try:
|
||||
self.driver = self.execute_hook('before_get_url', self.driver)
|
||||
self.driver = self.execute_hook("before_get_url", self.driver)
|
||||
if self.verbose:
|
||||
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
||||
self.driver.get(url) #<html><head></head><body></body></html>
|
||||
|
||||
self.driver.get(url) # <html><head></head><body></body></html>
|
||||
|
||||
WebDriverWait(self.driver, 20).until(
|
||||
lambda d: d.execute_script('return document.readyState') == 'complete'
|
||||
lambda d: d.execute_script("return document.readyState") == "complete"
|
||||
)
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
||||
|
||||
self.driver = self.execute_hook('after_get_url', self.driver)
|
||||
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source
|
||||
can_not_be_done_headless = False # Look at my creativity for naming variables
|
||||
|
||||
|
||||
self.driver.execute_script(
|
||||
"window.scrollTo(0, document.body.scrollHeight);"
|
||||
)
|
||||
|
||||
self.driver = self.execute_hook("after_get_url", self.driver)
|
||||
html = sanitize_input_encode(
|
||||
self._ensure_page_load()
|
||||
) # self.driver.page_source
|
||||
can_not_be_done_headless = (
|
||||
False # Look at my creativity for naming variables
|
||||
)
|
||||
|
||||
# TODO: Very ugly approach, but promise to change it!
|
||||
if kwargs.get('bypass_headless', False) or html == "<html><head></head><body></body></html>":
|
||||
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...")
|
||||
if (
|
||||
kwargs.get("bypass_headless", False)
|
||||
or html == "<html><head></head><body></body></html>"
|
||||
):
|
||||
print(
|
||||
"[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..."
|
||||
)
|
||||
can_not_be_done_headless = True
|
||||
options = Options()
|
||||
options.headless = False
|
||||
@@ -239,27 +260,31 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
options.add_argument("--window-size=5,5")
|
||||
driver = webdriver.Chrome(service=self.service, options=options)
|
||||
driver.get(url)
|
||||
self.driver = self.execute_hook('after_get_url', driver)
|
||||
self.driver = self.execute_hook("after_get_url", driver)
|
||||
html = sanitize_input_encode(driver.page_source)
|
||||
driver.quit()
|
||||
|
||||
|
||||
# Execute JS code if provided
|
||||
self.js_code = kwargs.get("js_code", self.js_code)
|
||||
if self.js_code and type(self.js_code) == str:
|
||||
self.driver.execute_script(self.js_code)
|
||||
# Optionally, wait for some condition after executing the JS code
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
lambda driver: driver.execute_script("return document.readyState") == "complete"
|
||||
lambda driver: driver.execute_script("return document.readyState")
|
||||
== "complete"
|
||||
)
|
||||
elif self.js_code and type(self.js_code) == list:
|
||||
for js in self.js_code:
|
||||
self.driver.execute_script(js)
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
lambda driver: driver.execute_script("return document.readyState") == "complete"
|
||||
lambda driver: driver.execute_script(
|
||||
"return document.readyState"
|
||||
)
|
||||
== "complete"
|
||||
)
|
||||
|
||||
|
||||
# Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky)
|
||||
wait_for = kwargs.get('wait_for', False)
|
||||
wait_for = kwargs.get("wait_for", False)
|
||||
if wait_for:
|
||||
if callable(wait_for):
|
||||
print("[LOG] 🔄 Waiting for condition...")
|
||||
@@ -268,32 +293,37 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
print("[LOG] 🔄 Waiting for condition...")
|
||||
WebDriverWait(self.driver, 20).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, wait_for))
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
if not can_not_be_done_headless:
|
||||
html = sanitize_input_encode(self.driver.page_source)
|
||||
self.driver = self.execute_hook('before_return_html', self.driver, html)
|
||||
|
||||
self.driver = self.execute_hook("before_return_html", self.driver, html)
|
||||
|
||||
# Store in cache
|
||||
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash)
|
||||
cache_file_path = os.path.join(
|
||||
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
|
||||
".crawl4ai",
|
||||
"cache",
|
||||
url_hash,
|
||||
)
|
||||
with open(cache_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(html)
|
||||
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] ✅ Crawled {url} successfully!")
|
||||
|
||||
|
||||
return html
|
||||
except InvalidArgumentException as e:
|
||||
if not hasattr(e, 'msg'):
|
||||
if not hasattr(e, "msg"):
|
||||
e.msg = sanitize_input_encode(str(e))
|
||||
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
|
||||
except WebDriverException as e:
|
||||
# If e does nlt have msg attribute create it and set it to str(e)
|
||||
if not hasattr(e, 'msg'):
|
||||
if not hasattr(e, "msg"):
|
||||
e.msg = sanitize_input_encode(str(e))
|
||||
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
|
||||
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
|
||||
except Exception as e:
|
||||
if not hasattr(e, 'msg'):
|
||||
if not hasattr(e, "msg"):
|
||||
e.msg = sanitize_input_encode(str(e))
|
||||
raise Exception(f"Failed to crawl {url}: {e.msg}")
|
||||
|
||||
@@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
try:
|
||||
# Get the dimensions of the page
|
||||
total_width = self.driver.execute_script("return document.body.scrollWidth")
|
||||
total_height = self.driver.execute_script("return document.body.scrollHeight")
|
||||
total_height = self.driver.execute_script(
|
||||
"return document.body.scrollHeight"
|
||||
)
|
||||
|
||||
# Set the window size to the dimensions of the page
|
||||
self.driver.set_window_size(total_width, total_height)
|
||||
@@ -313,25 +345,27 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
image = Image.open(BytesIO(screenshot))
|
||||
|
||||
# Convert image to RGB mode (this will handle both RGB and RGBA images)
|
||||
rgb_image = image.convert('RGB')
|
||||
rgb_image = image.convert("RGB")
|
||||
|
||||
# Convert to JPEG and compress
|
||||
buffered = BytesIO()
|
||||
rgb_image.save(buffered, format="JPEG", quality=85)
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] 📸 Screenshot taken and converted to base64")
|
||||
print("[LOG] 📸 Screenshot taken and converted to base64")
|
||||
|
||||
return img_base64
|
||||
except Exception as e:
|
||||
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}")
|
||||
error_message = sanitize_input_encode(
|
||||
f"Failed to take screenshot: {str(e)}"
|
||||
)
|
||||
print(error_message)
|
||||
|
||||
# Generate an image with black background
|
||||
img = Image.new('RGB', (800, 600), color='black')
|
||||
img = Image.new("RGB", (800, 600), color="black")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
|
||||
# Load a font
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 40)
|
||||
@@ -345,16 +379,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
|
||||
# Calculate text position
|
||||
text_position = (10, 10)
|
||||
|
||||
|
||||
# Draw the text on the image
|
||||
draw.text(text_position, wrapped_text, fill=text_color, font=font)
|
||||
|
||||
|
||||
# Convert to base64
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
return img_base64
|
||||
|
||||
|
||||
def quit(self):
|
||||
self.driver.quit()
|
||||
|
||||
@@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra
|
||||
os.makedirs(DB_PATH, exist_ok=True)
|
||||
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
|
||||
|
||||
|
||||
def init_db():
|
||||
global DB_PATH
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS crawled_data (
|
||||
url TEXT PRIMARY KEY,
|
||||
html TEXT,
|
||||
@@ -24,31 +26,42 @@ def init_db():
|
||||
metadata TEXT DEFAULT "{}",
|
||||
screenshot TEXT DEFAULT ""
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def alter_db_add_screenshot(new_column: str = "media"):
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
|
||||
cursor.execute(
|
||||
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error altering database to add screenshot column: {e}")
|
||||
|
||||
|
||||
def check_db_path():
|
||||
if not DB_PATH:
|
||||
raise ValueError("Database path is not set or is empty.")
|
||||
|
||||
def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
|
||||
|
||||
def get_cached_url(
|
||||
url: str,
|
||||
) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,))
|
||||
cursor.execute(
|
||||
"SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?",
|
||||
(url,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result
|
||||
@@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str
|
||||
print(f"Error retrieving cached URL: {e}")
|
||||
return None
|
||||
|
||||
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""):
|
||||
|
||||
def cache_url(
|
||||
url: str,
|
||||
html: str,
|
||||
cleaned_html: str,
|
||||
markdown: str,
|
||||
extracted_content: str,
|
||||
success: bool,
|
||||
media: str = "{}",
|
||||
links: str = "{}",
|
||||
metadata: str = "{}",
|
||||
screenshot: str = "",
|
||||
):
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(url) DO UPDATE SET
|
||||
@@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
|
||||
links = excluded.links,
|
||||
metadata = excluded.metadata,
|
||||
screenshot = excluded.screenshot
|
||||
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot))
|
||||
""",
|
||||
(
|
||||
url,
|
||||
html,
|
||||
cleaned_html,
|
||||
markdown,
|
||||
extracted_content,
|
||||
success,
|
||||
media,
|
||||
links,
|
||||
metadata,
|
||||
screenshot,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error caching URL: {e}")
|
||||
|
||||
|
||||
def get_total_count() -> int:
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT COUNT(*) FROM crawled_data')
|
||||
cursor.execute("SELECT COUNT(*) FROM crawled_data")
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result[0]
|
||||
@@ -93,43 +133,48 @@ def get_total_count() -> int:
|
||||
print(f"Error getting total count: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def clear_db():
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM crawled_data')
|
||||
cursor.execute("DELETE FROM crawled_data")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error clearing database: {e}")
|
||||
|
||||
|
||||
|
||||
def flush_db():
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DROP TABLE crawled_data')
|
||||
cursor.execute("DROP TABLE crawled_data")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error flushing database: {e}")
|
||||
|
||||
|
||||
def update_existing_records(new_column: str = "media", default_value: str = "{}"):
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL')
|
||||
cursor.execute(
|
||||
f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL'
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error updating existing records: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Delete the existing database file
|
||||
if os.path.exists(DB_PATH):
|
||||
os.remove(DB_PATH)
|
||||
init_db()
|
||||
init_db()
|
||||
# alter_db_add_screenshot("COL_NAME")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from crawl4ai.async_logger import AsyncLogger
|
||||
from crawl4ai.llmtxt import AsyncLLMTextManager
|
||||
|
||||
|
||||
class DocsManager:
|
||||
def __init__(self, logger=None):
|
||||
self.docs_dir = Path.home() / ".crawl4ai" / "docs"
|
||||
@@ -21,11 +22,14 @@ class DocsManager:
|
||||
"""Copy from local docs or download from GitHub"""
|
||||
try:
|
||||
# Try local first
|
||||
if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))):
|
||||
if self.local_docs.exists() and (
|
||||
any(self.local_docs.glob("*.md"))
|
||||
or any(self.local_docs.glob("*.tokens"))
|
||||
):
|
||||
# Empty the local docs directory
|
||||
for file_path in self.docs_dir.glob("*.md"):
|
||||
file_path.unlink()
|
||||
# for file_path in self.docs_dir.glob("*.tokens"):
|
||||
# for file_path in self.docs_dir.glob("*.tokens"):
|
||||
# file_path.unlink()
|
||||
for file_path in self.local_docs.glob("*.md"):
|
||||
shutil.copy2(file_path, self.docs_dir / file_path.name)
|
||||
@@ -36,14 +40,14 @@ class DocsManager:
|
||||
# Fallback to GitHub
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
|
||||
headers={'Accept': 'application/vnd.github.v3+json'}
|
||||
headers={"Accept": "application/vnd.github.v3+json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
for item in response.json():
|
||||
if item['type'] == 'file' and item['name'].endswith('.md'):
|
||||
content = requests.get(item['download_url']).text
|
||||
with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f:
|
||||
if item["type"] == "file" and item["name"].endswith(".md"):
|
||||
content = requests.get(item["download_url"]).text
|
||||
with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return True
|
||||
|
||||
@@ -57,11 +61,15 @@ class DocsManager:
|
||||
# Remove [0-9]+_ prefix
|
||||
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
|
||||
# Exclude those end with .xs.md and .q.md
|
||||
names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")]
|
||||
names = [
|
||||
name
|
||||
for name in names
|
||||
if not name.endswith(".xs") and not name.endswith(".q")
|
||||
]
|
||||
return names
|
||||
|
||||
|
||||
def generate(self, sections, mode="extended"):
|
||||
return self.llm_text.generate(sections, mode)
|
||||
|
||||
|
||||
def search(self, query: str, top_k: int = 5):
|
||||
return self.llm_text.search(query, top_k)
|
||||
return self.llm_text.search(query, top_k)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -54,13 +54,13 @@ class HTML2Text(html.parser.HTMLParser):
|
||||
self.td_count = 0
|
||||
self.table_start = False
|
||||
self.unicode_snob = config.UNICODE_SNOB # covered in cli
|
||||
|
||||
|
||||
self.escape_snob = config.ESCAPE_SNOB # covered in cli
|
||||
self.escape_backslash = config.ESCAPE_BACKSLASH # covered in cli
|
||||
self.escape_dot = config.ESCAPE_DOT # covered in cli
|
||||
self.escape_plus = config.ESCAPE_PLUS # covered in cli
|
||||
self.escape_dash = config.ESCAPE_DASH # covered in cli
|
||||
|
||||
|
||||
self.links_each_paragraph = config.LINKS_EACH_PARAGRAPH
|
||||
self.body_width = bodywidth # covered in cli
|
||||
self.skip_internal_links = config.SKIP_INTERNAL_LINKS # covered in cli
|
||||
@@ -144,8 +144,8 @@ class HTML2Text(html.parser.HTMLParser):
|
||||
|
||||
def update_params(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
def feed(self, data: str) -> None:
|
||||
data = data.replace("</' + 'script>", "</ignore>")
|
||||
super().feed(data)
|
||||
@@ -903,7 +903,13 @@ class HTML2Text(html.parser.HTMLParser):
|
||||
self.empty_link = False
|
||||
|
||||
if not self.code and not self.pre and not entity_char:
|
||||
data = escape_md_section(data, snob=self.escape_snob, escape_dot=self.escape_dot, escape_plus=self.escape_plus, escape_dash=self.escape_dash)
|
||||
data = escape_md_section(
|
||||
data,
|
||||
snob=self.escape_snob,
|
||||
escape_dot=self.escape_dot,
|
||||
escape_plus=self.escape_plus,
|
||||
escape_dash=self.escape_dash,
|
||||
)
|
||||
self.preceding_data = data
|
||||
self.o(data, puredata=True)
|
||||
|
||||
@@ -1006,6 +1012,7 @@ class HTML2Text(html.parser.HTMLParser):
|
||||
newlines += 1
|
||||
return result
|
||||
|
||||
|
||||
def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str:
|
||||
if bodywidth is None:
|
||||
bodywidth = config.BODY_WIDTH
|
||||
@@ -1013,6 +1020,7 @@ def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) ->
|
||||
|
||||
return h.handle(html)
|
||||
|
||||
|
||||
class CustomHTML2Text(HTML2Text):
|
||||
def __init__(self, *args, handle_code_in_pre=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1022,8 +1030,8 @@ class CustomHTML2Text(HTML2Text):
|
||||
self.current_preserved_tag = None
|
||||
self.preserved_content = []
|
||||
self.preserve_depth = 0
|
||||
self.handle_code_in_pre = handle_code_in_pre
|
||||
|
||||
self.handle_code_in_pre = handle_code_in_pre
|
||||
|
||||
# Configuration options
|
||||
self.skip_internal_links = False
|
||||
self.single_line_break = False
|
||||
@@ -1041,9 +1049,9 @@ class CustomHTML2Text(HTML2Text):
|
||||
def update_params(self, **kwargs):
|
||||
"""Update parameters and set preserved tags."""
|
||||
for key, value in kwargs.items():
|
||||
if key == 'preserve_tags':
|
||||
if key == "preserve_tags":
|
||||
self.preserve_tags = set(value)
|
||||
elif key == 'handle_code_in_pre':
|
||||
elif key == "handle_code_in_pre":
|
||||
self.handle_code_in_pre = value
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
@@ -1056,17 +1064,19 @@ class CustomHTML2Text(HTML2Text):
|
||||
self.current_preserved_tag = tag
|
||||
self.preserved_content = []
|
||||
# Format opening tag with attributes
|
||||
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None)
|
||||
self.preserved_content.append(f'<{tag}{attr_str}>')
|
||||
attr_str = "".join(
|
||||
f' {k}="{v}"' for k, v in attrs.items() if v is not None
|
||||
)
|
||||
self.preserved_content.append(f"<{tag}{attr_str}>")
|
||||
self.preserve_depth += 1
|
||||
return
|
||||
else:
|
||||
self.preserve_depth -= 1
|
||||
if self.preserve_depth == 0:
|
||||
self.preserved_content.append(f'</{tag}>')
|
||||
self.preserved_content.append(f"</{tag}>")
|
||||
# Output the preserved HTML block with proper spacing
|
||||
preserved_html = ''.join(self.preserved_content)
|
||||
self.o('\n' + preserved_html + '\n')
|
||||
preserved_html = "".join(self.preserved_content)
|
||||
self.o("\n" + preserved_html + "\n")
|
||||
self.current_preserved_tag = None
|
||||
return
|
||||
|
||||
@@ -1074,29 +1084,31 @@ class CustomHTML2Text(HTML2Text):
|
||||
if self.preserve_depth > 0:
|
||||
if start:
|
||||
# Format nested tags with attributes
|
||||
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None)
|
||||
self.preserved_content.append(f'<{tag}{attr_str}>')
|
||||
attr_str = "".join(
|
||||
f' {k}="{v}"' for k, v in attrs.items() if v is not None
|
||||
)
|
||||
self.preserved_content.append(f"<{tag}{attr_str}>")
|
||||
else:
|
||||
self.preserved_content.append(f'</{tag}>')
|
||||
self.preserved_content.append(f"</{tag}>")
|
||||
return
|
||||
|
||||
# Handle pre tags
|
||||
if tag == 'pre':
|
||||
if tag == "pre":
|
||||
if start:
|
||||
self.o('```\n') # Markdown code block start
|
||||
self.o("```\n") # Markdown code block start
|
||||
self.inside_pre = True
|
||||
else:
|
||||
self.o('\n```\n') # Markdown code block end
|
||||
self.o("\n```\n") # Markdown code block end
|
||||
self.inside_pre = False
|
||||
elif tag == 'code':
|
||||
elif tag == "code":
|
||||
if self.inside_pre and not self.handle_code_in_pre:
|
||||
# Ignore code tags inside pre blocks if handle_code_in_pre is False
|
||||
return
|
||||
if start:
|
||||
self.o('`') # Markdown inline code start
|
||||
self.o("`") # Markdown inline code start
|
||||
self.inside_code = True
|
||||
else:
|
||||
self.o('`') # Markdown inline code end
|
||||
self.o("`") # Markdown inline code end
|
||||
self.inside_code = False
|
||||
else:
|
||||
super().handle_tag(tag, attrs, start)
|
||||
@@ -1113,13 +1125,12 @@ class CustomHTML2Text(HTML2Text):
|
||||
return
|
||||
if self.inside_code:
|
||||
# Inline code: no newlines allowed
|
||||
self.o(data.replace('\n', ' '))
|
||||
self.o(data.replace("\n", " "))
|
||||
return
|
||||
|
||||
# Default behavior for other tags
|
||||
super().handle_data(data, entity_char)
|
||||
|
||||
|
||||
# # Handle pre tags
|
||||
# if tag == 'pre':
|
||||
# if start:
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
class OutCallback:
|
||||
def __call__(self, s: str) -> None: ...
|
||||
def __call__(self, s: str) -> None:
|
||||
...
|
||||
|
||||
@@ -210,7 +210,7 @@ def escape_md_section(
|
||||
snob: bool = False,
|
||||
escape_dot: bool = True,
|
||||
escape_plus: bool = True,
|
||||
escape_dash: bool = True
|
||||
escape_dash: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Escapes markdown-sensitive characters across whole document sections.
|
||||
@@ -233,6 +233,7 @@ def escape_md_section(
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def reformat_table(lines: List[str], right_margin: int) -> List[str]:
|
||||
"""
|
||||
Given the lines of a table
|
||||
|
||||
@@ -6,25 +6,44 @@ from .async_logger import AsyncLogger, LogLevel
|
||||
# Initialize logger
|
||||
logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
|
||||
|
||||
|
||||
def post_install():
|
||||
"""Run all post-installation tasks"""
|
||||
logger.info("Running post-installation setup...", tag="INIT")
|
||||
install_playwright()
|
||||
run_migration()
|
||||
logger.success("Post-installation setup completed!", tag="COMPLETE")
|
||||
|
||||
|
||||
|
||||
def install_playwright():
|
||||
logger.info("Installing Playwright browsers...", tag="INIT")
|
||||
try:
|
||||
# subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"])
|
||||
subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chromium"])
|
||||
logger.success("Playwright installation completed successfully.", tag="COMPLETE")
|
||||
except subprocess.CalledProcessError as e:
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"playwright",
|
||||
"install",
|
||||
"--with-deps",
|
||||
"--force",
|
||||
"chromium",
|
||||
]
|
||||
)
|
||||
logger.success(
|
||||
"Playwright installation completed successfully.", tag="COMPLETE"
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
# logger.error(f"Error during Playwright installation: {e}", tag="ERROR")
|
||||
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
|
||||
)
|
||||
except Exception:
|
||||
# logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR")
|
||||
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.")
|
||||
logger.warning(
|
||||
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
|
||||
)
|
||||
|
||||
|
||||
def run_migration():
|
||||
"""Initialize database during installation"""
|
||||
@@ -33,18 +52,26 @@ def run_migration():
|
||||
from crawl4ai.async_database import async_db_manager
|
||||
|
||||
asyncio.run(async_db_manager.initialize())
|
||||
logger.success("Database initialization completed successfully.", tag="COMPLETE")
|
||||
logger.success(
|
||||
"Database initialization completed successfully.", tag="COMPLETE"
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("Database module not found. Will initialize on first use.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Database initialization failed: {e}")
|
||||
logger.warning("Database will be initialized on first use")
|
||||
|
||||
|
||||
async def run_doctor():
|
||||
"""Test if Crawl4AI is working properly"""
|
||||
logger.info("Running Crawl4AI health check...", tag="INIT")
|
||||
try:
|
||||
from .async_webcrawler import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||
from .async_webcrawler import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
CacheMode,
|
||||
)
|
||||
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
@@ -52,7 +79,7 @@ async def run_doctor():
|
||||
ignore_https_errors=True,
|
||||
light_mode=True,
|
||||
viewport_width=1280,
|
||||
viewport_height=720
|
||||
viewport_height=720,
|
||||
)
|
||||
|
||||
run_config = CrawlerRunConfig(
|
||||
@@ -62,10 +89,7 @@ async def run_doctor():
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
logger.info("Testing crawling capabilities...", tag="TEST")
|
||||
result = await crawler.arun(
|
||||
url="https://crawl4ai.com",
|
||||
config=run_config
|
||||
)
|
||||
result = await crawler.arun(url="https://crawl4ai.com", config=run_config)
|
||||
|
||||
if result and result.markdown:
|
||||
logger.success("✅ Crawling test passed!", tag="COMPLETE")
|
||||
@@ -77,7 +101,9 @@ async def run_doctor():
|
||||
logger.error(f"❌ Test failed: {e}", tag="ERROR")
|
||||
return False
|
||||
|
||||
|
||||
def doctor():
|
||||
"""Entry point for the doctor command"""
|
||||
import asyncio
|
||||
|
||||
return asyncio.run(run_doctor())
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
import os, sys
|
||||
import os
|
||||
|
||||
|
||||
# Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free
|
||||
def load_js_script(script_name):
|
||||
# Get the path of the current script
|
||||
current_script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
# Get the path of the script to load
|
||||
script_path = os.path.join(current_script_path, script_name + '.js')
|
||||
script_path = os.path.join(current_script_path, script_name + ".js")
|
||||
# Check if the script exists
|
||||
if not os.path.exists(script_path):
|
||||
raise ValueError(f"Script {script_name} not found in the folder {current_script_path}")
|
||||
raise ValueError(
|
||||
f"Script {script_name} not found in the folder {current_script_path}"
|
||||
)
|
||||
# Load the content of the script
|
||||
with open(script_path, 'r') as f:
|
||||
with open(script_path, "r") as f:
|
||||
script_content = f.read()
|
||||
return script_content
|
||||
|
||||
@@ -11,16 +11,16 @@ from rank_bm25 import BM25Okapi
|
||||
from nltk.tokenize import word_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
from litellm import completion, batch_completion
|
||||
from litellm import batch_completion
|
||||
from .async_logger import AsyncLogger
|
||||
import litellm
|
||||
import pickle
|
||||
import hashlib # <--- ADDED for file-hash
|
||||
from fnmatch import fnmatch
|
||||
import glob
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
|
||||
def _compute_file_hash(file_path: Path) -> str:
|
||||
"""Compute MD5 hash for the file's entire content."""
|
||||
hash_md5 = hashlib.md5()
|
||||
@@ -29,13 +29,14 @@ def _compute_file_hash(file_path: Path) -> str:
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
class AsyncLLMTextManager:
|
||||
def __init__(
|
||||
self,
|
||||
docs_dir: Path,
|
||||
logger: Optional[AsyncLogger] = None,
|
||||
max_concurrent_calls: int = 5,
|
||||
batch_size: int = 3
|
||||
batch_size: int = 3,
|
||||
) -> None:
|
||||
self.docs_dir = docs_dir
|
||||
self.logger = logger
|
||||
@@ -51,7 +52,7 @@ class AsyncLLMTextManager:
|
||||
contents = []
|
||||
for file_path in doc_batch:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
contents.append(f.read())
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error reading {file_path}: {str(e)}")
|
||||
@@ -77,43 +78,53 @@ Wrap your response in <index>...</index> tags.
|
||||
# Prepare messages for batch processing
|
||||
messages_list = [
|
||||
[
|
||||
{"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"}
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}",
|
||||
}
|
||||
]
|
||||
for content in contents if content
|
||||
for content in contents
|
||||
if content
|
||||
]
|
||||
|
||||
try:
|
||||
responses = batch_completion(
|
||||
model="anthropic/claude-3-5-sonnet-latest",
|
||||
messages=messages_list,
|
||||
logger_fn=None
|
||||
logger_fn=None,
|
||||
)
|
||||
|
||||
# Process responses and save index files
|
||||
for response, file_path in zip(responses, doc_batch):
|
||||
try:
|
||||
index_content_match = re.search(
|
||||
r'<index>(.*?)</index>',
|
||||
r"<index>(.*?)</index>",
|
||||
response.choices[0].message.content,
|
||||
re.DOTALL
|
||||
re.DOTALL,
|
||||
)
|
||||
if not index_content_match:
|
||||
self.logger.warning(f"No <index>...</index> content found for {file_path}")
|
||||
self.logger.warning(
|
||||
f"No <index>...</index> content found for {file_path}"
|
||||
)
|
||||
continue
|
||||
|
||||
index_content = re.sub(
|
||||
r"\n\s*\n", "\n", index_content_match.group(1)
|
||||
).strip()
|
||||
if index_content:
|
||||
index_file = file_path.with_suffix('.q.md')
|
||||
with open(index_file, 'w', encoding='utf-8') as f:
|
||||
index_file = file_path.with_suffix(".q.md")
|
||||
with open(index_file, "w", encoding="utf-8") as f:
|
||||
f.write(index_content)
|
||||
self.logger.info(f"Created index file: {index_file}")
|
||||
else:
|
||||
self.logger.warning(f"No index content found in response for {file_path}")
|
||||
self.logger.warning(
|
||||
f"No index content found in response for {file_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing response for {file_path}: {str(e)}")
|
||||
self.logger.error(
|
||||
f"Error processing response for {file_path}: {str(e)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in batch completion: {str(e)}")
|
||||
@@ -171,7 +182,12 @@ Wrap your response in <index>...</index> tags.
|
||||
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
stop_words = set(stopwords.words("english")) - {
|
||||
"how", "what", "when", "where", "why", "which",
|
||||
"how",
|
||||
"what",
|
||||
"when",
|
||||
"where",
|
||||
"why",
|
||||
"which",
|
||||
}
|
||||
|
||||
tokens = []
|
||||
@@ -222,7 +238,9 @@ Wrap your response in <index>...</index> tags.
|
||||
self.logger.info("Checking which .q.md files need (re)indexing...")
|
||||
|
||||
# Gather all .q.md files
|
||||
q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")]
|
||||
q_files = [
|
||||
self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")
|
||||
]
|
||||
|
||||
# We'll store known (unchanged) facts in these lists
|
||||
existing_facts: List[str] = []
|
||||
@@ -243,7 +261,9 @@ Wrap your response in <index>...</index> tags.
|
||||
# Otherwise, load the existing cache and compare hash
|
||||
cache = self._load_or_create_token_cache(qf)
|
||||
# If the .q.tokens was out of date (i.e. changed hash), we reindex
|
||||
if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf):
|
||||
if len(cache["facts"]) == 0 or cache.get(
|
||||
"content_hash"
|
||||
) != _compute_file_hash(qf):
|
||||
needSet.append(qf)
|
||||
else:
|
||||
# File is unchanged → retrieve cached token data
|
||||
@@ -255,20 +275,29 @@ Wrap your response in <index>...</index> tags.
|
||||
if not needSet and not clear_cache:
|
||||
# If no file needs reindexing, try loading existing index
|
||||
if self.maybe_load_bm25_index(clear_cache=False):
|
||||
self.logger.info("No new/changed .q.md files found. Using existing BM25 index.")
|
||||
self.logger.info(
|
||||
"No new/changed .q.md files found. Using existing BM25 index."
|
||||
)
|
||||
return
|
||||
else:
|
||||
# If there's no existing index, we must build a fresh index from the old caches
|
||||
self.logger.info("No existing BM25 index found. Building from cached facts.")
|
||||
self.logger.info(
|
||||
"No existing BM25 index found. Building from cached facts."
|
||||
)
|
||||
if existing_facts:
|
||||
self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.")
|
||||
self.logger.info(
|
||||
f"Building BM25 index with {len(existing_facts)} cached facts."
|
||||
)
|
||||
self.bm25_index = BM25Okapi(existing_tokens)
|
||||
self.tokenized_facts = existing_facts
|
||||
with open(self.bm25_index_file, "wb") as f:
|
||||
pickle.dump({
|
||||
"bm25_index": self.bm25_index,
|
||||
"tokenized_facts": self.tokenized_facts
|
||||
}, f)
|
||||
pickle.dump(
|
||||
{
|
||||
"bm25_index": self.bm25_index,
|
||||
"tokenized_facts": self.tokenized_facts,
|
||||
},
|
||||
f,
|
||||
)
|
||||
else:
|
||||
self.logger.warning("No facts found at all. Index remains empty.")
|
||||
return
|
||||
@@ -311,7 +340,9 @@ Wrap your response in <index>...</index> tags.
|
||||
self._save_token_cache(file, fresh_cache)
|
||||
|
||||
mem_usage = process.memory_info().rss / 1024 / 1024
|
||||
self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB")
|
||||
self.logger.debug(
|
||||
f"Memory usage after {file.name}: {mem_usage:.2f}MB"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing {file}: {str(e)}")
|
||||
@@ -328,40 +359,49 @@ Wrap your response in <index>...</index> tags.
|
||||
all_tokens = existing_tokens + new_tokens
|
||||
|
||||
# 3) Build BM25 index from combined facts
|
||||
self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).")
|
||||
self.logger.info(
|
||||
f"Building BM25 index with {len(all_facts)} total facts (old + new)."
|
||||
)
|
||||
self.bm25_index = BM25Okapi(all_tokens)
|
||||
self.tokenized_facts = all_facts
|
||||
|
||||
# 4) Save the updated BM25 index to disk
|
||||
with open(self.bm25_index_file, "wb") as f:
|
||||
pickle.dump({
|
||||
"bm25_index": self.bm25_index,
|
||||
"tokenized_facts": self.tokenized_facts
|
||||
}, f)
|
||||
pickle.dump(
|
||||
{
|
||||
"bm25_index": self.bm25_index,
|
||||
"tokenized_facts": self.tokenized_facts,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
final_mem = process.memory_info().rss / 1024 / 1024
|
||||
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
|
||||
|
||||
async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None:
|
||||
async def generate_index_files(
|
||||
self, force_generate_facts: bool = False, clear_bm25_cache: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Generate index files for all documents in parallel batches
|
||||
|
||||
|
||||
Args:
|
||||
force_generate_facts (bool): If True, regenerate indexes even if they exist
|
||||
clear_bm25_cache (bool): If True, clear existing BM25 index cache
|
||||
"""
|
||||
self.logger.info("Starting index generation for documentation files.")
|
||||
|
||||
|
||||
md_files = [
|
||||
self.docs_dir / f for f in os.listdir(self.docs_dir)
|
||||
if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md'])
|
||||
self.docs_dir / f
|
||||
for f in os.listdir(self.docs_dir)
|
||||
if f.endswith(".md") and not any(f.endswith(x) for x in [".q.md", ".xs.md"])
|
||||
]
|
||||
|
||||
# Filter out files that already have .q files unless force=True
|
||||
if not force_generate_facts:
|
||||
md_files = [
|
||||
f for f in md_files
|
||||
if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists()
|
||||
f
|
||||
for f in md_files
|
||||
if not (self.docs_dir / f.name.replace(".md", ".q.md")).exists()
|
||||
]
|
||||
|
||||
if not md_files:
|
||||
@@ -369,8 +409,10 @@ Wrap your response in <index>...</index> tags.
|
||||
else:
|
||||
# Process documents in batches
|
||||
for i in range(0, len(md_files), self.batch_size):
|
||||
batch = md_files[i:i + self.batch_size]
|
||||
self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}")
|
||||
batch = md_files[i : i + self.batch_size]
|
||||
self.logger.info(
|
||||
f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}"
|
||||
)
|
||||
await self._process_document_batch(batch)
|
||||
|
||||
self.logger.info("Index generation complete, building/updating search index.")
|
||||
@@ -378,21 +420,31 @@ Wrap your response in <index>...</index> tags.
|
||||
|
||||
def generate(self, sections: List[str], mode: str = "extended") -> str:
|
||||
# Get all markdown files
|
||||
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \
|
||||
glob.glob(str(self.docs_dir / "[0-9]*.xs.md"))
|
||||
|
||||
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + glob.glob(
|
||||
str(self.docs_dir / "[0-9]*.xs.md")
|
||||
)
|
||||
|
||||
# Extract base names without extensions
|
||||
base_docs = {Path(f).name.split('.')[0] for f in all_files
|
||||
if not Path(f).name.endswith('.q.md')}
|
||||
|
||||
base_docs = {
|
||||
Path(f).name.split(".")[0]
|
||||
for f in all_files
|
||||
if not Path(f).name.endswith(".q.md")
|
||||
}
|
||||
|
||||
# Filter by sections if provided
|
||||
if sections:
|
||||
base_docs = {doc for doc in base_docs
|
||||
if any(section.lower() in doc.lower() for section in sections)}
|
||||
|
||||
base_docs = {
|
||||
doc
|
||||
for doc in base_docs
|
||||
if any(section.lower() in doc.lower() for section in sections)
|
||||
}
|
||||
|
||||
# Get file paths based on mode
|
||||
files = []
|
||||
for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999):
|
||||
for doc in sorted(
|
||||
base_docs,
|
||||
key=lambda x: int(x.split("_")[0]) if x.split("_")[0].isdigit() else 999999,
|
||||
):
|
||||
if mode == "condensed":
|
||||
xs_file = self.docs_dir / f"{doc}.xs.md"
|
||||
regular_file = self.docs_dir / f"{doc}.md"
|
||||
@@ -404,7 +456,7 @@ Wrap your response in <index>...</index> tags.
|
||||
content = []
|
||||
for file in files:
|
||||
try:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
fname = Path(file).name
|
||||
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
|
||||
except Exception as e:
|
||||
@@ -443,15 +495,9 @@ Wrap your response in <index>...</index> tags.
|
||||
for file, _ in ranked_files:
|
||||
main_doc = str(file).replace(".q.md", ".md")
|
||||
if os.path.exists(self.docs_dir / main_doc):
|
||||
with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f:
|
||||
with open(self.docs_dir / main_doc, "r", encoding="utf-8") as f:
|
||||
only_file_name = main_doc.split("/")[-1]
|
||||
content = [
|
||||
"#" * 20,
|
||||
f"# {only_file_name}",
|
||||
"#" * 20,
|
||||
"",
|
||||
f.read()
|
||||
]
|
||||
content = ["#" * 20, f"# {only_file_name}", "#" * 20, "", f.read()]
|
||||
results.append("\n".join(content))
|
||||
|
||||
return "\n\n---\n\n".join(results)
|
||||
@@ -482,7 +528,9 @@ Wrap your response in <index>...</index> tags.
|
||||
if len(components) == 3:
|
||||
code_ref = components[2].strip()
|
||||
code_tokens = self.preprocess_text(code_ref)
|
||||
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens)
|
||||
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(
|
||||
query_tokens
|
||||
)
|
||||
|
||||
file_data[file_path]["total_score"] += score
|
||||
file_data[file_path]["match_count"] += 1
|
||||
|
||||
@@ -2,77 +2,94 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from .models import MarkdownGenerationResult
|
||||
from .html2text import CustomHTML2Text
|
||||
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter
|
||||
from .content_filter_strategy import RelevantContentFilter
|
||||
import re
|
||||
from urllib.parse import urljoin
|
||||
|
||||
# Pre-compile the regex pattern
|
||||
LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)')
|
||||
|
||||
|
||||
def fast_urljoin(base: str, url: str) -> str:
|
||||
"""Fast URL joining for common cases."""
|
||||
if url.startswith(('http://', 'https://', 'mailto:', '//')):
|
||||
if url.startswith(("http://", "https://", "mailto:", "//")):
|
||||
return url
|
||||
if url.startswith('/'):
|
||||
if url.startswith("/"):
|
||||
# Handle absolute paths
|
||||
if base.endswith('/'):
|
||||
if base.endswith("/"):
|
||||
return base[:-1] + url
|
||||
return base + url
|
||||
return urljoin(base, url)
|
||||
|
||||
|
||||
class MarkdownGenerationStrategy(ABC):
|
||||
"""Abstract base class for markdown generation strategies."""
|
||||
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_filter: Optional[RelevantContentFilter] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.content_filter = content_filter
|
||||
self.options = options or {}
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def generate_markdown(self,
|
||||
cleaned_html: str,
|
||||
base_url: str = "",
|
||||
html2text_options: Optional[Dict[str, Any]] = None,
|
||||
content_filter: Optional[RelevantContentFilter] = None,
|
||||
citations: bool = True,
|
||||
**kwargs) -> MarkdownGenerationResult:
|
||||
def generate_markdown(
|
||||
self,
|
||||
cleaned_html: str,
|
||||
base_url: str = "",
|
||||
html2text_options: Optional[Dict[str, Any]] = None,
|
||||
content_filter: Optional[RelevantContentFilter] = None,
|
||||
citations: bool = True,
|
||||
**kwargs,
|
||||
) -> MarkdownGenerationResult:
|
||||
"""Generate markdown from cleaned HTML."""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
"""
|
||||
Default implementation of markdown generation strategy.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Generate raw markdown from cleaned HTML.
|
||||
2. Convert links to citations.
|
||||
3. Generate fit markdown if content filter is provided.
|
||||
4. Return MarkdownGenerationResult.
|
||||
|
||||
|
||||
Args:
|
||||
content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown.
|
||||
options (Optional[Dict[str, Any]]): Additional options for markdown generation. Defaults to None.
|
||||
|
||||
|
||||
Returns:
|
||||
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
|
||||
"""
|
||||
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_filter: Optional[RelevantContentFilter] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(content_filter, options)
|
||||
|
||||
def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]:
|
||||
|
||||
def convert_links_to_citations(
|
||||
self, markdown: str, base_url: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Convert links in markdown to citations.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Find all links in the markdown.
|
||||
2. Convert links to citations.
|
||||
3. Return converted markdown and references markdown.
|
||||
|
||||
|
||||
Note:
|
||||
This function uses a regex pattern to find links in markdown.
|
||||
|
||||
|
||||
Args:
|
||||
markdown (str): Markdown text.
|
||||
base_url (str): Base URL for URL joins.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Converted markdown and references markdown.
|
||||
"""
|
||||
@@ -81,57 +98,65 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
parts = []
|
||||
last_end = 0
|
||||
counter = 1
|
||||
|
||||
|
||||
for match in LINK_PATTERN.finditer(markdown):
|
||||
parts.append(markdown[last_end:match.start()])
|
||||
parts.append(markdown[last_end : match.start()])
|
||||
text, url, title = match.groups()
|
||||
|
||||
|
||||
# Use cached URL if available, otherwise compute and cache
|
||||
if base_url and not url.startswith(('http://', 'https://', 'mailto:')):
|
||||
if base_url and not url.startswith(("http://", "https://", "mailto:")):
|
||||
if url not in url_cache:
|
||||
url_cache[url] = fast_urljoin(base_url, url)
|
||||
url = url_cache[url]
|
||||
|
||||
|
||||
if url not in link_map:
|
||||
desc = []
|
||||
if title: desc.append(title)
|
||||
if text and text != title: desc.append(text)
|
||||
if title:
|
||||
desc.append(title)
|
||||
if text and text != title:
|
||||
desc.append(text)
|
||||
link_map[url] = (counter, ": " + " - ".join(desc) if desc else "")
|
||||
counter += 1
|
||||
|
||||
|
||||
num = link_map[url][0]
|
||||
parts.append(f"{text}⟨{num}⟩" if not match.group(0).startswith('!') else f"![{text}⟨{num}⟩]")
|
||||
parts.append(
|
||||
f"{text}⟨{num}⟩"
|
||||
if not match.group(0).startswith("!")
|
||||
else f"![{text}⟨{num}⟩]"
|
||||
)
|
||||
last_end = match.end()
|
||||
|
||||
|
||||
parts.append(markdown[last_end:])
|
||||
converted_text = ''.join(parts)
|
||||
|
||||
converted_text = "".join(parts)
|
||||
|
||||
# Pre-build reference strings
|
||||
references = ["\n\n## References\n\n"]
|
||||
references.extend(
|
||||
f"⟨{num}⟩ {url}{desc}\n"
|
||||
f"⟨{num}⟩ {url}{desc}\n"
|
||||
for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0])
|
||||
)
|
||||
|
||||
return converted_text, ''.join(references)
|
||||
|
||||
def generate_markdown(self,
|
||||
cleaned_html: str,
|
||||
base_url: str = "",
|
||||
html2text_options: Optional[Dict[str, Any]] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
content_filter: Optional[RelevantContentFilter] = None,
|
||||
citations: bool = True,
|
||||
**kwargs) -> MarkdownGenerationResult:
|
||||
return converted_text, "".join(references)
|
||||
|
||||
def generate_markdown(
|
||||
self,
|
||||
cleaned_html: str,
|
||||
base_url: str = "",
|
||||
html2text_options: Optional[Dict[str, Any]] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
content_filter: Optional[RelevantContentFilter] = None,
|
||||
citations: bool = True,
|
||||
**kwargs,
|
||||
) -> MarkdownGenerationResult:
|
||||
"""
|
||||
Generate markdown with citations from cleaned HTML.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Generate raw markdown from cleaned HTML.
|
||||
2. Convert links to citations.
|
||||
3. Generate fit markdown if content filter is provided.
|
||||
4. Return MarkdownGenerationResult.
|
||||
|
||||
|
||||
Args:
|
||||
cleaned_html (str): Cleaned HTML content.
|
||||
base_url (str): Base URL for URL joins.
|
||||
@@ -139,7 +164,7 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
options (Optional[Dict[str, Any]]): Additional options for markdown generation.
|
||||
content_filter (Optional[RelevantContentFilter]): Content filter for generating fit markdown.
|
||||
citations (bool): Whether to generate citations.
|
||||
|
||||
|
||||
Returns:
|
||||
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
|
||||
"""
|
||||
@@ -147,16 +172,16 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
# Initialize HTML2Text with default options for better conversion
|
||||
h = CustomHTML2Text(baseurl=base_url)
|
||||
default_options = {
|
||||
'body_width': 0, # Disable text wrapping
|
||||
'ignore_emphasis': False,
|
||||
'ignore_links': False,
|
||||
'ignore_images': False,
|
||||
'protect_links': True,
|
||||
'single_line_break': True,
|
||||
'mark_code': True,
|
||||
'escape_snob': False
|
||||
"body_width": 0, # Disable text wrapping
|
||||
"ignore_emphasis": False,
|
||||
"ignore_links": False,
|
||||
"ignore_images": False,
|
||||
"protect_links": True,
|
||||
"single_line_break": True,
|
||||
"mark_code": True,
|
||||
"escape_snob": False,
|
||||
}
|
||||
|
||||
|
||||
# Update with custom options if provided
|
||||
if html2text_options:
|
||||
default_options.update(html2text_options)
|
||||
@@ -164,7 +189,7 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
default_options.update(options)
|
||||
elif self.options:
|
||||
default_options.update(self.options)
|
||||
|
||||
|
||||
h.update_params(**default_options)
|
||||
|
||||
# Ensure we have valid input
|
||||
@@ -178,17 +203,18 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
raw_markdown = h.handle(cleaned_html)
|
||||
except Exception as e:
|
||||
raw_markdown = f"Error converting HTML to markdown: {str(e)}"
|
||||
|
||||
raw_markdown = raw_markdown.replace(' ```', '```')
|
||||
|
||||
raw_markdown = raw_markdown.replace(" ```", "```")
|
||||
|
||||
# Convert links to citations
|
||||
markdown_with_citations: str = raw_markdown
|
||||
references_markdown: str = ""
|
||||
if citations:
|
||||
try:
|
||||
markdown_with_citations, references_markdown = self.convert_links_to_citations(
|
||||
raw_markdown, base_url
|
||||
)
|
||||
(
|
||||
markdown_with_citations,
|
||||
references_markdown,
|
||||
) = self.convert_links_to_citations(raw_markdown, base_url)
|
||||
except Exception as e:
|
||||
markdown_with_citations = raw_markdown
|
||||
references_markdown = f"Error generating citations: {str(e)}"
|
||||
@@ -200,7 +226,9 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||
try:
|
||||
content_filter = content_filter or self.content_filter
|
||||
filtered_html = content_filter.filter_content(cleaned_html)
|
||||
filtered_html = '\n'.join('<div>{}</div>'.format(s) for s in filtered_html)
|
||||
filtered_html = "\n".join(
|
||||
"<div>{}</div>".format(s) for s in filtered_html
|
||||
)
|
||||
fit_markdown = h.handle(filtered_html)
|
||||
except Exception as e:
|
||||
fit_markdown = f"Error generating fit markdown: {str(e)}"
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import aiosqlite
|
||||
from typing import Optional
|
||||
import xxhash
|
||||
import aiofiles
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from .async_logger import AsyncLogger, LogLevel
|
||||
|
||||
@@ -17,18 +15,19 @@ logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
|
||||
# logging.basicConfig(level=logging.INFO)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseMigration:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.content_paths = self._ensure_content_dirs(os.path.dirname(db_path))
|
||||
|
||||
|
||||
def _ensure_content_dirs(self, base_path: str) -> dict:
|
||||
dirs = {
|
||||
'html': 'html_content',
|
||||
'cleaned': 'cleaned_html',
|
||||
'markdown': 'markdown_content',
|
||||
'extracted': 'extracted_content',
|
||||
'screenshots': 'screenshots'
|
||||
"html": "html_content",
|
||||
"cleaned": "cleaned_html",
|
||||
"markdown": "markdown_content",
|
||||
"extracted": "extracted_content",
|
||||
"screenshots": "screenshots",
|
||||
}
|
||||
content_paths = {}
|
||||
for key, dirname in dirs.items():
|
||||
@@ -47,43 +46,55 @@ class DatabaseMigration:
|
||||
async def _store_content(self, content: str, content_type: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
content_hash = self._generate_content_hash(content)
|
||||
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
||||
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
|
||||
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
|
||||
await f.write(content)
|
||||
|
||||
|
||||
return content_hash
|
||||
|
||||
async def migrate_database(self):
|
||||
"""Migrate existing database to file-based storage"""
|
||||
# logger.info("Starting database migration...")
|
||||
logger.info("Starting database migration...", tag="INIT")
|
||||
|
||||
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
# Get all rows
|
||||
async with db.execute(
|
||||
'''SELECT url, html, cleaned_html, markdown,
|
||||
extracted_content, screenshot FROM crawled_data'''
|
||||
"""SELECT url, html, cleaned_html, markdown,
|
||||
extracted_content, screenshot FROM crawled_data"""
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
migrated_count = 0
|
||||
for row in rows:
|
||||
url, html, cleaned_html, markdown, extracted_content, screenshot = row
|
||||
|
||||
(
|
||||
url,
|
||||
html,
|
||||
cleaned_html,
|
||||
markdown,
|
||||
extracted_content,
|
||||
screenshot,
|
||||
) = row
|
||||
|
||||
# Store content in files and get hashes
|
||||
html_hash = await self._store_content(html, 'html')
|
||||
cleaned_hash = await self._store_content(cleaned_html, 'cleaned')
|
||||
markdown_hash = await self._store_content(markdown, 'markdown')
|
||||
extracted_hash = await self._store_content(extracted_content, 'extracted')
|
||||
screenshot_hash = await self._store_content(screenshot, 'screenshots')
|
||||
html_hash = await self._store_content(html, "html")
|
||||
cleaned_hash = await self._store_content(cleaned_html, "cleaned")
|
||||
markdown_hash = await self._store_content(markdown, "markdown")
|
||||
extracted_hash = await self._store_content(
|
||||
extracted_content, "extracted"
|
||||
)
|
||||
screenshot_hash = await self._store_content(
|
||||
screenshot, "screenshots"
|
||||
)
|
||||
|
||||
# Update database with hashes
|
||||
await db.execute('''
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE crawled_data
|
||||
SET html = ?,
|
||||
cleaned_html = ?,
|
||||
@@ -91,40 +102,51 @@ class DatabaseMigration:
|
||||
extracted_content = ?,
|
||||
screenshot = ?
|
||||
WHERE url = ?
|
||||
''', (html_hash, cleaned_hash, markdown_hash,
|
||||
extracted_hash, screenshot_hash, url))
|
||||
|
||||
""",
|
||||
(
|
||||
html_hash,
|
||||
cleaned_hash,
|
||||
markdown_hash,
|
||||
extracted_hash,
|
||||
screenshot_hash,
|
||||
url,
|
||||
),
|
||||
)
|
||||
|
||||
migrated_count += 1
|
||||
if migrated_count % 100 == 0:
|
||||
logger.info(f"Migrated {migrated_count} records...", tag="INIT")
|
||||
|
||||
|
||||
await db.commit()
|
||||
logger.success(f"Migration completed. {migrated_count} records processed.", tag="COMPLETE")
|
||||
logger.success(
|
||||
f"Migration completed. {migrated_count} records processed.",
|
||||
tag="COMPLETE",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# logger.error(f"Migration failed: {e}")
|
||||
logger.error(
|
||||
message="Migration failed: {error}",
|
||||
tag="ERROR",
|
||||
params={"error": str(e)}
|
||||
params={"error": str(e)},
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
async def backup_database(db_path: str) -> str:
|
||||
"""Create backup of existing database"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.info("No existing database found. Skipping backup.", tag="INIT")
|
||||
return None
|
||||
|
||||
|
||||
# Create backup with timestamp
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = f"{db_path}.backup_{timestamp}"
|
||||
|
||||
|
||||
try:
|
||||
# Wait for any potential write operations to finish
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# Create backup
|
||||
shutil.copy2(db_path, backup_path)
|
||||
logger.info(f"Database backup created at: {backup_path}", tag="COMPLETE")
|
||||
@@ -132,37 +154,41 @@ async def backup_database(db_path: str) -> str:
|
||||
except Exception as e:
|
||||
# logger.error(f"Backup failed: {e}")
|
||||
logger.error(
|
||||
message="Migration failed: {error}",
|
||||
tag="ERROR",
|
||||
params={"error": str(e)}
|
||||
)
|
||||
message="Migration failed: {error}", tag="ERROR", params={"error": str(e)}
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
async def run_migration(db_path: Optional[str] = None):
|
||||
"""Run database migration"""
|
||||
if db_path is None:
|
||||
db_path = os.path.join(Path.home(), ".crawl4ai", "crawl4ai.db")
|
||||
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
logger.info("No existing database found. Skipping migration.", tag="INIT")
|
||||
return
|
||||
|
||||
|
||||
# Create backup first
|
||||
backup_path = await backup_database(db_path)
|
||||
if not backup_path:
|
||||
return
|
||||
|
||||
|
||||
migration = DatabaseMigration(db_path)
|
||||
await migration.migrate_database()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for migration"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage')
|
||||
parser.add_argument('--db-path', help='Custom database path')
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate Crawl4AI database to file-based storage"
|
||||
)
|
||||
parser.add_argument("--db-path", help="Custom database path")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
asyncio.run(run_migration(args.db_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -2,109 +2,125 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import subprocess, os
|
||||
import shutil
|
||||
import tarfile
|
||||
from .model_loader import *
|
||||
import argparse
|
||||
import urllib.request
|
||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_available_memory(device):
|
||||
import torch
|
||||
if device.type == 'cuda':
|
||||
|
||||
if device.type == "cuda":
|
||||
return torch.cuda.get_device_properties(device).total_memory
|
||||
elif device.type == 'mps':
|
||||
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
|
||||
elif device.type == "mps":
|
||||
return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def calculate_batch_size(device):
|
||||
available_memory = get_available_memory(device)
|
||||
|
||||
if device.type == 'cpu':
|
||||
|
||||
if device.type == "cpu":
|
||||
return 16
|
||||
elif device.type in ['cuda', 'mps']:
|
||||
elif device.type in ["cuda", "mps"]:
|
||||
# Adjust these thresholds based on your model size and available memory
|
||||
if available_memory >= 31 * 1024 ** 3: # > 32GB
|
||||
if available_memory >= 31 * 1024**3: # > 32GB
|
||||
return 256
|
||||
elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB
|
||||
elif available_memory >= 15 * 1024**3: # > 16GB to 32GB
|
||||
return 128
|
||||
elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB
|
||||
elif available_memory >= 8 * 1024**3: # 8GB to 16GB
|
||||
return 64
|
||||
else:
|
||||
return 32
|
||||
else:
|
||||
return 16 # Default batch size
|
||||
|
||||
return 16 # Default batch size
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_device():
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
return device
|
||||
|
||||
device = torch.device("cpu")
|
||||
return device
|
||||
|
||||
|
||||
def set_model_device(model):
|
||||
device = get_device()
|
||||
model.to(device)
|
||||
model.to(device)
|
||||
return model, device
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_home_folder():
|
||||
home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
|
||||
home_folder = os.path.join(
|
||||
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
|
||||
)
|
||||
os.makedirs(home_folder, exist_ok=True)
|
||||
os.makedirs(f"{home_folder}/cache", exist_ok=True)
|
||||
os.makedirs(f"{home_folder}/models", exist_ok=True)
|
||||
return home_folder
|
||||
return home_folder
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_bert_base_uncased():
|
||||
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||
from transformers import BertTokenizer, BertModel
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None)
|
||||
model = BertModel.from_pretrained("bert-base-uncased", resume_download=None)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
|
||||
"""Load the Hugging Face model for embedding.
|
||||
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The model name to load. Defaults to "BAAI/bge-small-en-v1.5".
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: The tokenizer and model.
|
||||
"""
|
||||
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
|
||||
model = AutoModel.from_pretrained(model_name, resume_download=None)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_text_classifier():
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"dstefa/roberta-base_topic_classification_nyt_news"
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"dstefa/roberta-base_topic_classification_nyt_news"
|
||||
)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
||||
return pipe
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_text_multilabel_classifier():
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
import torch
|
||||
|
||||
@@ -116,18 +132,27 @@ def load_text_multilabel_classifier():
|
||||
# else:
|
||||
# device = torch.device("cpu")
|
||||
# # return load_spacy_model(), torch.device("cpu")
|
||||
|
||||
|
||||
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
MODEL, resume_download=None
|
||||
)
|
||||
model.eval()
|
||||
model, device = set_model_device(model)
|
||||
class_mapping = model.config.id2label
|
||||
|
||||
def _classifier(texts, threshold=0.5, max_length=64):
|
||||
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
||||
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
|
||||
tokens = tokenizer(
|
||||
texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
tokens = {
|
||||
key: val.to(device) for key, val in tokens.items()
|
||||
} # Move tokens to the selected device
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**tokens)
|
||||
@@ -138,35 +163,41 @@ def load_text_multilabel_classifier():
|
||||
|
||||
batch_labels = []
|
||||
for prediction in predictions:
|
||||
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1]
|
||||
labels = [
|
||||
class_mapping[i] for i, value in enumerate(prediction) if value == 1
|
||||
]
|
||||
batch_labels.append(labels)
|
||||
|
||||
return batch_labels
|
||||
|
||||
return _classifier, device
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_nltk_punkt():
|
||||
import nltk
|
||||
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
nltk.download('punkt')
|
||||
return nltk.data.find('tokenizers/punkt')
|
||||
nltk.download("punkt")
|
||||
return nltk.data.find("tokenizers/punkt")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_spacy_model():
|
||||
import spacy
|
||||
|
||||
name = "models/reuters"
|
||||
home_folder = get_home_folder()
|
||||
model_folder = Path(home_folder) / name
|
||||
|
||||
|
||||
# Check if the model directory already exists
|
||||
if not (model_folder.exists() and any(model_folder.iterdir())):
|
||||
repo_url = "https://github.com/unclecode/crawl4ai.git"
|
||||
branch = MODEL_REPO_BRANCH
|
||||
branch = MODEL_REPO_BRANCH
|
||||
repo_folder = Path(home_folder) / "crawl4ai"
|
||||
|
||||
|
||||
print("[LOG] ⏬ Downloading Spacy model for the first time...")
|
||||
|
||||
# Remove existing repo folder if it exists
|
||||
@@ -176,7 +207,9 @@ def load_spacy_model():
|
||||
if model_folder.exists():
|
||||
shutil.rmtree(model_folder)
|
||||
except PermissionError:
|
||||
print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:")
|
||||
print(
|
||||
"[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:"
|
||||
)
|
||||
print(f"- {repo_folder}")
|
||||
print(f"- {model_folder}")
|
||||
return None
|
||||
@@ -187,7 +220,7 @@ def load_spacy_model():
|
||||
["git", "clone", "-b", branch, repo_url, str(repo_folder)],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
check=True
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Create the models directory if it doesn't exist
|
||||
@@ -215,6 +248,7 @@ def load_spacy_model():
|
||||
print(f"Error loading spacy model: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def download_all_models(remove_existing=False):
|
||||
"""Download all models required for Crawl4AI."""
|
||||
if remove_existing:
|
||||
@@ -243,14 +277,20 @@ def download_all_models(remove_existing=False):
|
||||
load_nltk_punkt()
|
||||
print("[LOG] ✅ All models downloaded successfully.")
|
||||
|
||||
|
||||
def main():
|
||||
print("[LOG] Welcome to the Crawl4AI Model Downloader!")
|
||||
print("[LOG] This script will download all the models required for Crawl4AI.")
|
||||
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
|
||||
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
|
||||
parser.add_argument(
|
||||
"--remove-existing",
|
||||
action="store_true",
|
||||
help="Remove existing models before downloading",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
download_all_models(remove_existing=args.remove_existing)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from typing import List, Dict, Optional, Callable, Awaitable, Union, Tuple, Any
|
||||
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from .ssl_certificate import SSLCertificate
|
||||
|
||||
from dataclasses import dataclass
|
||||
from .ssl_certificate import SSLCertificate
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
|
||||
###############################
|
||||
# Dispatcher Models
|
||||
###############################
|
||||
@@ -22,6 +16,7 @@ class DomainState:
|
||||
current_delay: float = 0
|
||||
fail_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlerTaskResult:
|
||||
task_id: str
|
||||
@@ -33,12 +28,14 @@ class CrawlerTaskResult:
|
||||
end_time: datetime
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class CrawlStatus(Enum):
|
||||
QUEUED = "QUEUED"
|
||||
IN_PROGRESS = "IN_PROGRESS"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlStats:
|
||||
task_id: str
|
||||
@@ -49,7 +46,7 @@ class CrawlStats:
|
||||
memory_usage: float = 0.0
|
||||
peak_memory: float = 0.0
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
@property
|
||||
def duration(self) -> str:
|
||||
if not self.start_time:
|
||||
@@ -58,26 +55,29 @@ class CrawlStats:
|
||||
duration = end - self.start_time
|
||||
return str(timedelta(seconds=int(duration.total_seconds())))
|
||||
|
||||
|
||||
class DisplayMode(Enum):
|
||||
DETAILED = "DETAILED"
|
||||
AGGREGATED = "AGGREGATED"
|
||||
|
||||
|
||||
###############################
|
||||
# Crawler Models
|
||||
###############################
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens_details: Optional[dict] = None
|
||||
prompt_tokens_details: Optional[dict] = None
|
||||
|
||||
|
||||
|
||||
class UrlModel(BaseModel):
|
||||
url: HttpUrl
|
||||
forced: bool = False
|
||||
|
||||
|
||||
class MarkdownGenerationResult(BaseModel):
|
||||
raw_markdown: str
|
||||
markdown_with_citations: str
|
||||
@@ -85,6 +85,7 @@ class MarkdownGenerationResult(BaseModel):
|
||||
fit_markdown: Optional[str] = None
|
||||
fit_html: Optional[str] = None
|
||||
|
||||
|
||||
class DispatchResult(BaseModel):
|
||||
task_id: str
|
||||
memory_usage: float
|
||||
@@ -92,6 +93,8 @@ class DispatchResult(BaseModel):
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class CrawlResult(BaseModel):
|
||||
url: str
|
||||
html: str
|
||||
@@ -101,7 +104,7 @@ class CrawlResult(BaseModel):
|
||||
links: Dict[str, List[Dict]] = {}
|
||||
downloaded_files: Optional[List[str]] = None
|
||||
screenshot: Optional[str] = None
|
||||
pdf : Optional[bytes] = None
|
||||
pdf: Optional[bytes] = None
|
||||
markdown: Optional[Union[str, MarkdownGenerationResult]] = None
|
||||
markdown_v2: Optional[MarkdownGenerationResult] = None
|
||||
fit_markdown: Optional[str] = None
|
||||
@@ -114,9 +117,11 @@ class CrawlResult(BaseModel):
|
||||
status_code: Optional[int] = None
|
||||
ssl_certificate: Optional[SSLCertificate] = None
|
||||
dispatch_result: Optional[DispatchResult] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AsyncCrawlResponse(BaseModel):
|
||||
html: str
|
||||
response_headers: Dict[str, str]
|
||||
@@ -130,6 +135,7 @@ class AsyncCrawlResponse(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
###############################
|
||||
# Scraping Models
|
||||
###############################
|
||||
@@ -143,21 +149,29 @@ class MediaItem(BaseModel):
|
||||
format: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
|
||||
|
||||
class Link(BaseModel):
|
||||
href: str
|
||||
text: str
|
||||
title: Optional[str] = None
|
||||
base_domain: str
|
||||
|
||||
|
||||
class Media(BaseModel):
|
||||
images: List[MediaItem] = []
|
||||
videos: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Video model if needed
|
||||
audios: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Audio model if needed
|
||||
videos: List[
|
||||
MediaItem
|
||||
] = [] # Using MediaItem model for now, can be extended with Video model if needed
|
||||
audios: List[
|
||||
MediaItem
|
||||
] = [] # Using MediaItem model for now, can be extended with Audio model if needed
|
||||
|
||||
|
||||
class Links(BaseModel):
|
||||
internal: List[Link] = []
|
||||
external: List[Link] = []
|
||||
|
||||
|
||||
class ScrapingResult(BaseModel):
|
||||
cleaned_html: str
|
||||
success: bool
|
||||
|
||||
@@ -13,10 +13,10 @@ from pathlib import Path
|
||||
class SSLCertificate:
|
||||
"""
|
||||
A class representing an SSL certificate with methods to export in various formats.
|
||||
|
||||
|
||||
Attributes:
|
||||
cert_info (Dict[str, Any]): The certificate information.
|
||||
|
||||
|
||||
Methods:
|
||||
from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']: Create SSLCertificate instance from a URL.
|
||||
from_file(file_path: str) -> Optional['SSLCertificate']: Create SSLCertificate instance from a file.
|
||||
@@ -26,32 +26,35 @@ class SSLCertificate:
|
||||
export_as_json() -> Dict[str, Any]: Export the certificate as JSON format.
|
||||
export_as_text() -> str: Export the certificate as text format.
|
||||
"""
|
||||
|
||||
def __init__(self, cert_info: Dict[str, Any]):
|
||||
self._cert_info = self._decode_cert_data(cert_info)
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']:
|
||||
def from_url(url: str, timeout: int = 10) -> Optional["SSLCertificate"]:
|
||||
"""
|
||||
Create SSLCertificate instance from a URL.
|
||||
|
||||
|
||||
Args:
|
||||
url (str): URL of the website.
|
||||
timeout (int): Timeout for the connection (default: 10).
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[SSLCertificate]: SSLCertificate instance if successful, None otherwise.
|
||||
"""
|
||||
try:
|
||||
hostname = urlparse(url).netloc
|
||||
if ':' in hostname:
|
||||
hostname = hostname.split(':')[0]
|
||||
|
||||
if ":" in hostname:
|
||||
hostname = hostname.split(":")[0]
|
||||
|
||||
context = ssl.create_default_context()
|
||||
with socket.create_connection((hostname, 443), timeout=timeout) as sock:
|
||||
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
|
||||
cert_binary = ssock.getpeercert(binary_form=True)
|
||||
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary)
|
||||
|
||||
x509 = OpenSSL.crypto.load_certificate(
|
||||
OpenSSL.crypto.FILETYPE_ASN1, cert_binary
|
||||
)
|
||||
|
||||
cert_info = {
|
||||
"subject": dict(x509.get_subject().get_components()),
|
||||
"issuer": dict(x509.get_issuer().get_components()),
|
||||
@@ -61,32 +64,33 @@ class SSLCertificate:
|
||||
"not_after": x509.get_notAfter(),
|
||||
"fingerprint": x509.digest("sha256").hex(),
|
||||
"signature_algorithm": x509.get_signature_algorithm(),
|
||||
"raw_cert": base64.b64encode(cert_binary)
|
||||
"raw_cert": base64.b64encode(cert_binary),
|
||||
}
|
||||
|
||||
|
||||
# Add extensions
|
||||
extensions = []
|
||||
for i in range(x509.get_extension_count()):
|
||||
ext = x509.get_extension(i)
|
||||
extensions.append({
|
||||
"name": ext.get_short_name(),
|
||||
"value": str(ext)
|
||||
})
|
||||
extensions.append(
|
||||
{"name": ext.get_short_name(), "value": str(ext)}
|
||||
)
|
||||
cert_info["extensions"] = extensions
|
||||
|
||||
|
||||
return SSLCertificate(cert_info)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _decode_cert_data(data: Any) -> Any:
|
||||
"""Helper method to decode bytes in certificate data."""
|
||||
if isinstance(data, bytes):
|
||||
return data.decode('utf-8')
|
||||
return data.decode("utf-8")
|
||||
elif isinstance(data, dict):
|
||||
return {
|
||||
(k.decode('utf-8') if isinstance(k, bytes) else k): SSLCertificate._decode_cert_data(v)
|
||||
(
|
||||
k.decode("utf-8") if isinstance(k, bytes) else k
|
||||
): SSLCertificate._decode_cert_data(v)
|
||||
for k, v in data.items()
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
@@ -96,58 +100,57 @@ class SSLCertificate:
|
||||
def to_json(self, filepath: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Export certificate as JSON.
|
||||
|
||||
|
||||
Args:
|
||||
filepath (Optional[str]): Path to save the JSON file (default: None).
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str]: JSON string if successful, None otherwise.
|
||||
"""
|
||||
json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False)
|
||||
if filepath:
|
||||
Path(filepath).write_text(json_str, encoding='utf-8')
|
||||
Path(filepath).write_text(json_str, encoding="utf-8")
|
||||
return None
|
||||
return json_str
|
||||
|
||||
def to_pem(self, filepath: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Export certificate as PEM.
|
||||
|
||||
|
||||
Args:
|
||||
filepath (Optional[str]): Path to save the PEM file (default: None).
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str]: PEM string if successful, None otherwise.
|
||||
"""
|
||||
try:
|
||||
x509 = OpenSSL.crypto.load_certificate(
|
||||
OpenSSL.crypto.FILETYPE_ASN1,
|
||||
base64.b64decode(self._cert_info['raw_cert'])
|
||||
OpenSSL.crypto.FILETYPE_ASN1,
|
||||
base64.b64decode(self._cert_info["raw_cert"]),
|
||||
)
|
||||
pem_data = OpenSSL.crypto.dump_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
x509
|
||||
).decode('utf-8')
|
||||
|
||||
OpenSSL.crypto.FILETYPE_PEM, x509
|
||||
).decode("utf-8")
|
||||
|
||||
if filepath:
|
||||
Path(filepath).write_text(pem_data, encoding='utf-8')
|
||||
Path(filepath).write_text(pem_data, encoding="utf-8")
|
||||
return None
|
||||
return pem_data
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]:
|
||||
"""
|
||||
Export certificate as DER.
|
||||
|
||||
|
||||
Args:
|
||||
filepath (Optional[str]): Path to save the DER file (default: None).
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[bytes]: DER bytes if successful, None otherwise.
|
||||
"""
|
||||
try:
|
||||
der_data = base64.b64decode(self._cert_info['raw_cert'])
|
||||
der_data = base64.b64decode(self._cert_info["raw_cert"])
|
||||
if filepath:
|
||||
Path(filepath).write_bytes(der_data)
|
||||
return None
|
||||
@@ -158,24 +161,24 @@ class SSLCertificate:
|
||||
@property
|
||||
def issuer(self) -> Dict[str, str]:
|
||||
"""Get certificate issuer information."""
|
||||
return self._cert_info.get('issuer', {})
|
||||
return self._cert_info.get("issuer", {})
|
||||
|
||||
@property
|
||||
def subject(self) -> Dict[str, str]:
|
||||
"""Get certificate subject information."""
|
||||
return self._cert_info.get('subject', {})
|
||||
return self._cert_info.get("subject", {})
|
||||
|
||||
@property
|
||||
def valid_from(self) -> str:
|
||||
"""Get certificate validity start date."""
|
||||
return self._cert_info.get('not_before', '')
|
||||
return self._cert_info.get("not_before", "")
|
||||
|
||||
@property
|
||||
def valid_until(self) -> str:
|
||||
"""Get certificate validity end date."""
|
||||
return self._cert_info.get('not_after', '')
|
||||
return self._cert_info.get("not_after", "")
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> str:
|
||||
"""Get certificate fingerprint."""
|
||||
return self._cert_info.get('fingerprint', '')
|
||||
return self._cert_info.get("fingerprint", "")
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
class UserAgentGenerator:
|
||||
"""
|
||||
Generate random user agents with specified constraints.
|
||||
|
||||
|
||||
Attributes:
|
||||
desktop_platforms (dict): A dictionary of possible desktop platforms and their corresponding user agent strings.
|
||||
mobile_platforms (dict): A dictionary of possible mobile platforms and their corresponding user agent strings.
|
||||
@@ -18,7 +18,7 @@ class UserAgentGenerator:
|
||||
safari_versions (list): A list of possible Safari browser versions.
|
||||
ios_versions (list): A list of possible iOS browser versions.
|
||||
android_versions (list): A list of possible Android browser versions.
|
||||
|
||||
|
||||
Methods:
|
||||
generate_user_agent(
|
||||
platform: Literal["desktop", "mobile"] = "desktop",
|
||||
@@ -30,8 +30,9 @@ class UserAgentGenerator:
|
||||
safari_version: Optional[str] = None,
|
||||
ios_version: Optional[str] = None,
|
||||
android_version: Optional[str] = None
|
||||
): Generates a random user agent string based on the specified parameters.
|
||||
): Generates a random user agent string based on the specified parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Previous platform definitions remain the same...
|
||||
self.desktop_platforms = {
|
||||
@@ -47,7 +48,7 @@ class UserAgentGenerator:
|
||||
"generic": "(X11; Linux x86_64)",
|
||||
"ubuntu": "(X11; Ubuntu; Linux x86_64)",
|
||||
"chrome_os": "(X11; CrOS x86_64 14541.0.0)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
self.mobile_platforms = {
|
||||
@@ -60,26 +61,14 @@ class UserAgentGenerator:
|
||||
"ios": {
|
||||
"iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)",
|
||||
"ipad": "(iPad; CPU OS 16_5 like Mac OS X)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Browser Combinations
|
||||
self.browser_combinations = {
|
||||
1: [
|
||||
["chrome"],
|
||||
["firefox"],
|
||||
["safari"],
|
||||
["edge"]
|
||||
],
|
||||
2: [
|
||||
["gecko", "firefox"],
|
||||
["chrome", "safari"],
|
||||
["webkit", "safari"]
|
||||
],
|
||||
3: [
|
||||
["chrome", "safari", "edge"],
|
||||
["webkit", "chrome", "safari"]
|
||||
]
|
||||
1: [["chrome"], ["firefox"], ["safari"], ["edge"]],
|
||||
2: [["gecko", "firefox"], ["chrome", "safari"], ["webkit", "safari"]],
|
||||
3: [["chrome", "safari", "edge"], ["webkit", "chrome", "safari"]],
|
||||
}
|
||||
|
||||
# Rendering Engines with versions
|
||||
@@ -90,7 +79,7 @@ class UserAgentGenerator:
|
||||
"Gecko/20100101",
|
||||
"Gecko/20100101", # Firefox usually uses this constant version
|
||||
"Gecko/2010010",
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
# Browser Versions
|
||||
@@ -135,25 +124,25 @@ class UserAgentGenerator:
|
||||
def get_browser_stack(self, num_browsers: int = 1) -> List[str]:
|
||||
"""
|
||||
Get a valid combination of browser versions.
|
||||
|
||||
|
||||
How it works:
|
||||
1. Check if the number of browsers is supported.
|
||||
2. Randomly choose a combination of browsers.
|
||||
3. Iterate through the combination and add browser versions.
|
||||
4. Return the browser stack.
|
||||
|
||||
|
||||
Args:
|
||||
num_browsers: Number of browser specifications (1-3)
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: A list of browser versions.
|
||||
"""
|
||||
if num_browsers not in self.browser_combinations:
|
||||
raise ValueError(f"Unsupported number of browsers: {num_browsers}")
|
||||
|
||||
|
||||
combination = random.choice(self.browser_combinations[num_browsers])
|
||||
browser_stack = []
|
||||
|
||||
|
||||
for browser in combination:
|
||||
if browser == "chrome":
|
||||
browser_stack.append(random.choice(self.chrome_versions))
|
||||
@@ -167,18 +156,20 @@ class UserAgentGenerator:
|
||||
browser_stack.append(random.choice(self.rendering_engines["gecko"]))
|
||||
elif browser == "webkit":
|
||||
browser_stack.append(self.rendering_engines["chrome_webkit"])
|
||||
|
||||
|
||||
return browser_stack
|
||||
|
||||
def generate(self,
|
||||
device_type: Optional[Literal['desktop', 'mobile']] = None,
|
||||
os_type: Optional[str] = None,
|
||||
device_brand: Optional[str] = None,
|
||||
browser_type: Optional[Literal['chrome', 'edge', 'safari', 'firefox']] = None,
|
||||
num_browsers: int = 3) -> str:
|
||||
def generate(
|
||||
self,
|
||||
device_type: Optional[Literal["desktop", "mobile"]] = None,
|
||||
os_type: Optional[str] = None,
|
||||
device_brand: Optional[str] = None,
|
||||
browser_type: Optional[Literal["chrome", "edge", "safari", "firefox"]] = None,
|
||||
num_browsers: int = 3,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a random user agent with specified constraints.
|
||||
|
||||
|
||||
Args:
|
||||
device_type: 'desktop' or 'mobile'
|
||||
os_type: 'windows', 'macos', 'linux', 'android', 'ios'
|
||||
@@ -188,23 +179,23 @@ class UserAgentGenerator:
|
||||
"""
|
||||
# Get platform string
|
||||
platform = self.get_random_platform(device_type, os_type, device_brand)
|
||||
|
||||
|
||||
# Start with Mozilla
|
||||
components = ["Mozilla/5.0", platform]
|
||||
|
||||
|
||||
# Add browser stack
|
||||
browser_stack = self.get_browser_stack(num_browsers)
|
||||
|
||||
|
||||
# Add appropriate legacy token based on browser stack
|
||||
if "Firefox" in str(browser_stack):
|
||||
components.append(random.choice(self.rendering_engines["gecko"]))
|
||||
elif "Chrome" in str(browser_stack) or "Safari" in str(browser_stack):
|
||||
components.append(self.rendering_engines["chrome_webkit"])
|
||||
components.append("(KHTML, like Gecko)")
|
||||
|
||||
|
||||
# Add browser versions
|
||||
components.extend(browser_stack)
|
||||
|
||||
|
||||
return " ".join(components)
|
||||
|
||||
def generate_with_client_hints(self, **kwargs) -> Tuple[str, str]:
|
||||
@@ -215,16 +206,20 @@ class UserAgentGenerator:
|
||||
|
||||
def get_random_platform(self, device_type, os_type, device_brand):
|
||||
"""Helper method to get random platform based on constraints"""
|
||||
platforms = self.desktop_platforms if device_type == 'desktop' else \
|
||||
self.mobile_platforms if device_type == 'mobile' else \
|
||||
{**self.desktop_platforms, **self.mobile_platforms}
|
||||
|
||||
platforms = (
|
||||
self.desktop_platforms
|
||||
if device_type == "desktop"
|
||||
else self.mobile_platforms
|
||||
if device_type == "mobile"
|
||||
else {**self.desktop_platforms, **self.mobile_platforms}
|
||||
)
|
||||
|
||||
if os_type:
|
||||
for platform_group in [self.desktop_platforms, self.mobile_platforms]:
|
||||
if os_type in platform_group:
|
||||
platforms = {os_type: platform_group[os_type]}
|
||||
break
|
||||
|
||||
|
||||
os_key = random.choice(list(platforms.keys()))
|
||||
if device_brand and device_brand in platforms[os_key]:
|
||||
return platforms[os_key][device_brand]
|
||||
@@ -233,73 +228,72 @@ class UserAgentGenerator:
|
||||
def parse_user_agent(self, user_agent: str) -> Dict[str, str]:
|
||||
"""Parse a user agent string to extract browser and version information"""
|
||||
browsers = {
|
||||
'chrome': r'Chrome/(\d+)',
|
||||
'edge': r'Edg/(\d+)',
|
||||
'safari': r'Version/(\d+)',
|
||||
'firefox': r'Firefox/(\d+)'
|
||||
"chrome": r"Chrome/(\d+)",
|
||||
"edge": r"Edg/(\d+)",
|
||||
"safari": r"Version/(\d+)",
|
||||
"firefox": r"Firefox/(\d+)",
|
||||
}
|
||||
|
||||
|
||||
result = {}
|
||||
for browser, pattern in browsers.items():
|
||||
match = re.search(pattern, user_agent)
|
||||
if match:
|
||||
result[browser] = match.group(1)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def generate_client_hints(self, user_agent: str) -> str:
|
||||
"""Generate Sec-CH-UA header value based on user agent string"""
|
||||
browsers = self.parse_user_agent(user_agent)
|
||||
|
||||
|
||||
# Client hints components
|
||||
hints = []
|
||||
|
||||
|
||||
# Handle different browser combinations
|
||||
if 'chrome' in browsers:
|
||||
if "chrome" in browsers:
|
||||
hints.append(f'"Chromium";v="{browsers["chrome"]}"')
|
||||
hints.append('"Not_A Brand";v="8"')
|
||||
|
||||
if 'edge' in browsers:
|
||||
|
||||
if "edge" in browsers:
|
||||
hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"')
|
||||
else:
|
||||
hints.append(f'"Google Chrome";v="{browsers["chrome"]}"')
|
||||
|
||||
elif 'firefox' in browsers:
|
||||
|
||||
elif "firefox" in browsers:
|
||||
# Firefox doesn't typically send Sec-CH-UA
|
||||
return '""'
|
||||
|
||||
elif 'safari' in browsers:
|
||||
|
||||
elif "safari" in browsers:
|
||||
# Safari's format for client hints
|
||||
hints.append(f'"Safari";v="{browsers["safari"]}"')
|
||||
hints.append('"Not_A Brand";v="8"')
|
||||
|
||||
return ', '.join(hints)
|
||||
|
||||
return ", ".join(hints)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
generator = UserAgentGenerator()
|
||||
print(generator.generate())
|
||||
|
||||
|
||||
print("\nSingle browser (Chrome):")
|
||||
print(generator.generate(num_browsers=1, browser_type='chrome'))
|
||||
|
||||
print(generator.generate(num_browsers=1, browser_type="chrome"))
|
||||
|
||||
print("\nTwo browsers (Gecko/Firefox):")
|
||||
print(generator.generate(num_browsers=2))
|
||||
|
||||
|
||||
print("\nThree browsers (Chrome/Safari/Edge):")
|
||||
print(generator.generate(num_browsers=3))
|
||||
|
||||
|
||||
print("\nFirefox on Linux:")
|
||||
print(generator.generate(
|
||||
device_type='desktop',
|
||||
os_type='linux',
|
||||
browser_type='firefox',
|
||||
num_browsers=2
|
||||
))
|
||||
|
||||
print(
|
||||
generator.generate(
|
||||
device_type="desktop",
|
||||
os_type="linux",
|
||||
browser_type="firefox",
|
||||
num_browsers=2,
|
||||
)
|
||||
)
|
||||
|
||||
print("\nChrome/Safari/Edge on Windows:")
|
||||
print(generator.generate(
|
||||
device_type='desktop',
|
||||
os_type='windows',
|
||||
num_browsers=3
|
||||
))
|
||||
print(generator.generate(device_type="desktop", os_type="windows", num_browsers=3))
|
||||
|
||||
1360
crawl4ai/utils.py
1360
crawl4ai/utils.py
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,14 @@
|
||||
# version_manager.py
|
||||
import os
|
||||
from pathlib import Path
|
||||
from packaging import version
|
||||
from . import __version__
|
||||
|
||||
|
||||
class VersionManager:
|
||||
def __init__(self):
|
||||
self.home_dir = Path.home() / ".crawl4ai"
|
||||
self.version_file = self.home_dir / "version.txt"
|
||||
|
||||
|
||||
def get_installed_version(self):
|
||||
"""Get the version recorded in home directory"""
|
||||
if not self.version_file.exists():
|
||||
@@ -17,14 +17,13 @@ class VersionManager:
|
||||
return version.parse(self.version_file.read_text().strip())
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def update_version(self):
|
||||
"""Update the version file to current library version"""
|
||||
self.version_file.write_text(__version__.__version__)
|
||||
|
||||
|
||||
def needs_update(self):
|
||||
"""Check if database needs update based on version"""
|
||||
installed = self.get_installed_version()
|
||||
current = version.parse(__version__.__version__)
|
||||
return installed is None or installed < current
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os, time
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
from pathlib import Path
|
||||
|
||||
from .models import UrlModel, CrawlResult
|
||||
from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db
|
||||
from .database import init_db, get_cached_url, cache_url
|
||||
from .utils import *
|
||||
from .chunking_strategy import *
|
||||
from .extraction_strategy import *
|
||||
@@ -14,31 +15,44 @@ from .content_scraping_strategy import WebScrapingStrategy
|
||||
from .config import *
|
||||
import warnings
|
||||
import json
|
||||
warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".')
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message='Field "model_name" has conflict with protected namespace "model_".',
|
||||
)
|
||||
|
||||
|
||||
class WebCrawler:
|
||||
def __init__(self, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False):
|
||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
|
||||
def __init__(
|
||||
self,
|
||||
crawler_strategy: CrawlerStrategy = None,
|
||||
always_by_pass_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(
|
||||
verbose=verbose
|
||||
)
|
||||
self.always_by_pass_cache = always_by_pass_cache
|
||||
self.crawl4ai_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
|
||||
self.crawl4ai_folder = os.path.join(
|
||||
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
|
||||
)
|
||||
os.makedirs(self.crawl4ai_folder, exist_ok=True)
|
||||
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
|
||||
init_db()
|
||||
self.ready = False
|
||||
|
||||
|
||||
def warmup(self):
|
||||
print("[LOG] 🌤️ Warming up the WebCrawler")
|
||||
self.run(
|
||||
url='https://google.com/',
|
||||
url="https://google.com/",
|
||||
word_count_threshold=5,
|
||||
extraction_strategy=NoExtractionStrategy(),
|
||||
bypass_cache=False,
|
||||
verbose=False
|
||||
verbose=False,
|
||||
)
|
||||
self.ready = True
|
||||
print("[LOG] 🌞 WebCrawler is ready to crawl")
|
||||
|
||||
|
||||
def fetch_page(
|
||||
self,
|
||||
url_model: UrlModel,
|
||||
@@ -80,6 +94,7 @@ class WebCrawler:
|
||||
**kwargs,
|
||||
) -> List[CrawlResult]:
|
||||
extraction_strategy = extraction_strategy or NoExtractionStrategy()
|
||||
|
||||
def fetch_page_wrapper(url_model, *args, **kwargs):
|
||||
return self.fetch_page(url_model, *args, **kwargs)
|
||||
|
||||
@@ -104,150 +119,176 @@ class WebCrawler:
|
||||
return results
|
||||
|
||||
def run(
|
||||
self,
|
||||
url: str,
|
||||
word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
extraction_strategy: ExtractionStrategy = None,
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
bypass_cache: bool = False,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
user_agent: str = None,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
try:
|
||||
extraction_strategy = extraction_strategy or NoExtractionStrategy()
|
||||
extraction_strategy.verbose = verbose
|
||||
if not isinstance(extraction_strategy, ExtractionStrategy):
|
||||
raise ValueError("Unsupported extraction strategy")
|
||||
if not isinstance(chunking_strategy, ChunkingStrategy):
|
||||
raise ValueError("Unsupported chunking strategy")
|
||||
|
||||
word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)
|
||||
self,
|
||||
url: str,
|
||||
word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
extraction_strategy: ExtractionStrategy = None,
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
bypass_cache: bool = False,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
user_agent: str = None,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
try:
|
||||
extraction_strategy = extraction_strategy or NoExtractionStrategy()
|
||||
extraction_strategy.verbose = verbose
|
||||
if not isinstance(extraction_strategy, ExtractionStrategy):
|
||||
raise ValueError("Unsupported extraction strategy")
|
||||
if not isinstance(chunking_strategy, ChunkingStrategy):
|
||||
raise ValueError("Unsupported chunking strategy")
|
||||
|
||||
cached = None
|
||||
screenshot_data = None
|
||||
extracted_content = None
|
||||
if not bypass_cache and not self.always_by_pass_cache:
|
||||
cached = get_cached_url(url)
|
||||
|
||||
if kwargs.get("warmup", True) and not self.ready:
|
||||
return None
|
||||
|
||||
if cached:
|
||||
html = sanitize_input_encode(cached[1])
|
||||
extracted_content = sanitize_input_encode(cached[4])
|
||||
if screenshot:
|
||||
screenshot_data = cached[9]
|
||||
if not screenshot_data:
|
||||
cached = None
|
||||
|
||||
if not cached or not html:
|
||||
if user_agent:
|
||||
self.crawler_strategy.update_user_agent(user_agent)
|
||||
t1 = time.time()
|
||||
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
|
||||
t2 = time.time()
|
||||
if verbose:
|
||||
print(f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds")
|
||||
if screenshot:
|
||||
screenshot_data = self.crawler_strategy.take_screenshot()
|
||||
word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)
|
||||
|
||||
|
||||
crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs)
|
||||
crawl_result.success = bool(html)
|
||||
return crawl_result
|
||||
except Exception as e:
|
||||
if not hasattr(e, "msg"):
|
||||
e.msg = str(e)
|
||||
print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}")
|
||||
return CrawlResult(url=url, html="", success=False, error_message=e.msg)
|
||||
cached = None
|
||||
screenshot_data = None
|
||||
extracted_content = None
|
||||
if not bypass_cache and not self.always_by_pass_cache:
|
||||
cached = get_cached_url(url)
|
||||
|
||||
if kwargs.get("warmup", True) and not self.ready:
|
||||
return None
|
||||
|
||||
if cached:
|
||||
html = sanitize_input_encode(cached[1])
|
||||
extracted_content = sanitize_input_encode(cached[4])
|
||||
if screenshot:
|
||||
screenshot_data = cached[9]
|
||||
if not screenshot_data:
|
||||
cached = None
|
||||
|
||||
if not cached or not html:
|
||||
if user_agent:
|
||||
self.crawler_strategy.update_user_agent(user_agent)
|
||||
t1 = time.time()
|
||||
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
|
||||
t2 = time.time()
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
|
||||
)
|
||||
if screenshot:
|
||||
screenshot_data = self.crawler_strategy.take_screenshot()
|
||||
|
||||
crawl_result = self.process_html(
|
||||
url,
|
||||
html,
|
||||
extracted_content,
|
||||
word_count_threshold,
|
||||
extraction_strategy,
|
||||
chunking_strategy,
|
||||
css_selector,
|
||||
screenshot_data,
|
||||
verbose,
|
||||
bool(cached),
|
||||
**kwargs,
|
||||
)
|
||||
crawl_result.success = bool(html)
|
||||
return crawl_result
|
||||
except Exception as e:
|
||||
if not hasattr(e, "msg"):
|
||||
e.msg = str(e)
|
||||
print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}")
|
||||
return CrawlResult(url=url, html="", success=False, error_message=e.msg)
|
||||
|
||||
def process_html(
|
||||
self,
|
||||
url: str,
|
||||
html: str,
|
||||
extracted_content: str,
|
||||
word_count_threshold: int,
|
||||
extraction_strategy: ExtractionStrategy,
|
||||
chunking_strategy: ChunkingStrategy,
|
||||
css_selector: str,
|
||||
screenshot: bool,
|
||||
verbose: bool,
|
||||
is_cached: bool,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
t = time.time()
|
||||
# Extract content from HTML
|
||||
try:
|
||||
t1 = time.time()
|
||||
scrapping_strategy = WebScrapingStrategy()
|
||||
extra_params = {k: v for k, v in kwargs.items() if k not in ["only_text", "image_description_min_word_threshold"]}
|
||||
result = scrapping_strategy.scrap(
|
||||
url,
|
||||
html,
|
||||
word_count_threshold=word_count_threshold,
|
||||
css_selector=css_selector,
|
||||
only_text=kwargs.get("only_text", False),
|
||||
image_description_min_word_threshold=kwargs.get(
|
||||
"image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
|
||||
),
|
||||
**extra_params,
|
||||
self,
|
||||
url: str,
|
||||
html: str,
|
||||
extracted_content: str,
|
||||
word_count_threshold: int,
|
||||
extraction_strategy: ExtractionStrategy,
|
||||
chunking_strategy: ChunkingStrategy,
|
||||
css_selector: str,
|
||||
screenshot: bool,
|
||||
verbose: bool,
|
||||
is_cached: bool,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
t = time.time()
|
||||
# Extract content from HTML
|
||||
try:
|
||||
t1 = time.time()
|
||||
scrapping_strategy = WebScrapingStrategy()
|
||||
extra_params = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["only_text", "image_description_min_word_threshold"]
|
||||
}
|
||||
result = scrapping_strategy.scrap(
|
||||
url,
|
||||
html,
|
||||
word_count_threshold=word_count_threshold,
|
||||
css_selector=css_selector,
|
||||
only_text=kwargs.get("only_text", False),
|
||||
image_description_min_word_threshold=kwargs.get(
|
||||
"image_description_min_word_threshold",
|
||||
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||
),
|
||||
**extra_params,
|
||||
)
|
||||
|
||||
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds"
|
||||
)
|
||||
|
||||
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
|
||||
if verbose:
|
||||
print(f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds")
|
||||
|
||||
if result is None:
|
||||
raise ValueError(f"Failed to extract content from the website: {url}")
|
||||
except InvalidCSSSelectorError as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
cleaned_html = sanitize_input_encode(result.get("cleaned_html", ""))
|
||||
markdown = sanitize_input_encode(result.get("markdown", ""))
|
||||
media = result.get("media", [])
|
||||
links = result.get("links", [])
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
if extracted_content is None:
|
||||
if verbose:
|
||||
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}")
|
||||
|
||||
sections = chunking_strategy.chunk(markdown)
|
||||
extracted_content = extraction_strategy.run(url, sections)
|
||||
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
|
||||
if result is None:
|
||||
raise ValueError(f"Failed to extract content from the website: {url}")
|
||||
except InvalidCSSSelectorError as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
if verbose:
|
||||
print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds.")
|
||||
|
||||
screenshot = None if not screenshot else screenshot
|
||||
|
||||
if not is_cached:
|
||||
cache_url(
|
||||
url,
|
||||
html,
|
||||
cleaned_html,
|
||||
markdown,
|
||||
extracted_content,
|
||||
True,
|
||||
json.dumps(media),
|
||||
json.dumps(links),
|
||||
json.dumps(metadata),
|
||||
screenshot=screenshot,
|
||||
)
|
||||
|
||||
return CrawlResult(
|
||||
url=url,
|
||||
html=html,
|
||||
cleaned_html=format_html(cleaned_html),
|
||||
markdown=markdown,
|
||||
media=media,
|
||||
links=links,
|
||||
metadata=metadata,
|
||||
cleaned_html = sanitize_input_encode(result.get("cleaned_html", ""))
|
||||
markdown = sanitize_input_encode(result.get("markdown", ""))
|
||||
media = result.get("media", [])
|
||||
links = result.get("links", [])
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
if extracted_content is None:
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}"
|
||||
)
|
||||
|
||||
sections = chunking_strategy.chunk(markdown)
|
||||
extracted_content = extraction_strategy.run(url, sections)
|
||||
extracted_content = json.dumps(
|
||||
extracted_content, indent=4, default=str, ensure_ascii=False
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds."
|
||||
)
|
||||
|
||||
screenshot = None if not screenshot else screenshot
|
||||
|
||||
if not is_cached:
|
||||
cache_url(
|
||||
url,
|
||||
html,
|
||||
cleaned_html,
|
||||
markdown,
|
||||
extracted_content,
|
||||
True,
|
||||
json.dumps(media),
|
||||
json.dumps(links),
|
||||
json.dumps(metadata),
|
||||
screenshot=screenshot,
|
||||
extracted_content=extracted_content,
|
||||
success=True,
|
||||
error_message="",
|
||||
)
|
||||
)
|
||||
|
||||
return CrawlResult(
|
||||
url=url,
|
||||
html=html,
|
||||
cleaned_html=format_html(cleaned_html),
|
||||
markdown=markdown,
|
||||
media=media,
|
||||
links=links,
|
||||
metadata=metadata,
|
||||
screenshot=screenshot,
|
||||
extracted_content=extracted_content,
|
||||
success=True,
|
||||
error_message="",
|
||||
)
|
||||
|
||||
@@ -9,13 +9,11 @@ from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
||||
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||
import json
|
||||
|
||||
|
||||
async def extract_amazon_products():
|
||||
# Initialize browser config
|
||||
browser_config = BrowserConfig(
|
||||
browser_type="chromium",
|
||||
headless=True
|
||||
)
|
||||
|
||||
browser_config = BrowserConfig(browser_type="chromium", headless=True)
|
||||
|
||||
# Initialize crawler config with JSON CSS extraction strategy
|
||||
crawler_config = CrawlerRunConfig(
|
||||
extraction_strategy=JsonCssExtractionStrategy(
|
||||
@@ -27,74 +25,70 @@ async def extract_amazon_products():
|
||||
"name": "asin",
|
||||
"selector": "",
|
||||
"type": "attribute",
|
||||
"attribute": "data-asin"
|
||||
},
|
||||
{
|
||||
"name": "title",
|
||||
"selector": "h2 a span",
|
||||
"type": "text"
|
||||
"attribute": "data-asin",
|
||||
},
|
||||
{"name": "title", "selector": "h2 a span", "type": "text"},
|
||||
{
|
||||
"name": "url",
|
||||
"selector": "h2 a",
|
||||
"type": "attribute",
|
||||
"attribute": "href"
|
||||
"attribute": "href",
|
||||
},
|
||||
{
|
||||
"name": "image",
|
||||
"selector": ".s-image",
|
||||
"type": "attribute",
|
||||
"attribute": "src"
|
||||
"attribute": "src",
|
||||
},
|
||||
{
|
||||
"name": "rating",
|
||||
"selector": ".a-icon-star-small .a-icon-alt",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "reviews_count",
|
||||
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "price",
|
||||
"selector": ".a-price .a-offscreen",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "original_price",
|
||||
"selector": ".a-price.a-text-price .a-offscreen",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "sponsored",
|
||||
"selector": ".puis-sponsored-label-text",
|
||||
"type": "exists"
|
||||
"type": "exists",
|
||||
},
|
||||
{
|
||||
"name": "delivery_info",
|
||||
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
||||
"type": "text",
|
||||
"multiple": True
|
||||
}
|
||||
]
|
||||
"multiple": True,
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Example search URL (you should replace with your actual Amazon URL)
|
||||
url = "https://www.amazon.com/s?k=Samsung+Galaxy+Tab"
|
||||
|
||||
|
||||
# Use context manager for proper resource handling
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
# Extract the data
|
||||
result = await crawler.arun(url=url, config=crawler_config)
|
||||
|
||||
|
||||
# Process and print the results
|
||||
if result and result.extracted_content:
|
||||
# Parse the JSON string into a list of products
|
||||
products = json.loads(result.extracted_content)
|
||||
|
||||
|
||||
# Process each product in the list
|
||||
for product in products:
|
||||
print("\nProduct Details:")
|
||||
@@ -105,10 +99,12 @@ async def extract_amazon_products():
|
||||
print(f"Rating: {product.get('rating')}")
|
||||
print(f"Reviews: {product.get('reviews_count')}")
|
||||
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
||||
if product.get('delivery_info'):
|
||||
if product.get("delivery_info"):
|
||||
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(extract_amazon_products())
|
||||
|
||||
@@ -10,17 +10,17 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||
import json
|
||||
from playwright.async_api import Page, BrowserContext
|
||||
|
||||
|
||||
async def extract_amazon_products():
|
||||
# Initialize browser config
|
||||
browser_config = BrowserConfig(
|
||||
# browser_type="chromium",
|
||||
headless=True
|
||||
)
|
||||
|
||||
|
||||
# Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
|
||||
extraction_strategy=JsonCssExtractionStrategy(
|
||||
schema={
|
||||
"name": "Amazon Product Search Results",
|
||||
@@ -30,102 +30,105 @@ async def extract_amazon_products():
|
||||
"name": "asin",
|
||||
"selector": "",
|
||||
"type": "attribute",
|
||||
"attribute": "data-asin"
|
||||
},
|
||||
{
|
||||
"name": "title",
|
||||
"selector": "h2 a span",
|
||||
"type": "text"
|
||||
"attribute": "data-asin",
|
||||
},
|
||||
{"name": "title", "selector": "h2 a span", "type": "text"},
|
||||
{
|
||||
"name": "url",
|
||||
"selector": "h2 a",
|
||||
"type": "attribute",
|
||||
"attribute": "href"
|
||||
"attribute": "href",
|
||||
},
|
||||
{
|
||||
"name": "image",
|
||||
"selector": ".s-image",
|
||||
"type": "attribute",
|
||||
"attribute": "src"
|
||||
"attribute": "src",
|
||||
},
|
||||
{
|
||||
"name": "rating",
|
||||
"selector": ".a-icon-star-small .a-icon-alt",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "reviews_count",
|
||||
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "price",
|
||||
"selector": ".a-price .a-offscreen",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "original_price",
|
||||
"selector": ".a-price.a-text-price .a-offscreen",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "sponsored",
|
||||
"selector": ".puis-sponsored-label-text",
|
||||
"type": "exists"
|
||||
"type": "exists",
|
||||
},
|
||||
{
|
||||
"name": "delivery_info",
|
||||
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
||||
"type": "text",
|
||||
"multiple": True
|
||||
}
|
||||
]
|
||||
"multiple": True,
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
url = "https://www.amazon.com/"
|
||||
|
||||
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs):
|
||||
|
||||
async def after_goto(
|
||||
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
|
||||
):
|
||||
"""Hook called after navigating to each URL"""
|
||||
print(f"[HOOK] after_goto - Successfully loaded: {url}")
|
||||
|
||||
|
||||
try:
|
||||
# Wait for search box to be available
|
||||
search_box = await page.wait_for_selector('#twotabsearchtextbox', timeout=1000)
|
||||
|
||||
search_box = await page.wait_for_selector(
|
||||
"#twotabsearchtextbox", timeout=1000
|
||||
)
|
||||
|
||||
# Type the search query
|
||||
await search_box.fill('Samsung Galaxy Tab')
|
||||
|
||||
await search_box.fill("Samsung Galaxy Tab")
|
||||
|
||||
# Get the search button and prepare for navigation
|
||||
search_button = await page.wait_for_selector('#nav-search-submit-button', timeout=1000)
|
||||
|
||||
search_button = await page.wait_for_selector(
|
||||
"#nav-search-submit-button", timeout=1000
|
||||
)
|
||||
|
||||
# Click with navigation waiting
|
||||
await search_button.click()
|
||||
|
||||
|
||||
# Wait for search results to load
|
||||
await page.wait_for_selector('[data-component-type="s-search-result"]', timeout=10000)
|
||||
await page.wait_for_selector(
|
||||
'[data-component-type="s-search-result"]', timeout=10000
|
||||
)
|
||||
print("[HOOK] Search completed and results loaded!")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[HOOK] Error during search operation: {str(e)}")
|
||||
|
||||
return page
|
||||
|
||||
|
||||
return page
|
||||
|
||||
# Use context manager for proper resource handling
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
|
||||
crawler.crawler_strategy.set_hook("after_goto", after_goto)
|
||||
|
||||
|
||||
# Extract the data
|
||||
result = await crawler.arun(url=url, config=crawler_config)
|
||||
|
||||
|
||||
# Process and print the results
|
||||
if result and result.extracted_content:
|
||||
# Parse the JSON string into a list of products
|
||||
products = json.loads(result.extracted_content)
|
||||
|
||||
|
||||
# Process each product in the list
|
||||
for product in products:
|
||||
print("\nProduct Details:")
|
||||
@@ -136,10 +139,12 @@ async def extract_amazon_products():
|
||||
print(f"Rating: {product.get('rating')}")
|
||||
print(f"Reviews: {product.get('reviews_count')}")
|
||||
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
||||
if product.get('delivery_info'):
|
||||
if product.get("delivery_info"):
|
||||
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(extract_amazon_products())
|
||||
|
||||
@@ -8,7 +8,7 @@ from crawl4ai import AsyncWebCrawler, CacheMode
|
||||
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
||||
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||
import json
|
||||
from playwright.async_api import Page, BrowserContext
|
||||
|
||||
|
||||
async def extract_amazon_products():
|
||||
# Initialize browser config
|
||||
@@ -16,7 +16,7 @@ async def extract_amazon_products():
|
||||
# browser_type="chromium",
|
||||
headless=True
|
||||
)
|
||||
|
||||
|
||||
js_code_to_search = """
|
||||
const task = async () => {
|
||||
document.querySelector('#twotabsearchtextbox').value = 'Samsung Galaxy Tab';
|
||||
@@ -30,7 +30,7 @@ async def extract_amazon_products():
|
||||
"""
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
js_code = js_code_to_search,
|
||||
js_code=js_code_to_search,
|
||||
wait_for='css:[data-component-type="s-search-result"]',
|
||||
extraction_strategy=JsonCssExtractionStrategy(
|
||||
schema={
|
||||
@@ -41,75 +41,70 @@ async def extract_amazon_products():
|
||||
"name": "asin",
|
||||
"selector": "",
|
||||
"type": "attribute",
|
||||
"attribute": "data-asin"
|
||||
},
|
||||
{
|
||||
"name": "title",
|
||||
"selector": "h2 a span",
|
||||
"type": "text"
|
||||
"attribute": "data-asin",
|
||||
},
|
||||
{"name": "title", "selector": "h2 a span", "type": "text"},
|
||||
{
|
||||
"name": "url",
|
||||
"selector": "h2 a",
|
||||
"type": "attribute",
|
||||
"attribute": "href"
|
||||
"attribute": "href",
|
||||
},
|
||||
{
|
||||
"name": "image",
|
||||
"selector": ".s-image",
|
||||
"type": "attribute",
|
||||
"attribute": "src"
|
||||
"attribute": "src",
|
||||
},
|
||||
{
|
||||
"name": "rating",
|
||||
"selector": ".a-icon-star-small .a-icon-alt",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "reviews_count",
|
||||
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "price",
|
||||
"selector": ".a-price .a-offscreen",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "original_price",
|
||||
"selector": ".a-price.a-text-price .a-offscreen",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "sponsored",
|
||||
"selector": ".puis-sponsored-label-text",
|
||||
"type": "exists"
|
||||
"type": "exists",
|
||||
},
|
||||
{
|
||||
"name": "delivery_info",
|
||||
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
||||
"type": "text",
|
||||
"multiple": True
|
||||
}
|
||||
]
|
||||
"multiple": True,
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Example search URL (you should replace with your actual Amazon URL)
|
||||
url = "https://www.amazon.com/"
|
||||
|
||||
|
||||
|
||||
# Use context manager for proper resource handling
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
# Extract the data
|
||||
result = await crawler.arun(url=url, config=crawler_config)
|
||||
|
||||
|
||||
# Process and print the results
|
||||
if result and result.extracted_content:
|
||||
# Parse the JSON string into a list of products
|
||||
products = json.loads(result.extracted_content)
|
||||
|
||||
|
||||
# Process each product in the list
|
||||
for product in products:
|
||||
print("\nProduct Details:")
|
||||
@@ -120,10 +115,12 @@ async def extract_amazon_products():
|
||||
print(f"Rating: {product.get('rating')}")
|
||||
print(f"Reviews: {product.get('reviews_count')}")
|
||||
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
||||
if product.get('delivery_info'):
|
||||
if product.get("delivery_info"):
|
||||
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(extract_amazon_products())
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
# File: async_webcrawler_multiple_urls_example.py
|
||||
import os, sys
|
||||
|
||||
# append 2 parent directories to sys.path to import crawl4ai
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parent_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
|
||||
|
||||
async def main():
|
||||
# Initialize the AsyncWebCrawler
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -16,7 +20,7 @@ async def main():
|
||||
"https://python.org",
|
||||
"https://github.com",
|
||||
"https://stackoverflow.com",
|
||||
"https://news.ycombinator.com"
|
||||
"https://news.ycombinator.com",
|
||||
]
|
||||
|
||||
# Set up crawling parameters
|
||||
@@ -27,7 +31,7 @@ async def main():
|
||||
urls=urls,
|
||||
word_count_threshold=word_count_threshold,
|
||||
bypass_cache=True,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Process the results
|
||||
@@ -36,7 +40,9 @@ async def main():
|
||||
print(f"Successfully crawled: {result.url}")
|
||||
print(f"Title: {result.metadata.get('title', 'N/A')}")
|
||||
print(f"Word count: {len(result.markdown.split())}")
|
||||
print(f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}")
|
||||
print(
|
||||
f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}"
|
||||
)
|
||||
print(f"Number of images: {len(result.media.get('images', []))}")
|
||||
print("---")
|
||||
else:
|
||||
@@ -44,5 +50,6 @@ async def main():
|
||||
print(f"Error: {result.error_message}")
|
||||
print("---")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -6,10 +6,8 @@ This example demonstrates optimal browser usage patterns in Crawl4AI:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
|
||||
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
import os, time
|
||||
|
||||
# append the path to the root of the project
|
||||
import sys
|
||||
import asyncio
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
from firecrawl import FirecrawlApp
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data'
|
||||
|
||||
__data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data"
|
||||
|
||||
|
||||
async def compare():
|
||||
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
|
||||
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
|
||||
|
||||
# Tet Firecrawl with a simple crawl
|
||||
start = time.time()
|
||||
scrape_status = app.scrape_url(
|
||||
'https://www.nbcnews.com/business',
|
||||
params={'formats': ['markdown', 'html']}
|
||||
"https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Time taken: {end - start} seconds")
|
||||
print(len(scrape_status['markdown']))
|
||||
print(len(scrape_status["markdown"]))
|
||||
# save the markdown content with provider name
|
||||
with open(f"{__data__}/firecrawl_simple.md", "w") as f:
|
||||
f.write(scrape_status['markdown'])
|
||||
f.write(scrape_status["markdown"])
|
||||
# Count how many "cldnry.s-nbcnews.com" are in the markdown
|
||||
print(scrape_status['markdown'].count("cldnry.s-nbcnews.com"))
|
||||
|
||||
|
||||
print(scrape_status["markdown"].count("cldnry.s-nbcnews.com"))
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
start = time.time()
|
||||
@@ -33,13 +34,13 @@ async def compare():
|
||||
url="https://www.nbcnews.com/business",
|
||||
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
|
||||
word_count_threshold=0,
|
||||
bypass_cache=True,
|
||||
verbose=False
|
||||
bypass_cache=True,
|
||||
verbose=False,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Time taken: {end - start} seconds")
|
||||
print(len(result.markdown))
|
||||
# save the markdown content with provider name
|
||||
# save the markdown content with provider name
|
||||
with open(f"{__data__}/crawl4ai_simple.md", "w") as f:
|
||||
f.write(result.markdown)
|
||||
# count how many "cldnry.s-nbcnews.com" are in the markdown
|
||||
@@ -48,10 +49,12 @@ async def compare():
|
||||
start = time.time()
|
||||
result = await crawler.arun(
|
||||
url="https://www.nbcnews.com/business",
|
||||
js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
|
||||
js_code=[
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
],
|
||||
word_count_threshold=0,
|
||||
bypass_cache=True,
|
||||
verbose=False
|
||||
bypass_cache=True,
|
||||
verbose=False,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Time taken: {end - start} seconds")
|
||||
@@ -61,7 +64,7 @@ async def compare():
|
||||
f.write(result.markdown)
|
||||
# count how many "cldnry.s-nbcnews.com" are in the markdown
|
||||
print(result.markdown.count("cldnry.s-nbcnews.com"))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(compare())
|
||||
|
||||
@@ -3,11 +3,18 @@ import time
|
||||
from rich import print
|
||||
from rich.table import Table
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig,
|
||||
MemoryAdaptiveDispatcher, SemaphoreDispatcher,
|
||||
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
MemoryAdaptiveDispatcher,
|
||||
SemaphoreDispatcher,
|
||||
RateLimiter,
|
||||
CrawlerMonitor,
|
||||
DisplayMode,
|
||||
CacheMode,
|
||||
)
|
||||
|
||||
|
||||
async def memory_adaptive(urls, browser_config, run_config):
|
||||
"""Memory adaptive crawler with monitoring"""
|
||||
start = time.perf_counter()
|
||||
@@ -16,14 +23,16 @@ async def memory_adaptive(urls, browser_config, run_config):
|
||||
memory_threshold_percent=70.0,
|
||||
max_session_permit=10,
|
||||
monitor=CrawlerMonitor(
|
||||
max_visible_rows=15,
|
||||
display_mode=DisplayMode.DETAILED
|
||||
)
|
||||
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||
),
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
||||
duration = time.perf_counter() - start
|
||||
return len(results), duration
|
||||
|
||||
|
||||
async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
|
||||
"""Memory adaptive crawler with rate limiting"""
|
||||
start = time.perf_counter()
|
||||
@@ -32,19 +41,19 @@ async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
|
||||
memory_threshold_percent=70.0,
|
||||
max_session_permit=10,
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=(1.0, 2.0),
|
||||
max_delay=30.0,
|
||||
max_retries=2
|
||||
base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
|
||||
),
|
||||
monitor=CrawlerMonitor(
|
||||
max_visible_rows=15,
|
||||
display_mode=DisplayMode.DETAILED
|
||||
)
|
||||
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||
),
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
||||
duration = time.perf_counter() - start
|
||||
return len(results), duration
|
||||
|
||||
|
||||
async def semaphore(urls, browser_config, run_config):
|
||||
"""Basic semaphore crawler"""
|
||||
start = time.perf_counter()
|
||||
@@ -52,14 +61,16 @@ async def semaphore(urls, browser_config, run_config):
|
||||
dispatcher = SemaphoreDispatcher(
|
||||
semaphore_count=5,
|
||||
monitor=CrawlerMonitor(
|
||||
max_visible_rows=15,
|
||||
display_mode=DisplayMode.DETAILED
|
||||
)
|
||||
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||
),
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
||||
duration = time.perf_counter() - start
|
||||
return len(results), duration
|
||||
|
||||
|
||||
async def semaphore_with_rate_limit(urls, browser_config, run_config):
|
||||
"""Semaphore crawler with rate limiting"""
|
||||
start = time.perf_counter()
|
||||
@@ -67,19 +78,19 @@ async def semaphore_with_rate_limit(urls, browser_config, run_config):
|
||||
dispatcher = SemaphoreDispatcher(
|
||||
semaphore_count=5,
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=(1.0, 2.0),
|
||||
max_delay=30.0,
|
||||
max_retries=2
|
||||
base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
|
||||
),
|
||||
monitor=CrawlerMonitor(
|
||||
max_visible_rows=15,
|
||||
display_mode=DisplayMode.DETAILED
|
||||
)
|
||||
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||
),
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
||||
duration = time.perf_counter() - start
|
||||
return len(results), duration
|
||||
|
||||
|
||||
def create_performance_table(results):
|
||||
"""Creates a rich table showing performance results"""
|
||||
table = Table(title="Crawler Strategy Performance Comparison")
|
||||
@@ -89,18 +100,16 @@ def create_performance_table(results):
|
||||
table.add_column("URLs/second", justify="right", style="magenta")
|
||||
|
||||
sorted_results = sorted(results.items(), key=lambda x: x[1][1])
|
||||
|
||||
|
||||
for strategy, (urls_crawled, duration) in sorted_results:
|
||||
urls_per_second = urls_crawled / duration
|
||||
table.add_row(
|
||||
strategy,
|
||||
str(urls_crawled),
|
||||
f"{duration:.2f}",
|
||||
f"{urls_per_second:.2f}"
|
||||
strategy, str(urls_crawled), f"{duration:.2f}", f"{urls_per_second:.2f}"
|
||||
)
|
||||
|
||||
|
||||
return table
|
||||
|
||||
|
||||
async def main():
|
||||
urls = [f"https://example.com/page{i}" for i in range(1, 20)]
|
||||
browser_config = BrowserConfig(headless=True, verbose=False)
|
||||
@@ -108,14 +117,19 @@ async def main():
|
||||
|
||||
results = {
|
||||
"Memory Adaptive": await memory_adaptive(urls, browser_config, run_config),
|
||||
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(urls, browser_config, run_config),
|
||||
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(
|
||||
urls, browser_config, run_config
|
||||
),
|
||||
"Semaphore": await semaphore(urls, browser_config, run_config),
|
||||
"Semaphore + Rate Limit": await semaphore_with_rate_limit(urls, browser_config, run_config),
|
||||
"Semaphore + Rate Limit": await semaphore_with_rate_limit(
|
||||
urls, browser_config, run_config
|
||||
),
|
||||
}
|
||||
|
||||
table = create_performance_table(results)
|
||||
print("\nPerformance Summary:")
|
||||
print(table)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -6,63 +6,80 @@ import base64
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class Crawl4AiTester:
|
||||
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
|
||||
self.base_url = base_url
|
||||
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" # Check environment variable as fallback
|
||||
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {}
|
||||
|
||||
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
|
||||
self.api_token = (
|
||||
api_token or os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
|
||||
) # Check environment variable as fallback
|
||||
self.headers = (
|
||||
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
|
||||
)
|
||||
|
||||
def submit_and_wait(
|
||||
self, request_data: Dict[str, Any], timeout: int = 300
|
||||
) -> Dict[str, Any]:
|
||||
# Submit crawl job
|
||||
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/crawl", json=request_data, headers=self.headers
|
||||
)
|
||||
if response.status_code == 403:
|
||||
raise Exception("API token is invalid or missing")
|
||||
task_id = response.json()["task_id"]
|
||||
print(f"Task ID: {task_id}")
|
||||
|
||||
|
||||
# Poll for result
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
||||
|
||||
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers)
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
result = requests.get(
|
||||
f"{self.base_url}/task/{task_id}", headers=self.headers
|
||||
)
|
||||
status = result.json()
|
||||
|
||||
|
||||
if status["status"] == "failed":
|
||||
print("Task failed:", status.get("error"))
|
||||
raise Exception(f"Task failed: {status.get('error')}")
|
||||
|
||||
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/crawl_sync",
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code == 408:
|
||||
raise TimeoutError("Task did not complete within server timeout")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Directly crawl without using task queue"""
|
||||
response = requests.post(
|
||||
f"{self.base_url}/crawl_direct",
|
||||
json=request_data,
|
||||
headers=self.headers
|
||||
f"{self.base_url}/crawl_direct", json=request_data, headers=self.headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_docker_deployment(version="basic"):
|
||||
tester = Crawl4AiTester(
|
||||
base_url="http://localhost:11235" ,
|
||||
base_url="http://localhost:11235",
|
||||
# base_url="https://api.crawl4ai.com" # just for example
|
||||
# api_token="test" # just for example
|
||||
)
|
||||
print(f"Testing Crawl4AI Docker {version} version")
|
||||
|
||||
|
||||
# Health check with timeout and retry
|
||||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
@@ -70,19 +87,19 @@ def test_docker_deployment(version="basic"):
|
||||
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
||||
print("Health check:", health.json())
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.exceptions.RequestException:
|
||||
if i == max_retries - 1:
|
||||
print(f"Failed to connect after {max_retries} attempts")
|
||||
sys.exit(1)
|
||||
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# Test cases based on version
|
||||
test_basic_crawl_direct(tester)
|
||||
test_basic_crawl(tester)
|
||||
test_basic_crawl(tester)
|
||||
test_basic_crawl_sync(tester)
|
||||
|
||||
|
||||
if version in ["full", "transformer"]:
|
||||
test_cosine_extraction(tester)
|
||||
|
||||
@@ -92,49 +109,52 @@ def test_docker_deployment(version="basic"):
|
||||
test_llm_extraction(tester)
|
||||
test_llm_with_ollama(tester)
|
||||
test_screenshot(tester)
|
||||
|
||||
|
||||
|
||||
def test_basic_crawl(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Basic Crawl ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10,
|
||||
"session_id": "test"
|
||||
"priority": 10,
|
||||
"session_id": "test",
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
assert len(result["result"]["markdown"]) > 0
|
||||
|
||||
|
||||
def test_basic_crawl_sync(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Basic Crawl (Sync) ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10,
|
||||
"session_id": "test"
|
||||
"session_id": "test",
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_sync(request)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result['status'] == 'completed'
|
||||
assert result['result']['success']
|
||||
assert len(result['result']['markdown']) > 0
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
assert len(result["result"]["markdown"]) > 0
|
||||
|
||||
|
||||
def test_basic_crawl_direct(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Basic Crawl (Direct) ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10,
|
||||
# "session_id": "test"
|
||||
"cache_mode": "bypass" # or "enabled", "disabled", "read_only", "write_only"
|
||||
"cache_mode": "bypass", # or "enabled", "disabled", "read_only", "write_only"
|
||||
}
|
||||
|
||||
|
||||
result = tester.crawl_direct(request)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result['result']['success']
|
||||
assert len(result['result']['markdown']) > 0
|
||||
|
||||
assert result["result"]["success"]
|
||||
assert len(result["result"]["markdown"]) > 0
|
||||
|
||||
|
||||
def test_js_execution(tester: Crawl4AiTester):
|
||||
print("\n=== Testing JS Execution ===")
|
||||
request = {
|
||||
@@ -144,32 +164,29 @@ def test_js_execution(tester: Crawl4AiTester):
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
],
|
||||
"wait_for": "article.tease-card:nth-child(10)",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
def test_css_selector(tester: Crawl4AiTester):
|
||||
print("\n=== Testing CSS Selector ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 7,
|
||||
"css_selector": ".wide-tease-item__description",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
},
|
||||
"extra": {"word_count_threshold": 10}
|
||||
|
||||
"crawler_params": {"headless": True},
|
||||
"extra": {"word_count_threshold": 10},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
def test_structured_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Structured Extraction ===")
|
||||
schema = {
|
||||
@@ -190,21 +207,16 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
||||
"name": "price",
|
||||
"selector": "td:nth-child(2)",
|
||||
"type": "text",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.coinbase.com/explore",
|
||||
"priority": 9,
|
||||
"extraction_config": {
|
||||
"type": "json_css",
|
||||
"params": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
print(f"Extracted {len(extracted)} items")
|
||||
@@ -212,6 +224,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
||||
assert result["result"]["success"]
|
||||
assert len(extracted) > 0
|
||||
|
||||
|
||||
def test_llm_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing LLM Extraction ===")
|
||||
schema = {
|
||||
@@ -219,20 +232,20 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
"properties": {
|
||||
"model_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the OpenAI model."
|
||||
"description": "Name of the OpenAI model.",
|
||||
},
|
||||
"input_fee": {
|
||||
"type": "string",
|
||||
"description": "Fee for input token for the OpenAI model."
|
||||
"description": "Fee for input token for the OpenAI model.",
|
||||
},
|
||||
"output_fee": {
|
||||
"type": "string",
|
||||
"description": "Fee for output token for the OpenAI model."
|
||||
}
|
||||
"description": "Fee for output token for the OpenAI model.",
|
||||
},
|
||||
},
|
||||
"required": ["model_name", "input_fee", "output_fee"]
|
||||
"required": ["model_name", "input_fee", "output_fee"],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://openai.com/api/pricing",
|
||||
"priority": 8,
|
||||
@@ -243,12 +256,12 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
"api_token": os.getenv("OPENAI_API_KEY"),
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
|
||||
}
|
||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
||||
},
|
||||
},
|
||||
"crawler_params": {"word_count_threshold": 1}
|
||||
"crawler_params": {"word_count_threshold": 1},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -258,6 +271,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
||||
|
||||
|
||||
def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
print("\n=== Testing LLM with Ollama ===")
|
||||
schema = {
|
||||
@@ -265,20 +279,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
"properties": {
|
||||
"article_title": {
|
||||
"type": "string",
|
||||
"description": "The main title of the news article"
|
||||
"description": "The main title of the news article",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A brief summary of the article content"
|
||||
"description": "A brief summary of the article content",
|
||||
},
|
||||
"main_topics": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Main topics or themes discussed in the article"
|
||||
}
|
||||
}
|
||||
"description": "Main topics or themes discussed in the article",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 8,
|
||||
@@ -288,13 +302,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
"provider": "ollama/llama2",
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": "Extract the main article information including title, summary, and main topics."
|
||||
}
|
||||
"instruction": "Extract the main article information including title, summary, and main topics.",
|
||||
},
|
||||
},
|
||||
"extra": {"word_count_threshold": 1},
|
||||
"crawler_params": {"verbose": True}
|
||||
"crawler_params": {"verbose": True},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -303,6 +317,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"Ollama extraction test failed: {str(e)}")
|
||||
|
||||
|
||||
def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Cosine Extraction ===")
|
||||
request = {
|
||||
@@ -314,11 +329,11 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
"semantic_filter": "business finance economy",
|
||||
"word_count_threshold": 10,
|
||||
"max_dist": 0.2,
|
||||
"top_k": 3
|
||||
}
|
||||
}
|
||||
"top_k": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -328,30 +343,30 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"Cosine extraction test failed: {str(e)}")
|
||||
|
||||
|
||||
def test_screenshot(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Screenshot ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 5,
|
||||
"screenshot": True,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print("Screenshot captured:", bool(result["result"]["screenshot"]))
|
||||
|
||||
|
||||
if result["result"]["screenshot"]:
|
||||
# Save screenshot
|
||||
screenshot_data = base64.b64decode(result["result"]["screenshot"])
|
||||
with open("test_screenshot.jpg", "wb") as f:
|
||||
f.write(screenshot_data)
|
||||
print("Screenshot saved as test_screenshot.jpg")
|
||||
|
||||
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
||||
# version = "full"
|
||||
test_docker_deployment(version)
|
||||
test_docker_deployment(version)
|
||||
|
||||
@@ -9,18 +9,17 @@ This example shows how to:
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||
from crawl4ai.extraction_strategy import (
|
||||
LLMExtractionStrategy,
|
||||
JsonCssExtractionStrategy,
|
||||
JsonXPathExtractionStrategy
|
||||
JsonXPathExtractionStrategy,
|
||||
)
|
||||
from crawl4ai.chunking_strategy import RegexChunking, IdentityChunking
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
|
||||
|
||||
async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str):
|
||||
"""Helper function to run extraction with proper configuration"""
|
||||
try:
|
||||
@@ -30,78 +29,90 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
|
||||
extraction_strategy=strategy,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter() # For fit_markdown support
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Run the crawler
|
||||
result = await crawler.arun(url=url, config=config)
|
||||
|
||||
|
||||
if result.success:
|
||||
print(f"\n=== {name} Results ===")
|
||||
print(f"Extracted Content: {result.extracted_content}")
|
||||
print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}")
|
||||
print(f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}")
|
||||
print(
|
||||
f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}"
|
||||
)
|
||||
else:
|
||||
print(f"Error in {name}: Crawl failed")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in {name}: {str(e)}")
|
||||
|
||||
|
||||
async def main():
|
||||
# Example URL (replace with actual URL)
|
||||
url = "https://example.com/product-page"
|
||||
|
||||
|
||||
# Configure browser settings
|
||||
browser_config = BrowserConfig(
|
||||
headless=True,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
browser_config = BrowserConfig(headless=True, verbose=True)
|
||||
|
||||
# Initialize extraction strategies
|
||||
|
||||
|
||||
# 1. LLM Extraction with different input formats
|
||||
markdown_strategy = LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o-mini",
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
instruction="Extract product information including name, price, and description"
|
||||
instruction="Extract product information including name, price, and description",
|
||||
)
|
||||
|
||||
|
||||
html_strategy = LLMExtractionStrategy(
|
||||
input_format="html",
|
||||
provider="openai/gpt-4o-mini",
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
instruction="Extract product information from HTML including structured data"
|
||||
instruction="Extract product information from HTML including structured data",
|
||||
)
|
||||
|
||||
|
||||
fit_markdown_strategy = LLMExtractionStrategy(
|
||||
input_format="fit_markdown",
|
||||
provider="openai/gpt-4o-mini",
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
instruction="Extract product information from cleaned markdown"
|
||||
instruction="Extract product information from cleaned markdown",
|
||||
)
|
||||
|
||||
|
||||
# 2. JSON CSS Extraction (automatically uses HTML input)
|
||||
css_schema = {
|
||||
"baseSelector": ".product",
|
||||
"fields": [
|
||||
{"name": "title", "selector": "h1.product-title", "type": "text"},
|
||||
{"name": "price", "selector": ".price", "type": "text"},
|
||||
{"name": "description", "selector": ".description", "type": "text"}
|
||||
]
|
||||
{"name": "description", "selector": ".description", "type": "text"},
|
||||
],
|
||||
}
|
||||
css_strategy = JsonCssExtractionStrategy(schema=css_schema)
|
||||
|
||||
|
||||
# 3. JSON XPath Extraction (automatically uses HTML input)
|
||||
xpath_schema = {
|
||||
"baseSelector": "//div[@class='product']",
|
||||
"fields": [
|
||||
{"name": "title", "selector": ".//h1[@class='product-title']/text()", "type": "text"},
|
||||
{"name": "price", "selector": ".//span[@class='price']/text()", "type": "text"},
|
||||
{"name": "description", "selector": ".//div[@class='description']/text()", "type": "text"}
|
||||
]
|
||||
{
|
||||
"name": "title",
|
||||
"selector": ".//h1[@class='product-title']/text()",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "price",
|
||||
"selector": ".//span[@class='price']/text()",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"selector": ".//div[@class='description']/text()",
|
||||
"type": "text",
|
||||
},
|
||||
],
|
||||
}
|
||||
xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema)
|
||||
|
||||
|
||||
# Use context manager for proper resource handling
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
# Run all strategies
|
||||
@@ -111,5 +122,6 @@ async def main():
|
||||
await run_extraction(crawler, url, css_strategy, "CSS Extraction")
|
||||
await run_extraction(crawler, url, xpath_strategy, "XPath Extraction")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
import asyncio
|
||||
from crawl4ai import *
|
||||
|
||||
|
||||
async def main():
|
||||
browser_config = BrowserConfig(headless=True, verbose=True)
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
|
||||
)
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||
)
|
||||
),
|
||||
)
|
||||
result = await crawler.arun(
|
||||
url="https://www.helloworld.org",
|
||||
config=crawler_config
|
||||
url="https://www.helloworld.org", config=crawler_config
|
||||
)
|
||||
print(result.markdown_v2.raw_markdown[:500])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||
from playwright.async_api import Page, BrowserContext
|
||||
|
||||
|
||||
async def main():
|
||||
print("🔗 Hooks Example: Demonstrating different hook use cases")
|
||||
|
||||
# Configure browser settings
|
||||
browser_config = BrowserConfig(
|
||||
headless=True
|
||||
)
|
||||
|
||||
browser_config = BrowserConfig(headless=True)
|
||||
|
||||
# Configure crawler settings
|
||||
crawler_run_config = CrawlerRunConfig(
|
||||
js_code="window.scrollTo(0, document.body.scrollHeight);",
|
||||
wait_for="body",
|
||||
cache_mode=CacheMode.BYPASS
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
# Create crawler instance
|
||||
@@ -30,16 +29,22 @@ async def main():
|
||||
"""Hook called after a new page and context are created"""
|
||||
print("[HOOK] on_page_context_created - New page created!")
|
||||
# Example: Set default viewport size
|
||||
await context.add_cookies([{
|
||||
'name': 'session_id',
|
||||
'value': 'example_session',
|
||||
'domain': '.example.com',
|
||||
'path': '/'
|
||||
}])
|
||||
await context.add_cookies(
|
||||
[
|
||||
{
|
||||
"name": "session_id",
|
||||
"value": "example_session",
|
||||
"domain": ".example.com",
|
||||
"path": "/",
|
||||
}
|
||||
]
|
||||
)
|
||||
await page.set_viewport_size({"width": 1080, "height": 800})
|
||||
return page
|
||||
|
||||
async def on_user_agent_updated(page: Page, context: BrowserContext, user_agent: str, **kwargs):
|
||||
async def on_user_agent_updated(
|
||||
page: Page, context: BrowserContext, user_agent: str, **kwargs
|
||||
):
|
||||
"""Hook called when the user agent is updated"""
|
||||
print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}")
|
||||
return page
|
||||
@@ -53,17 +58,17 @@ async def main():
|
||||
"""Hook called before navigating to each URL"""
|
||||
print(f"[HOOK] before_goto - About to visit: {url}")
|
||||
# Example: Add custom headers for the request
|
||||
await page.set_extra_http_headers({
|
||||
"Custom-Header": "my-value"
|
||||
})
|
||||
await page.set_extra_http_headers({"Custom-Header": "my-value"})
|
||||
return page
|
||||
|
||||
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs):
|
||||
async def after_goto(
|
||||
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
|
||||
):
|
||||
"""Hook called after navigating to each URL"""
|
||||
print(f"[HOOK] after_goto - Successfully loaded: {url}")
|
||||
# Example: Wait for a specific element to be loaded
|
||||
try:
|
||||
await page.wait_for_selector('.content', timeout=1000)
|
||||
await page.wait_for_selector(".content", timeout=1000)
|
||||
print("Content element found!")
|
||||
except:
|
||||
print("Content element not found, continuing anyway")
|
||||
@@ -76,7 +81,9 @@ async def main():
|
||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight);")
|
||||
return page
|
||||
|
||||
async def before_return_html(page: Page, context: BrowserContext, html:str, **kwargs):
|
||||
async def before_return_html(
|
||||
page: Page, context: BrowserContext, html: str, **kwargs
|
||||
):
|
||||
"""Hook called before returning the HTML content"""
|
||||
print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})")
|
||||
# Example: You could modify the HTML content here if needed
|
||||
@@ -84,7 +91,9 @@ async def main():
|
||||
|
||||
# Set all the hooks
|
||||
crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created)
|
||||
crawler.crawler_strategy.set_hook("on_page_context_created", on_page_context_created)
|
||||
crawler.crawler_strategy.set_hook(
|
||||
"on_page_context_created", on_page_context_created
|
||||
)
|
||||
crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated)
|
||||
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
|
||||
crawler.crawler_strategy.set_hook("before_goto", before_goto)
|
||||
@@ -95,13 +104,15 @@ async def main():
|
||||
await crawler.start()
|
||||
|
||||
# Example usage: crawl a simple website
|
||||
url = 'https://example.com'
|
||||
url = "https://example.com"
|
||||
result = await crawler.arun(url, config=crawler_run_config)
|
||||
print(f"\nCrawled URL: {result.url}")
|
||||
print(f"HTML length: {len(result.html)}")
|
||||
|
||||
|
||||
await crawler.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy
|
||||
|
||||
|
||||
async def main():
|
||||
# Example 1: Setting language when creating the crawler
|
||||
crawler1 = AsyncWebCrawler(
|
||||
@@ -9,11 +10,15 @@ async def main():
|
||||
)
|
||||
)
|
||||
result1 = await crawler1.arun("https://www.example.com")
|
||||
print("Example 1 result:", result1.extracted_content[:100]) # Print first 100 characters
|
||||
print(
|
||||
"Example 1 result:", result1.extracted_content[:100]
|
||||
) # Print first 100 characters
|
||||
|
||||
# Example 2: Setting language before crawling
|
||||
crawler2 = AsyncWebCrawler()
|
||||
crawler2.crawler_strategy.headers["Accept-Language"] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
|
||||
crawler2.crawler_strategy.headers[
|
||||
"Accept-Language"
|
||||
] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
|
||||
result2 = await crawler2.arun("https://www.example.com")
|
||||
print("Example 2 result:", result2.extracted_content[:100])
|
||||
|
||||
@@ -21,7 +26,7 @@ async def main():
|
||||
crawler3 = AsyncWebCrawler()
|
||||
result3 = await crawler3.arun(
|
||||
"https://www.example.com",
|
||||
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"}
|
||||
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"},
|
||||
)
|
||||
print("Example 3 result:", result3.extracted_content[:100])
|
||||
|
||||
@@ -31,15 +36,15 @@ async def main():
|
||||
("https://www.example.org", "es-ES,es;q=0.9"),
|
||||
("https://www.example.net", "de-DE,de;q=0.9"),
|
||||
]
|
||||
|
||||
|
||||
crawler4 = AsyncWebCrawler()
|
||||
results = await asyncio.gather(*[
|
||||
crawler4.arun(url, headers={"Accept-Language": lang})
|
||||
for url, lang in urls
|
||||
])
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[crawler4.arun(url, headers={"Accept-Language": lang}) for url, lang in urls]
|
||||
)
|
||||
|
||||
for url, result in zip([u for u, _ in urls], results):
|
||||
print(f"Result for {url}:", result.extracted_content[:100])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -3,32 +3,37 @@ from crawl4ai.crawler_strategy import *
|
||||
import asyncio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
url = r'https://openai.com/api/pricing/'
|
||||
url = r"https://openai.com/api/pricing/"
|
||||
|
||||
|
||||
class OpenAIModelFee(BaseModel):
|
||||
model_name: str = Field(..., description="Name of the OpenAI model.")
|
||||
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
|
||||
output_fee: str = Field(..., description="Fee for output token for the OpenAI model.")
|
||||
output_fee: str = Field(
|
||||
..., description="Fee for output token for the OpenAI model."
|
||||
)
|
||||
|
||||
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
|
||||
|
||||
async def main():
|
||||
# Use AsyncWebCrawler
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
word_count_threshold=1,
|
||||
extraction_strategy= LLMExtractionStrategy(
|
||||
extraction_strategy=LLMExtractionStrategy(
|
||||
# provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
|
||||
provider= "groq/llama-3.1-70b-versatile", api_token = os.getenv('GROQ_API_KEY'),
|
||||
provider="groq/llama-3.1-70b-versatile",
|
||||
api_token=os.getenv("GROQ_API_KEY"),
|
||||
schema=OpenAIModelFee.model_json_schema(),
|
||||
extraction_type="schema",
|
||||
instruction="From the crawled content, extract all mentioned model names along with their " \
|
||||
"fees for input and output tokens. Make sure not to miss anything in the entire content. " \
|
||||
'One extracted model JSON format should look like this: ' \
|
||||
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }'
|
||||
instruction="From the crawled content, extract all mentioned model names along with their "
|
||||
"fees for input and output tokens. Make sure not to miss anything in the entire content. "
|
||||
"One extracted model JSON format should look like this: "
|
||||
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }',
|
||||
),
|
||||
|
||||
)
|
||||
print("Success:", result.success)
|
||||
model_fees = json.loads(result.extracted_content)
|
||||
@@ -37,4 +42,5 @@ async def main():
|
||||
with open(".data/data.json", "w", encoding="utf-8") as f:
|
||||
f.write(result.extracted_content)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -8,12 +8,12 @@ import asyncio
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel, Field
|
||||
from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
from crawl4ai.extraction_strategy import (
|
||||
JsonCssExtractionStrategy,
|
||||
LLMExtractionStrategy,
|
||||
@@ -62,6 +62,7 @@ async def clean_content():
|
||||
print(f"Full Markdown Length: {full_markdown_length}")
|
||||
print(f"Fit Markdown Length: {fit_markdown_length}")
|
||||
|
||||
|
||||
async def link_analysis():
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.ENABLED,
|
||||
@@ -76,9 +77,10 @@ async def link_analysis():
|
||||
print(f"Found {len(result.links['internal'])} internal links")
|
||||
print(f"Found {len(result.links['external'])} external links")
|
||||
|
||||
for link in result.links['internal'][:5]:
|
||||
for link in result.links["internal"][:5]:
|
||||
print(f"Href: {link['href']}\nText: {link['text']}\n")
|
||||
|
||||
|
||||
# JavaScript Execution Example
|
||||
async def simple_example_with_running_js_code():
|
||||
print("\n--- Executing JavaScript and Using CSS Selectors ---")
|
||||
@@ -112,25 +114,29 @@ async def simple_example_with_css_selector():
|
||||
)
|
||||
print(result.markdown[:500])
|
||||
|
||||
|
||||
async def media_handling():
|
||||
crawler_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True)
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True
|
||||
)
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.nbcnews.com/business",
|
||||
config=crawler_config
|
||||
url="https://www.nbcnews.com/business", config=crawler_config
|
||||
)
|
||||
for img in result.media['images'][:5]:
|
||||
for img in result.media["images"][:5]:
|
||||
print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}")
|
||||
|
||||
|
||||
async def custom_hook_workflow(verbose=True):
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
# Set a 'before_goto' hook to run custom code just before navigation
|
||||
crawler.crawler_strategy.set_hook("before_goto", lambda page, context: print("[Hook] Preparing to navigate..."))
|
||||
crawler.crawler_strategy.set_hook(
|
||||
"before_goto",
|
||||
lambda page, context: print("[Hook] Preparing to navigate..."),
|
||||
)
|
||||
|
||||
# Perform the crawl operation
|
||||
result = await crawler.arun(
|
||||
url="https://crawl4ai.com"
|
||||
)
|
||||
result = await crawler.arun(url="https://crawl4ai.com")
|
||||
print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- "))
|
||||
|
||||
|
||||
@@ -412,21 +418,22 @@ async def cosine_similarity_extraction():
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
extraction_strategy=CosineStrategy(
|
||||
word_count_threshold=10,
|
||||
max_dist=0.2, # Maximum distance between two words
|
||||
linkage_method="ward", # Linkage method for hierarchical clustering (ward, complete, average, single)
|
||||
top_k=3, # Number of top keywords to extract
|
||||
sim_threshold=0.3, # Similarity threshold for clustering
|
||||
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
|
||||
verbose=True
|
||||
),
|
||||
max_dist=0.2, # Maximum distance between two words
|
||||
linkage_method="ward", # Linkage method for hierarchical clustering (ward, complete, average, single)
|
||||
top_k=3, # Number of top keywords to extract
|
||||
sim_threshold=0.3, # Similarity threshold for clustering
|
||||
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
|
||||
verbose=True,
|
||||
),
|
||||
)
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156",
|
||||
config=crawl_config
|
||||
config=crawl_config,
|
||||
)
|
||||
print(json.loads(result.extracted_content)[:5])
|
||||
|
||||
|
||||
# Browser Comparison
|
||||
async def crawl_custom_browser_type():
|
||||
print("\n--- Browser Comparison ---")
|
||||
@@ -484,39 +491,42 @@ async def crawl_with_user_simulation():
|
||||
result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config)
|
||||
print(result.markdown)
|
||||
|
||||
|
||||
async def ssl_certification():
|
||||
# Configure crawler to fetch SSL certificate
|
||||
config = CrawlerRunConfig(
|
||||
fetch_ssl_certificate=True,
|
||||
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates
|
||||
cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://example.com',
|
||||
config=config
|
||||
)
|
||||
|
||||
result = await crawler.arun(url="https://example.com", config=config)
|
||||
|
||||
if result.success and result.ssl_certificate:
|
||||
cert = result.ssl_certificate
|
||||
|
||||
|
||||
# 1. Access certificate properties directly
|
||||
print("\nCertificate Information:")
|
||||
print(f"Issuer: {cert.issuer.get('CN', '')}")
|
||||
print(f"Valid until: {cert.valid_until}")
|
||||
print(f"Fingerprint: {cert.fingerprint}")
|
||||
|
||||
|
||||
# 2. Export certificate in different formats
|
||||
cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis
|
||||
print("\nCertificate exported to:")
|
||||
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
|
||||
|
||||
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers
|
||||
|
||||
pem_data = cert.to_pem(
|
||||
os.path.join(tmp_dir, "certificate.pem")
|
||||
) # For web servers
|
||||
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
|
||||
|
||||
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps
|
||||
|
||||
der_data = cert.to_der(
|
||||
os.path.join(tmp_dir, "certificate.der")
|
||||
) # For Java apps
|
||||
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
|
||||
|
||||
|
||||
# Speed Comparison
|
||||
async def speed_comparison():
|
||||
print("\n--- Speed Comparison ---")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import os, sys
|
||||
|
||||
# append parent directory to system path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))); os.environ['FIRECRAWL_API_KEY'] = "fc-84b370ccfad44beabc686b38f1769692";
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
os.environ["FIRECRAWL_API_KEY"] = "fc-84b370ccfad44beabc686b38f1769692"
|
||||
|
||||
import asyncio
|
||||
# import nest_asyncio
|
||||
@@ -15,7 +19,7 @@ from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel, Field
|
||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
from crawl4ai.extraction_strategy import (
|
||||
JsonCssExtractionStrategy,
|
||||
LLMExtractionStrategy,
|
||||
@@ -32,9 +36,12 @@ print("Website: https://crawl4ai.com")
|
||||
async def simple_crawl():
|
||||
print("\n--- Basic Usage ---")
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
result = await crawler.arun(url="https://www.nbcnews.com/business", cache_mode= CacheMode.BYPASS)
|
||||
result = await crawler.arun(
|
||||
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
print(result.markdown[:500]) # Print first 500 characters
|
||||
|
||||
|
||||
async def simple_example_with_running_js_code():
|
||||
print("\n--- Executing JavaScript and Using CSS Selectors ---")
|
||||
# New code to handle the wait_for parameter
|
||||
@@ -57,6 +64,7 @@ async def simple_example_with_running_js_code():
|
||||
)
|
||||
print(result.markdown[:500]) # Print first 500 characters
|
||||
|
||||
|
||||
async def simple_example_with_css_selector():
|
||||
print("\n--- Using CSS Selectors ---")
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -67,42 +75,44 @@ async def simple_example_with_css_selector():
|
||||
)
|
||||
print(result.markdown[:500]) # Print first 500 characters
|
||||
|
||||
|
||||
async def use_proxy():
|
||||
print("\n--- Using a Proxy ---")
|
||||
print(
|
||||
"Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example."
|
||||
)
|
||||
# Uncomment and modify the following lines to use a proxy
|
||||
async with AsyncWebCrawler(verbose=True, proxy="http://your-proxy-url:port") as crawler:
|
||||
async with AsyncWebCrawler(
|
||||
verbose=True, proxy="http://your-proxy-url:port"
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.nbcnews.com/business",
|
||||
cache_mode= CacheMode.BYPASS
|
||||
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
if result.success:
|
||||
print(result.markdown[:500]) # Print first 500 characters
|
||||
|
||||
|
||||
async def capture_and_save_screenshot(url: str, output_path: str):
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
screenshot=True,
|
||||
cache_mode= CacheMode.BYPASS
|
||||
url=url, screenshot=True, cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
|
||||
|
||||
if result.success and result.screenshot:
|
||||
import base64
|
||||
|
||||
|
||||
# Decode the base64 screenshot data
|
||||
screenshot_data = base64.b64decode(result.screenshot)
|
||||
|
||||
|
||||
# Save the screenshot as a JPEG file
|
||||
with open(output_path, 'wb') as f:
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(screenshot_data)
|
||||
|
||||
|
||||
print(f"Screenshot saved successfully to {output_path}")
|
||||
else:
|
||||
print("Failed to capture screenshot")
|
||||
|
||||
|
||||
class OpenAIModelFee(BaseModel):
|
||||
model_name: str = Field(..., description="Name of the OpenAI model.")
|
||||
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
|
||||
@@ -110,16 +120,19 @@ class OpenAIModelFee(BaseModel):
|
||||
..., description="Fee for output token for the OpenAI model."
|
||||
)
|
||||
|
||||
async def extract_structured_data_using_llm(provider: str, api_token: str = None, extra_headers: Dict[str, str] = None):
|
||||
|
||||
async def extract_structured_data_using_llm(
|
||||
provider: str, api_token: str = None, extra_headers: Dict[str, str] = None
|
||||
):
|
||||
print(f"\n--- Extracting Structured Data with {provider} ---")
|
||||
|
||||
|
||||
if api_token is None and provider != "ollama":
|
||||
print(f"API token is required for {provider}. Skipping this example.")
|
||||
return
|
||||
|
||||
# extra_args = {}
|
||||
extra_args={
|
||||
"temperature": 0,
|
||||
extra_args = {
|
||||
"temperature": 0,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": 2000,
|
||||
# any other supported parameters for litellm
|
||||
@@ -139,52 +152,49 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None
|
||||
instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens.
|
||||
Do not miss any models in the entire content. One extracted model JSON format should look like this:
|
||||
{"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""",
|
||||
extra_args=extra_args
|
||||
extra_args=extra_args,
|
||||
),
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
print(result.extracted_content)
|
||||
|
||||
|
||||
async def extract_structured_data_using_css_extractor():
|
||||
print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---")
|
||||
schema = {
|
||||
"name": "KidoCode Courses",
|
||||
"baseSelector": "section.charge-methodology .w-tab-content > div",
|
||||
"fields": [
|
||||
{
|
||||
"name": "section_title",
|
||||
"selector": "h3.heading-50",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "section_description",
|
||||
"selector": ".charge-content",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "course_name",
|
||||
"selector": ".text-block-93",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "course_description",
|
||||
"selector": ".course-content-text",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "course_icon",
|
||||
"selector": ".image-92",
|
||||
"type": "attribute",
|
||||
"attribute": "src"
|
||||
}
|
||||
]
|
||||
}
|
||||
"name": "KidoCode Courses",
|
||||
"baseSelector": "section.charge-methodology .w-tab-content > div",
|
||||
"fields": [
|
||||
{
|
||||
"name": "section_title",
|
||||
"selector": "h3.heading-50",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "section_description",
|
||||
"selector": ".charge-content",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "course_name",
|
||||
"selector": ".text-block-93",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "course_description",
|
||||
"selector": ".course-content-text",
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "course_icon",
|
||||
"selector": ".image-92",
|
||||
"type": "attribute",
|
||||
"attribute": "src",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async with AsyncWebCrawler(
|
||||
headless=True,
|
||||
verbose=True
|
||||
) as crawler:
|
||||
|
||||
async with AsyncWebCrawler(headless=True, verbose=True) as crawler:
|
||||
# Create the JavaScript that handles clicking multiple times
|
||||
js_click_tabs = """
|
||||
(async () => {
|
||||
@@ -198,19 +208,20 @@ async def extract_structured_data_using_css_extractor():
|
||||
await new Promise(r => setTimeout(r, 500));
|
||||
}
|
||||
})();
|
||||
"""
|
||||
"""
|
||||
|
||||
result = await crawler.arun(
|
||||
url="https://www.kidocode.com/degrees/technology",
|
||||
extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True),
|
||||
js_code=[js_click_tabs],
|
||||
cache_mode=CacheMode.BYPASS
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
companies = json.loads(result.extracted_content)
|
||||
print(f"Successfully extracted {len(companies)} companies")
|
||||
print(json.dumps(companies[0], indent=2))
|
||||
|
||||
|
||||
# Advanced Session-Based Crawling with Dynamic Content 🔄
|
||||
async def crawl_dynamic_content_pages_method_1():
|
||||
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
|
||||
@@ -267,6 +278,7 @@ async def crawl_dynamic_content_pages_method_1():
|
||||
await crawler.crawler_strategy.kill_session(session_id)
|
||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||
|
||||
|
||||
async def crawl_dynamic_content_pages_method_2():
|
||||
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
|
||||
|
||||
@@ -334,8 +346,11 @@ async def crawl_dynamic_content_pages_method_2():
|
||||
await crawler.crawler_strategy.kill_session(session_id)
|
||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||
|
||||
|
||||
async def crawl_dynamic_content_pages_method_3():
|
||||
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---")
|
||||
print(
|
||||
"\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---"
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://github.com/microsoft/TypeScript/commits/main"
|
||||
@@ -357,7 +372,7 @@ async def crawl_dynamic_content_pages_method_3():
|
||||
const firstCommit = commits[0].textContent.trim();
|
||||
return firstCommit !== window.firstCommit;
|
||||
}"""
|
||||
|
||||
|
||||
schema = {
|
||||
"name": "Commit Extractor",
|
||||
"baseSelector": "li.Box-sc-g0xbh4-0",
|
||||
@@ -395,40 +410,53 @@ async def crawl_dynamic_content_pages_method_3():
|
||||
await crawler.crawler_strategy.kill_session(session_id)
|
||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||
|
||||
|
||||
async def crawl_custom_browser_type():
|
||||
# Use Firefox
|
||||
start = time.time()
|
||||
async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler:
|
||||
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
|
||||
async with AsyncWebCrawler(
|
||||
browser_type="firefox", verbose=True, headless=True
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.example.com", cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
print(result.markdown[:500])
|
||||
print("Time taken: ", time.time() - start)
|
||||
|
||||
# Use WebKit
|
||||
start = time.time()
|
||||
async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler:
|
||||
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
|
||||
async with AsyncWebCrawler(
|
||||
browser_type="webkit", verbose=True, headless=True
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.example.com", cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
print(result.markdown[:500])
|
||||
print("Time taken: ", time.time() - start)
|
||||
|
||||
# Use Chromium (default)
|
||||
start = time.time()
|
||||
async with AsyncWebCrawler(verbose=True, headless = True) as crawler:
|
||||
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
|
||||
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.example.com", cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
print(result.markdown[:500])
|
||||
print("Time taken: ", time.time() - start)
|
||||
|
||||
|
||||
async def crawl_with_user_simultion():
|
||||
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
|
||||
url = "YOUR-URL-HERE"
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
url=url,
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
magic = True, # Automatically detects and removes overlays, popups, and other elements that block content
|
||||
magic=True, # Automatically detects and removes overlays, popups, and other elements that block content
|
||||
# simulate_user = True,# Causes a series of random mouse movements and clicks to simulate user interaction
|
||||
# override_navigator = True # Overrides the navigator object to make it look like a real user
|
||||
)
|
||||
|
||||
print(result.markdown)
|
||||
|
||||
print(result.markdown)
|
||||
|
||||
|
||||
async def speed_comparison():
|
||||
# print("\n--- Speed Comparison ---")
|
||||
@@ -439,18 +467,18 @@ async def speed_comparison():
|
||||
# print()
|
||||
# Simulated Firecrawl performance
|
||||
from firecrawl import FirecrawlApp
|
||||
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
|
||||
|
||||
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
|
||||
start = time.time()
|
||||
scrape_status = app.scrape_url(
|
||||
'https://www.nbcnews.com/business',
|
||||
params={'formats': ['markdown', 'html']}
|
||||
"https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
|
||||
)
|
||||
end = time.time()
|
||||
print("Firecrawl:")
|
||||
print(f"Time taken: {end - start:.2f} seconds")
|
||||
print(f"Content length: {len(scrape_status['markdown'])} characters")
|
||||
print(f"Images found: {scrape_status['markdown'].count('cldnry.s-nbcnews.com')}")
|
||||
print()
|
||||
print()
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
# Crawl4AI simple crawl
|
||||
@@ -474,7 +502,9 @@ async def speed_comparison():
|
||||
url="https://www.nbcnews.com/business",
|
||||
word_count_threshold=0,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||
)
|
||||
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
|
||||
),
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
@@ -498,7 +528,9 @@ async def speed_comparison():
|
||||
word_count_threshold=0,
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||
)
|
||||
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
|
||||
),
|
||||
verbose=False,
|
||||
@@ -520,11 +552,12 @@ async def speed_comparison():
|
||||
print("If you run these tests in an environment with better network conditions,")
|
||||
print("you may observe an even more significant speed advantage for Crawl4AI.")
|
||||
|
||||
|
||||
async def generate_knowledge_graph():
|
||||
class Entity(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
entity1: Entity
|
||||
entity2: Entity
|
||||
@@ -536,11 +569,11 @@ async def generate_knowledge_graph():
|
||||
relationships: List[Relationship]
|
||||
|
||||
extraction_strategy = LLMExtractionStrategy(
|
||||
provider='openai/gpt-4o-mini', # Or any other provider, including Ollama and open source models
|
||||
api_token=os.getenv('OPENAI_API_KEY'), # In case of Ollama just pass "no-token"
|
||||
schema=KnowledgeGraph.model_json_schema(),
|
||||
extraction_type="schema",
|
||||
instruction="""Extract entities and relationships from the given text."""
|
||||
provider="openai/gpt-4o-mini", # Or any other provider, including Ollama and open source models
|
||||
api_token=os.getenv("OPENAI_API_KEY"), # In case of Ollama just pass "no-token"
|
||||
schema=KnowledgeGraph.model_json_schema(),
|
||||
extraction_type="schema",
|
||||
instruction="""Extract entities and relationships from the given text.""",
|
||||
)
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
url = "https://paulgraham.com/love.html"
|
||||
@@ -554,27 +587,22 @@ async def generate_knowledge_graph():
|
||||
with open(os.path.join(__location__, "kb.json"), "w") as f:
|
||||
f.write(result.extracted_content)
|
||||
|
||||
|
||||
async def fit_markdown_remove_overlay():
|
||||
|
||||
async with AsyncWebCrawler(
|
||||
headless=True, # Set to False to see what is happening
|
||||
verbose=True,
|
||||
user_agent_mode="random",
|
||||
user_agent_generator_config={
|
||||
"device_type": "mobile",
|
||||
"os_type": "android"
|
||||
},
|
||||
headless=True, # Set to False to see what is happening
|
||||
verbose=True,
|
||||
user_agent_mode="random",
|
||||
user_agent_generator_config={"device_type": "mobile", "os_type": "android"},
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://www.kidocode.com/degrees/technology',
|
||||
url="https://www.kidocode.com/degrees/technology",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(
|
||||
content_filter=PruningContentFilter(
|
||||
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||
),
|
||||
options={
|
||||
"ignore_links": True
|
||||
}
|
||||
options={"ignore_links": True},
|
||||
),
|
||||
# markdown_generator=DefaultMarkdownGenerator(
|
||||
# content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0),
|
||||
@@ -583,31 +611,38 @@ async def fit_markdown_remove_overlay():
|
||||
# }
|
||||
# ),
|
||||
)
|
||||
|
||||
|
||||
if result.success:
|
||||
print(len(result.markdown_v2.raw_markdown))
|
||||
print(len(result.markdown_v2.markdown_with_citations))
|
||||
print(len(result.markdown_v2.fit_markdown))
|
||||
|
||||
|
||||
# Save clean html
|
||||
with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f:
|
||||
f.write(result.cleaned_html)
|
||||
|
||||
with open(os.path.join(__location__, "output/output_raw_markdown.md"), "w") as f:
|
||||
|
||||
with open(
|
||||
os.path.join(__location__, "output/output_raw_markdown.md"), "w"
|
||||
) as f:
|
||||
f.write(result.markdown_v2.raw_markdown)
|
||||
|
||||
with open(os.path.join(__location__, "output/output_markdown_with_citations.md"), "w") as f:
|
||||
f.write(result.markdown_v2.markdown_with_citations)
|
||||
|
||||
with open(os.path.join(__location__, "output/output_fit_markdown.md"), "w") as f:
|
||||
|
||||
with open(
|
||||
os.path.join(__location__, "output/output_markdown_with_citations.md"),
|
||||
"w",
|
||||
) as f:
|
||||
f.write(result.markdown_v2.markdown_with_citations)
|
||||
|
||||
with open(
|
||||
os.path.join(__location__, "output/output_fit_markdown.md"), "w"
|
||||
) as f:
|
||||
f.write(result.markdown_v2.fit_markdown)
|
||||
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
async def main():
|
||||
# await extract_structured_data_using_llm("openai/gpt-4o", os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
|
||||
# await simple_crawl()
|
||||
# await simple_example_with_running_js_code()
|
||||
# await simple_example_with_css_selector()
|
||||
@@ -618,7 +653,7 @@ async def main():
|
||||
# LLM extraction examples
|
||||
# await extract_structured_data_using_llm()
|
||||
# await extract_structured_data_using_llm("huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", os.getenv("HUGGINGFACE_API_KEY"))
|
||||
# await extract_structured_data_using_llm("ollama/llama3.2")
|
||||
# await extract_structured_data_using_llm("ollama/llama3.2")
|
||||
|
||||
# You always can pass custom headers to the extraction strategy
|
||||
# custom_headers = {
|
||||
@@ -626,13 +661,13 @@ async def main():
|
||||
# "X-Custom-Header": "Some-Value"
|
||||
# }
|
||||
# await extract_structured_data_using_llm(extra_headers=custom_headers)
|
||||
|
||||
|
||||
# await crawl_dynamic_content_pages_method_1()
|
||||
# await crawl_dynamic_content_pages_method_2()
|
||||
await crawl_dynamic_content_pages_method_3()
|
||||
|
||||
|
||||
# await crawl_custom_browser_type()
|
||||
|
||||
|
||||
# await speed_comparison()
|
||||
|
||||
|
||||
|
||||
@@ -10,15 +10,17 @@ from functools import lru_cache
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def create_crawler():
|
||||
crawler = WebCrawler(verbose=True)
|
||||
crawler.warmup()
|
||||
return crawler
|
||||
|
||||
|
||||
def print_result(result):
|
||||
# Print each key in one line and just the first 10 characters of each one's value and three dots
|
||||
console.print(f"\t[bold]Result:[/bold]")
|
||||
console.print("\t[bold]Result:[/bold]")
|
||||
for key, value in result.model_dump().items():
|
||||
if isinstance(value, str) and value:
|
||||
console.print(f"\t{key}: [green]{value[:20]}...[/green]")
|
||||
@@ -33,18 +35,27 @@ def cprint(message, press_any_key=False):
|
||||
console.print("Press any key to continue...", style="")
|
||||
input()
|
||||
|
||||
|
||||
def basic_usage(crawler):
|
||||
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", only_text = True)
|
||||
cprint(
|
||||
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
|
||||
)
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", only_text=True)
|
||||
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def basic_usage_some_params(crawler):
|
||||
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", word_count_threshold=1, only_text = True)
|
||||
cprint(
|
||||
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business", word_count_threshold=1, only_text=True
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def screenshot_usage(crawler):
|
||||
cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]")
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
|
||||
@@ -55,16 +66,23 @@ def screenshot_usage(crawler):
|
||||
cprint("Screenshot saved to 'screenshot.png'!")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def understanding_parameters(crawler):
|
||||
cprint("\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]")
|
||||
cprint("By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action.")
|
||||
|
||||
cprint(
|
||||
"\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]"
|
||||
)
|
||||
cprint(
|
||||
"By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action."
|
||||
)
|
||||
|
||||
# First crawl (reads from cache)
|
||||
cprint("1️⃣ First crawl (caches the result):", True)
|
||||
start_time = time.time()
|
||||
result = crawler.run(url="https://www.nbcnews.com/business")
|
||||
end_time = time.time()
|
||||
cprint(f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]")
|
||||
cprint(
|
||||
f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]"
|
||||
)
|
||||
print_result(result)
|
||||
|
||||
# Force to crawl again
|
||||
@@ -72,169 +90,232 @@ def understanding_parameters(crawler):
|
||||
start_time = time.time()
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True)
|
||||
end_time = time.time()
|
||||
cprint(f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]")
|
||||
cprint(
|
||||
f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]"
|
||||
)
|
||||
print_result(result)
|
||||
|
||||
|
||||
def add_chunking_strategy(crawler):
|
||||
# Adding a chunking strategy: RegexChunking
|
||||
cprint("\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", True)
|
||||
cprint("RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!")
|
||||
cprint(
|
||||
"\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
chunking_strategy=RegexChunking(patterns=["\n\n"])
|
||||
chunking_strategy=RegexChunking(patterns=["\n\n"]),
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
# Adding another chunking strategy: NlpSentenceChunking
|
||||
cprint("\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", True)
|
||||
cprint("NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!")
|
||||
cprint(
|
||||
"\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
chunking_strategy=NlpSentenceChunking()
|
||||
url="https://www.nbcnews.com/business", chunking_strategy=NlpSentenceChunking()
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def add_extraction_strategy(crawler):
|
||||
# Adding an extraction strategy: CosineStrategy
|
||||
cprint("\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", True)
|
||||
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
||||
cprint(
|
||||
"\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True)
|
||||
extraction_strategy=CosineStrategy(
|
||||
word_count_threshold=10,
|
||||
max_dist=0.2,
|
||||
linkage_method="ward",
|
||||
top_k=3,
|
||||
sim_threshold=0.3,
|
||||
verbose=True,
|
||||
),
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
# Using semantic_filter with CosineStrategy
|
||||
cprint("You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!")
|
||||
cprint(
|
||||
"You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy=CosineStrategy(
|
||||
semantic_filter="inflation rent prices",
|
||||
)
|
||||
),
|
||||
)
|
||||
cprint(
|
||||
"[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]"
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def add_llm_extraction_strategy(crawler):
|
||||
# Adding an LLM extraction strategy without instructions
|
||||
cprint("\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", True)
|
||||
cprint("LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!")
|
||||
cprint(
|
||||
"\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-4o", api_token=os.getenv('OPENAI_API_KEY'))
|
||||
extraction_strategy=LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o", api_token=os.getenv("OPENAI_API_KEY")
|
||||
),
|
||||
)
|
||||
cprint(
|
||||
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]"
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
# Adding an LLM extraction strategy with instructions
|
||||
cprint("\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", True)
|
||||
cprint("Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!")
|
||||
cprint(
|
||||
"\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!"
|
||||
)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy=LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o",
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
instruction="I am interested in only financial news"
|
||||
)
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
instruction="I am interested in only financial news",
|
||||
),
|
||||
)
|
||||
cprint(
|
||||
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]"
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy=LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o",
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
instruction="Extract only content related to technology"
|
||||
)
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
instruction="Extract only content related to technology",
|
||||
),
|
||||
)
|
||||
cprint(
|
||||
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]"
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def targeted_extraction(crawler):
|
||||
# Using a CSS selector to extract only H2 tags
|
||||
cprint("\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", True)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
css_selector="h2"
|
||||
cprint(
|
||||
"\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", css_selector="h2")
|
||||
cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def interactive_extraction(crawler):
|
||||
# Passing JavaScript code to interact with the page
|
||||
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True)
|
||||
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.")
|
||||
cprint(
|
||||
"\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"In this example we try to click the 'Load More' button on the page using JavaScript code."
|
||||
)
|
||||
js_code = """
|
||||
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
|
||||
loadMoreButton && loadMoreButton.click();
|
||||
"""
|
||||
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
|
||||
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
js = js_code
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
|
||||
cprint(
|
||||
"[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def multiple_scrip(crawler):
|
||||
# Passing JavaScript code to interact with the page
|
||||
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True)
|
||||
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.")
|
||||
js_code = ["""
|
||||
cprint(
|
||||
"\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
cprint(
|
||||
"In this example we try to click the 'Load More' button on the page using JavaScript code."
|
||||
)
|
||||
js_code = [
|
||||
"""
|
||||
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
|
||||
loadMoreButton && loadMoreButton.click();
|
||||
"""] * 2
|
||||
"""
|
||||
] * 2
|
||||
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
|
||||
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
|
||||
result = crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
js = js_code
|
||||
result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
|
||||
cprint(
|
||||
"[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
|
||||
)
|
||||
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
def using_crawler_hooks(crawler):
|
||||
# Example usage of the hooks for authentication and setting a cookie
|
||||
def on_driver_created(driver):
|
||||
print("[HOOK] on_driver_created")
|
||||
# Example customization: maximize the window
|
||||
driver.maximize_window()
|
||||
|
||||
|
||||
# Example customization: logging in to a hypothetical website
|
||||
driver.get('https://example.com/login')
|
||||
|
||||
driver.get("https://example.com/login")
|
||||
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
|
||||
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.NAME, 'username'))
|
||||
EC.presence_of_element_located((By.NAME, "username"))
|
||||
)
|
||||
driver.find_element(By.NAME, 'username').send_keys('testuser')
|
||||
driver.find_element(By.NAME, 'password').send_keys('password123')
|
||||
driver.find_element(By.NAME, 'login').click()
|
||||
driver.find_element(By.NAME, "username").send_keys("testuser")
|
||||
driver.find_element(By.NAME, "password").send_keys("password123")
|
||||
driver.find_element(By.NAME, "login").click()
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.ID, 'welcome'))
|
||||
EC.presence_of_element_located((By.ID, "welcome"))
|
||||
)
|
||||
# Add a custom cookie
|
||||
driver.add_cookie({'name': 'test_cookie', 'value': 'cookie_value'})
|
||||
return driver
|
||||
|
||||
driver.add_cookie({"name": "test_cookie", "value": "cookie_value"})
|
||||
return driver
|
||||
|
||||
def before_get_url(driver):
|
||||
print("[HOOK] before_get_url")
|
||||
# Example customization: add a custom header
|
||||
# Enable Network domain for sending headers
|
||||
driver.execute_cdp_cmd('Network.enable', {})
|
||||
driver.execute_cdp_cmd("Network.enable", {})
|
||||
# Add a custom header
|
||||
driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': {'X-Test-Header': 'test'}})
|
||||
driver.execute_cdp_cmd(
|
||||
"Network.setExtraHTTPHeaders", {"headers": {"X-Test-Header": "test"}}
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
def after_get_url(driver):
|
||||
print("[HOOK] after_get_url")
|
||||
# Example customization: log the URL
|
||||
@@ -246,48 +327,59 @@ def using_crawler_hooks(crawler):
|
||||
# Example customization: log the HTML
|
||||
print(len(html))
|
||||
return driver
|
||||
|
||||
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", True)
|
||||
|
||||
|
||||
cprint(
|
||||
"\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]",
|
||||
True,
|
||||
)
|
||||
|
||||
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
|
||||
crawler_strategy.set_hook('on_driver_created', on_driver_created)
|
||||
crawler_strategy.set_hook('before_get_url', before_get_url)
|
||||
crawler_strategy.set_hook('after_get_url', after_get_url)
|
||||
crawler_strategy.set_hook('before_return_html', before_return_html)
|
||||
|
||||
crawler_strategy.set_hook("on_driver_created", on_driver_created)
|
||||
crawler_strategy.set_hook("before_get_url", before_get_url)
|
||||
crawler_strategy.set_hook("after_get_url", after_get_url)
|
||||
crawler_strategy.set_hook("before_return_html", before_return_html)
|
||||
|
||||
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
|
||||
crawler.warmup()
|
||||
crawler.warmup()
|
||||
result = crawler.run(url="https://example.com")
|
||||
|
||||
|
||||
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
|
||||
print_result(result= result)
|
||||
|
||||
print_result(result=result)
|
||||
|
||||
|
||||
def using_crawler_hooks_dleay_example(crawler):
|
||||
def delay(driver):
|
||||
print("Delaying for 5 seconds...")
|
||||
time.sleep(5)
|
||||
print("Resuming...")
|
||||
|
||||
|
||||
def create_crawler():
|
||||
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
|
||||
crawler_strategy.set_hook('after_get_url', delay)
|
||||
crawler_strategy.set_hook("after_get_url", delay)
|
||||
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
|
||||
crawler.warmup()
|
||||
return crawler
|
||||
|
||||
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]")
|
||||
cprint(
|
||||
"\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]"
|
||||
)
|
||||
crawler = create_crawler()
|
||||
result = crawler.run(url="https://google.com", bypass_cache=True)
|
||||
|
||||
result = crawler.run(url="https://google.com", bypass_cache=True)
|
||||
|
||||
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
|
||||
print_result(result)
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
cprint("🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]")
|
||||
cprint("⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]")
|
||||
cprint("If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files.")
|
||||
cprint(
|
||||
"🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]"
|
||||
)
|
||||
cprint(
|
||||
"⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]"
|
||||
)
|
||||
cprint(
|
||||
"If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files."
|
||||
)
|
||||
|
||||
crawler = create_crawler()
|
||||
|
||||
@@ -295,7 +387,7 @@ def main():
|
||||
basic_usage(crawler)
|
||||
# basic_usage_some_params(crawler)
|
||||
understanding_parameters(crawler)
|
||||
|
||||
|
||||
crawler.always_by_pass_cache = True
|
||||
screenshot_usage(crawler)
|
||||
add_chunking_strategy(crawler)
|
||||
@@ -305,8 +397,10 @@ def main():
|
||||
interactive_extraction(crawler)
|
||||
multiple_scrip(crawler)
|
||||
|
||||
cprint("\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]")
|
||||
cprint(
|
||||
"\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -11,7 +11,9 @@ from groq import Groq
|
||||
# Import threadpools to run the crawl_url function in a separate thread
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
client = AsyncOpenAI(base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY"))
|
||||
client = AsyncOpenAI(
|
||||
base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")
|
||||
)
|
||||
|
||||
# Instrument the OpenAI client
|
||||
cl.instrument_openai()
|
||||
@@ -25,41 +27,39 @@ settings = {
|
||||
"presence_penalty": 0,
|
||||
}
|
||||
|
||||
|
||||
def extract_urls(text):
|
||||
url_pattern = re.compile(r'(https?://\S+)')
|
||||
url_pattern = re.compile(r"(https?://\S+)")
|
||||
return url_pattern.findall(text)
|
||||
|
||||
|
||||
def crawl_url(url):
|
||||
data = {
|
||||
"urls": [url],
|
||||
"include_raw_html": True,
|
||||
"word_count_threshold": 10,
|
||||
"extraction_strategy": "NoExtractionStrategy",
|
||||
"chunking_strategy": "RegexChunking"
|
||||
"chunking_strategy": "RegexChunking",
|
||||
}
|
||||
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
||||
response_data = response.json()
|
||||
response_data = response_data['results'][0]
|
||||
return response_data['markdown']
|
||||
response_data = response_data["results"][0]
|
||||
return response_data["markdown"]
|
||||
|
||||
|
||||
@cl.on_chat_start
|
||||
async def on_chat_start():
|
||||
cl.user_session.set("session", {
|
||||
"history": [],
|
||||
"context": {}
|
||||
})
|
||||
await cl.Message(
|
||||
content="Welcome to the chat! How can I assist you today?"
|
||||
).send()
|
||||
cl.user_session.set("session", {"history": [], "context": {}})
|
||||
await cl.Message(content="Welcome to the chat! How can I assist you today?").send()
|
||||
|
||||
|
||||
@cl.on_message
|
||||
async def on_message(message: cl.Message):
|
||||
user_session = cl.user_session.get("session")
|
||||
|
||||
|
||||
# Extract URLs from the user's message
|
||||
urls = extract_urls(message.content)
|
||||
|
||||
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for url in urls:
|
||||
@@ -69,16 +69,9 @@ async def on_message(message: cl.Message):
|
||||
|
||||
for url, result in zip(urls, results):
|
||||
ref_number = f"REF_{len(user_session['context']) + 1}"
|
||||
user_session["context"][ref_number] = {
|
||||
"url": url,
|
||||
"content": result
|
||||
}
|
||||
user_session["context"][ref_number] = {"url": url, "content": result}
|
||||
|
||||
|
||||
user_session["history"].append({
|
||||
"role": "user",
|
||||
"content": message.content
|
||||
})
|
||||
user_session["history"].append({"role": "user", "content": message.content})
|
||||
|
||||
# Create a system message that includes the context
|
||||
context_messages = [
|
||||
@@ -95,26 +88,17 @@ async def on_message(message: cl.Message):
|
||||
"If not, there is no need to add a references section. "
|
||||
"At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n"
|
||||
"\n\n".join(context_messages)
|
||||
)
|
||||
),
|
||||
}
|
||||
else:
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
|
||||
system_message = {"role": "system", "content": "You are a helpful assistant."}
|
||||
|
||||
msg = cl.Message(content="")
|
||||
await msg.send()
|
||||
|
||||
# Get response from the LLM
|
||||
stream = await client.chat.completions.create(
|
||||
messages=[
|
||||
system_message,
|
||||
*user_session["history"]
|
||||
],
|
||||
stream=True,
|
||||
**settings
|
||||
messages=[system_message, *user_session["history"]], stream=True, **settings
|
||||
)
|
||||
|
||||
assistant_response = ""
|
||||
@@ -124,10 +108,7 @@ async def on_message(message: cl.Message):
|
||||
await msg.stream_token(token)
|
||||
|
||||
# Add assistant message to the history
|
||||
user_session["history"].append({
|
||||
"role": "assistant",
|
||||
"content": assistant_response
|
||||
})
|
||||
user_session["history"].append({"role": "assistant", "content": assistant_response})
|
||||
await msg.update()
|
||||
|
||||
# Append the reference section to the assistant's response
|
||||
@@ -154,10 +135,11 @@ async def on_audio_chunk(chunk: cl.AudioChunk):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@cl.step(type="tool")
|
||||
async def speech_to_text(audio_file):
|
||||
cli = Groq()
|
||||
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model="whisper-large-v3", file=audio_file
|
||||
)
|
||||
@@ -172,24 +154,19 @@ async def on_audio_end(elements: list[ElementBased]):
|
||||
audio_buffer.seek(0) # Move the file pointer to the beginning
|
||||
audio_file = audio_buffer.read()
|
||||
audio_mime_type: str = cl.user_session.get("audio_mime_type")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
whisper_input = (audio_buffer.name, audio_file, audio_mime_type)
|
||||
transcription = await speech_to_text(whisper_input)
|
||||
end_time = time.time()
|
||||
print(f"Transcription took {end_time - start_time} seconds")
|
||||
|
||||
user_msg = cl.Message(
|
||||
author="You",
|
||||
type="user_message",
|
||||
content=transcription
|
||||
)
|
||||
|
||||
user_msg = cl.Message(author="You", type="user_message", content=transcription)
|
||||
await user_msg.send()
|
||||
await on_message(user_msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from chainlit.cli import run_chainlit
|
||||
|
||||
run_chainlit(__file__)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import requests, base64, os
|
||||
|
||||
data = {
|
||||
@@ -6,59 +5,50 @@ data = {
|
||||
"screenshot": True,
|
||||
}
|
||||
|
||||
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
||||
result = response.json()['results'][0]
|
||||
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
||||
result = response.json()["results"][0]
|
||||
print(result.keys())
|
||||
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
|
||||
# 'links', 'screenshot', 'markdown', 'extracted_content',
|
||||
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
|
||||
# 'links', 'screenshot', 'markdown', 'extracted_content',
|
||||
# 'metadata', 'error_message'])
|
||||
with open("screenshot.png", "wb") as f:
|
||||
f.write(base64.b64decode(result['screenshot']))
|
||||
|
||||
f.write(base64.b64decode(result["screenshot"]))
|
||||
|
||||
# Example of filtering the content using CSS selectors
|
||||
data = {
|
||||
"urls": [
|
||||
"https://www.nbcnews.com/business"
|
||||
],
|
||||
"urls": ["https://www.nbcnews.com/business"],
|
||||
"css_selector": "article",
|
||||
"screenshot": True,
|
||||
}
|
||||
|
||||
# Example of executing a JS script on the page before extracting the content
|
||||
data = {
|
||||
"urls": [
|
||||
"https://www.nbcnews.com/business"
|
||||
],
|
||||
"urls": ["https://www.nbcnews.com/business"],
|
||||
"screenshot": True,
|
||||
'js' : ["""
|
||||
"js": [
|
||||
"""
|
||||
const loadMoreButton = Array.from(document.querySelectorAll('button')).
|
||||
find(button => button.textContent.includes('Load More'));
|
||||
loadMoreButton && loadMoreButton.click();
|
||||
"""]
|
||||
"""
|
||||
],
|
||||
}
|
||||
|
||||
# Example of using a custom extraction strategy
|
||||
data = {
|
||||
"urls": [
|
||||
"https://www.nbcnews.com/business"
|
||||
],
|
||||
"urls": ["https://www.nbcnews.com/business"],
|
||||
"extraction_strategy": "CosineStrategy",
|
||||
"extraction_strategy_args": {
|
||||
"semantic_filter": "inflation rent prices"
|
||||
},
|
||||
"extraction_strategy_args": {"semantic_filter": "inflation rent prices"},
|
||||
}
|
||||
|
||||
# Example of using LLM to extract content
|
||||
data = {
|
||||
"urls": [
|
||||
"https://www.nbcnews.com/business"
|
||||
],
|
||||
"urls": ["https://www.nbcnews.com/business"],
|
||||
"extraction_strategy": "LLMExtractionStrategy",
|
||||
"extraction_strategy_args": {
|
||||
"provider": "groq/llama3-8b-8192",
|
||||
"api_token": os.environ.get("GROQ_API_KEY"),
|
||||
"instruction": """I am interested in only financial news,
|
||||
and translate them in French."""
|
||||
and translate them in French.""",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -5,42 +5,47 @@ import os
|
||||
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode
|
||||
|
||||
# Create tmp directory if it doesn't exist
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parent_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
tmp_dir = os.path.join(parent_dir, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
|
||||
async def main():
|
||||
# Configure crawler to fetch SSL certificate
|
||||
config = CrawlerRunConfig(
|
||||
fetch_ssl_certificate=True,
|
||||
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates
|
||||
cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://example.com',
|
||||
config=config
|
||||
)
|
||||
|
||||
result = await crawler.arun(url="https://example.com", config=config)
|
||||
|
||||
if result.success and result.ssl_certificate:
|
||||
cert = result.ssl_certificate
|
||||
|
||||
|
||||
# 1. Access certificate properties directly
|
||||
print("\nCertificate Information:")
|
||||
print(f"Issuer: {cert.issuer.get('CN', '')}")
|
||||
print(f"Valid until: {cert.valid_until}")
|
||||
print(f"Fingerprint: {cert.fingerprint}")
|
||||
|
||||
|
||||
# 2. Export certificate in different formats
|
||||
cert.to_json(os.path.join(tmp_dir, "certificate.json")) # For analysis
|
||||
print("\nCertificate exported to:")
|
||||
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
|
||||
|
||||
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers
|
||||
|
||||
pem_data = cert.to_pem(
|
||||
os.path.join(tmp_dir, "certificate.pem")
|
||||
) # For web servers
|
||||
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
|
||||
|
||||
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps
|
||||
|
||||
der_data = cert.to_der(
|
||||
os.path.join(tmp_dir, "certificate.der")
|
||||
) # For Java apps
|
||||
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,39 +1,41 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from crawl4ai.web_crawler import WebCrawler
|
||||
from crawl4ai.chunking_strategy import *
|
||||
from crawl4ai.extraction_strategy import *
|
||||
from crawl4ai.crawler_strategy import *
|
||||
|
||||
url = r'https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot'
|
||||
url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot"
|
||||
|
||||
crawler = WebCrawler()
|
||||
crawler.warmup()
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PageSummary(BaseModel):
|
||||
title: str = Field(..., description="Title of the page.")
|
||||
summary: str = Field(..., description="Summary of the page.")
|
||||
brief_summary: str = Field(..., description="Brief summary of the page.")
|
||||
keywords: list = Field(..., description="Keywords assigned to the page.")
|
||||
|
||||
|
||||
result = crawler.run(
|
||||
url=url,
|
||||
word_count_threshold=1,
|
||||
extraction_strategy= LLMExtractionStrategy(
|
||||
provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
|
||||
extraction_strategy=LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o",
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
schema=PageSummary.model_json_schema(),
|
||||
extraction_type="schema",
|
||||
apply_chunking =False,
|
||||
instruction="From the crawled content, extract the following details: "\
|
||||
"1. Title of the page "\
|
||||
"2. Summary of the page, which is a detailed summary "\
|
||||
"3. Brief summary of the page, which is a paragraph text "\
|
||||
"4. Keywords assigned to the page, which is a list of keywords. "\
|
||||
'The extracted JSON format should look like this: '\
|
||||
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }'
|
||||
apply_chunking=False,
|
||||
instruction="From the crawled content, extract the following details: "
|
||||
"1. Title of the page "
|
||||
"2. Summary of the page, which is a detailed summary "
|
||||
"3. Brief summary of the page, which is a paragraph text "
|
||||
"4. Keywords assigned to the page, which is a list of keywords. "
|
||||
"The extracted JSON format should look like this: "
|
||||
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }',
|
||||
),
|
||||
bypass_cache=True,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os, sys
|
||||
|
||||
# append the parent directory to the sys.path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -13,19 +14,18 @@ import json
|
||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
||||
|
||||
|
||||
# 1. File Download Processing Example
|
||||
async def download_example():
|
||||
"""Example of downloading files from Python.org"""
|
||||
# downloads_path = os.path.join(os.getcwd(), "downloads")
|
||||
downloads_path = os.path.join(Path.home(), ".crawl4ai", "downloads")
|
||||
os.makedirs(downloads_path, exist_ok=True)
|
||||
|
||||
|
||||
print(f"Downloads will be saved to: {downloads_path}")
|
||||
|
||||
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
downloads_path=downloads_path,
|
||||
verbose=True
|
||||
accept_downloads=True, downloads_path=downloads_path, verbose=True
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
@@ -40,9 +40,9 @@ async def download_example():
|
||||
}
|
||||
""",
|
||||
delay_before_return_html=1, # Wait 5 seconds to ensure download starts
|
||||
cache_mode=CacheMode.BYPASS
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
|
||||
if result.downloaded_files:
|
||||
print("\nDownload successful!")
|
||||
print("Downloaded files:")
|
||||
@@ -52,25 +52,26 @@ async def download_example():
|
||||
else:
|
||||
print("\nNo files were downloaded")
|
||||
|
||||
|
||||
# 2. Local File and Raw HTML Processing Example
|
||||
async def local_and_raw_html_example():
|
||||
"""Example of processing local files and raw HTML"""
|
||||
# Create a sample HTML file
|
||||
sample_file = os.path.join(__data__, "sample.html")
|
||||
with open(sample_file, "w") as f:
|
||||
f.write("""
|
||||
f.write(
|
||||
"""
|
||||
<html><body>
|
||||
<h1>Test Content</h1>
|
||||
<p>This is a test paragraph.</p>
|
||||
</body></html>
|
||||
""")
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# Process local file
|
||||
local_result = await crawler.arun(
|
||||
url=f"file://{os.path.abspath(sample_file)}"
|
||||
)
|
||||
|
||||
local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}")
|
||||
|
||||
# Process raw HTML
|
||||
raw_html = """
|
||||
<html><body>
|
||||
@@ -78,16 +79,15 @@ async def local_and_raw_html_example():
|
||||
<p>This is a test of raw HTML processing.</p>
|
||||
</body></html>
|
||||
"""
|
||||
raw_result = await crawler.arun(
|
||||
url=f"raw:{raw_html}"
|
||||
)
|
||||
|
||||
raw_result = await crawler.arun(url=f"raw:{raw_html}")
|
||||
|
||||
# Clean up
|
||||
os.remove(sample_file)
|
||||
|
||||
|
||||
print("Local file content:", local_result.markdown)
|
||||
print("\nRaw HTML content:", raw_result.markdown)
|
||||
|
||||
|
||||
# 3. Enhanced Markdown Generation Example
|
||||
async def markdown_generation_example():
|
||||
"""Example of enhanced markdown generation with citations and LLM-friendly features"""
|
||||
@@ -97,58 +97,66 @@ async def markdown_generation_example():
|
||||
# user_query="History and cultivation",
|
||||
bm25_threshold=1.0
|
||||
)
|
||||
|
||||
|
||||
result = await crawler.arun(
|
||||
url="https://en.wikipedia.org/wiki/Apple",
|
||||
css_selector="main div#bodyContent",
|
||||
content_filter=content_filter,
|
||||
cache_mode=CacheMode.BYPASS
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
|
||||
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
||||
|
||||
|
||||
result = await crawler.arun(
|
||||
url="https://en.wikipedia.org/wiki/Apple",
|
||||
css_selector="main div#bodyContent",
|
||||
content_filter=BM25ContentFilter()
|
||||
content_filter=BM25ContentFilter(),
|
||||
)
|
||||
print(result.markdown_v2.fit_markdown)
|
||||
|
||||
|
||||
print("\nMarkdown Generation Results:")
|
||||
print(f"1. Original markdown length: {len(result.markdown)}")
|
||||
print(f"2. New markdown versions (markdown_v2):")
|
||||
print("2. New markdown versions (markdown_v2):")
|
||||
print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}")
|
||||
print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}")
|
||||
print(f" - References section length: {len(result.markdown_v2.references_markdown)}")
|
||||
print(
|
||||
f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}"
|
||||
)
|
||||
print(
|
||||
f" - References section length: {len(result.markdown_v2.references_markdown)}"
|
||||
)
|
||||
if result.markdown_v2.fit_markdown:
|
||||
print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}")
|
||||
|
||||
print(
|
||||
f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}"
|
||||
)
|
||||
|
||||
# Save examples to files
|
||||
output_dir = os.path.join(__data__, "markdown_examples")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# Save different versions
|
||||
with open(os.path.join(output_dir, "1_raw_markdown.md"), "w") as f:
|
||||
f.write(result.markdown_v2.raw_markdown)
|
||||
|
||||
|
||||
with open(os.path.join(output_dir, "2_citations_markdown.md"), "w") as f:
|
||||
f.write(result.markdown_v2.markdown_with_citations)
|
||||
|
||||
|
||||
with open(os.path.join(output_dir, "3_references.md"), "w") as f:
|
||||
f.write(result.markdown_v2.references_markdown)
|
||||
|
||||
|
||||
if result.markdown_v2.fit_markdown:
|
||||
with open(os.path.join(output_dir, "4_filtered_markdown.md"), "w") as f:
|
||||
f.write(result.markdown_v2.fit_markdown)
|
||||
|
||||
|
||||
print(f"\nMarkdown examples saved to: {output_dir}")
|
||||
|
||||
|
||||
# Show a sample of citations and references
|
||||
print("\nSample of markdown with citations:")
|
||||
print(result.markdown_v2.markdown_with_citations[:500] + "...\n")
|
||||
print("Sample of references:")
|
||||
print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...")
|
||||
print(
|
||||
"\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..."
|
||||
)
|
||||
|
||||
|
||||
# 4. Browser Management Example
|
||||
async def browser_management_example():
|
||||
@@ -156,38 +164,38 @@ async def browser_management_example():
|
||||
# Use the specified user directory path
|
||||
user_data_dir = os.path.join(Path.home(), ".crawl4ai", "browser_profile")
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
|
||||
|
||||
print(f"Browser profile will be saved to: {user_data_dir}")
|
||||
|
||||
|
||||
async with AsyncWebCrawler(
|
||||
use_managed_browser=True,
|
||||
user_data_dir=user_data_dir,
|
||||
headless=False,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
) as crawler:
|
||||
|
||||
result = await crawler.arun(
|
||||
url="https://crawl4ai.com",
|
||||
# session_id="persistent_session_1",
|
||||
cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
# Use GitHub as an example - it's a good test for browser management
|
||||
# because it requires proper browser handling
|
||||
result = await crawler.arun(
|
||||
url="https://github.com/trending",
|
||||
# session_id="persistent_session_1",
|
||||
cache_mode=CacheMode.BYPASS
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
|
||||
print("\nBrowser session result:", result.success)
|
||||
if result.success:
|
||||
print("Page title:", result.metadata.get('title', 'No title found'))
|
||||
print("Page title:", result.metadata.get("title", "No title found"))
|
||||
|
||||
|
||||
# 5. API Usage Example
|
||||
async def api_example():
|
||||
"""Example of using the new API endpoints"""
|
||||
api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code"
|
||||
headers = {'Authorization': f'Bearer {api_token}'}
|
||||
api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Submit crawl job
|
||||
crawl_request = {
|
||||
@@ -199,25 +207,17 @@ async def api_example():
|
||||
"name": "Hacker News Articles",
|
||||
"baseSelector": ".athing",
|
||||
"fields": [
|
||||
{
|
||||
"name": "title",
|
||||
"selector": ".title a",
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "score",
|
||||
"selector": ".score",
|
||||
"type": "text"
|
||||
},
|
||||
{"name": "title", "selector": ".title a", "type": "text"},
|
||||
{"name": "score", "selector": ".score", "type": "text"},
|
||||
{
|
||||
"name": "url",
|
||||
"selector": ".title a",
|
||||
"type": "attribute",
|
||||
"attribute": "href"
|
||||
}
|
||||
]
|
||||
"attribute": "href",
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
"crawler_params": {
|
||||
"headless": True,
|
||||
@@ -227,51 +227,50 @@ async def api_example():
|
||||
# "screenshot": True,
|
||||
# "magic": True
|
||||
}
|
||||
|
||||
|
||||
async with session.post(
|
||||
"http://localhost:11235/crawl",
|
||||
json=crawl_request,
|
||||
headers=headers
|
||||
"http://localhost:11235/crawl", json=crawl_request, headers=headers
|
||||
) as response:
|
||||
task_data = await response.json()
|
||||
task_id = task_data["task_id"]
|
||||
|
||||
|
||||
# Check task status
|
||||
while True:
|
||||
async with session.get(
|
||||
f"http://localhost:11235/task/{task_id}",
|
||||
headers=headers
|
||||
f"http://localhost:11235/task/{task_id}", headers=headers
|
||||
) as status_response:
|
||||
result = await status_response.json()
|
||||
print(f"Task status: {result['status']}")
|
||||
|
||||
|
||||
if result["status"] == "completed":
|
||||
print("Task completed!")
|
||||
print("Results:")
|
||||
news = json.loads(result["results"][0]['extracted_content'])
|
||||
news = json.loads(result["results"][0]["extracted_content"])
|
||||
print(json.dumps(news[:4], indent=2))
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# Main execution
|
||||
async def main():
|
||||
# print("Running Crawl4AI feature examples...")
|
||||
|
||||
|
||||
# print("\n1. Running Download Example:")
|
||||
# await download_example()
|
||||
|
||||
|
||||
# print("\n2. Running Markdown Generation Example:")
|
||||
# await markdown_generation_example()
|
||||
|
||||
|
||||
# # print("\n3. Running Local and Raw HTML Example:")
|
||||
# await local_and_raw_html_example()
|
||||
|
||||
|
||||
# # print("\n4. Running Browser Management Example:")
|
||||
await browser_management_example()
|
||||
|
||||
|
||||
# print("\n5. Running API Example:")
|
||||
await api_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -10,18 +10,17 @@ import asyncio
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
CacheMode,
|
||||
LLMExtractionStrategy,
|
||||
JsonCssExtractionStrategy
|
||||
JsonCssExtractionStrategy,
|
||||
)
|
||||
from crawl4ai.content_filter_strategy import RelevantContentFilter
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
# Sample HTML for demonstrations
|
||||
@@ -52,17 +51,18 @@ SAMPLE_HTML = """
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
async def demo_ssl_features():
|
||||
"""
|
||||
Enhanced SSL & Security Features Demo
|
||||
-----------------------------------
|
||||
|
||||
|
||||
This example demonstrates the new SSL certificate handling and security features:
|
||||
1. Custom certificate paths
|
||||
2. SSL verification options
|
||||
3. HTTPS error handling
|
||||
4. Certificate validation configurations
|
||||
|
||||
|
||||
These features are particularly useful when:
|
||||
- Working with self-signed certificates
|
||||
- Dealing with corporate proxies
|
||||
@@ -76,14 +76,11 @@ async def demo_ssl_features():
|
||||
|
||||
run_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
fetch_ssl_certificate=True # Enable SSL certificate fetching
|
||||
fetch_ssl_certificate=True, # Enable SSL certificate fetching
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://example.com",
|
||||
config=run_config
|
||||
)
|
||||
result = await crawler.arun(url="https://example.com", config=run_config)
|
||||
print(f"SSL Crawl Success: {result.success}")
|
||||
result.ssl_certificate.to_json(
|
||||
os.path.join(os.getcwd(), "ssl_certificate.json")
|
||||
@@ -91,11 +88,12 @@ async def demo_ssl_features():
|
||||
if not result.success:
|
||||
print(f"SSL Error: {result.error_message}")
|
||||
|
||||
|
||||
async def demo_content_filtering():
|
||||
"""
|
||||
Smart Content Filtering Demo
|
||||
----------------------
|
||||
|
||||
|
||||
Demonstrates advanced content filtering capabilities:
|
||||
1. Custom filter to identify and extract specific content
|
||||
2. Integration with markdown generation
|
||||
@@ -110,87 +108,90 @@ async def demo_content_filtering():
|
||||
super().__init__()
|
||||
# Add news-specific patterns
|
||||
self.negative_patterns = re.compile(
|
||||
r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending',
|
||||
re.I
|
||||
r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending",
|
||||
re.I,
|
||||
)
|
||||
self.min_word_count = 30 # Higher threshold for news content
|
||||
|
||||
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
|
||||
def filter_content(
|
||||
self, html: str, min_word_threshold: int = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Implements news-specific content filtering logic.
|
||||
|
||||
|
||||
Args:
|
||||
html (str): HTML content to be filtered
|
||||
min_word_threshold (int, optional): Minimum word count threshold
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: List of filtered HTML content blocks
|
||||
"""
|
||||
if not html or not isinstance(html, str):
|
||||
return []
|
||||
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
if not soup.body:
|
||||
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
|
||||
|
||||
body = soup.find('body')
|
||||
|
||||
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
|
||||
|
||||
body = soup.find("body")
|
||||
|
||||
# Extract chunks with metadata
|
||||
chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count)
|
||||
|
||||
chunks = self.extract_text_chunks(
|
||||
body, min_word_threshold or self.min_word_count
|
||||
)
|
||||
|
||||
# Filter chunks based on news-specific criteria
|
||||
filtered_chunks = []
|
||||
for _, text, tag_type, element in chunks:
|
||||
# Skip if element has negative class/id
|
||||
if self.is_excluded(element):
|
||||
continue
|
||||
|
||||
|
||||
# Headers are important in news articles
|
||||
if tag_type == 'header':
|
||||
if tag_type == "header":
|
||||
filtered_chunks.append(self.clean_element(element))
|
||||
continue
|
||||
|
||||
|
||||
# For content, check word count and link density
|
||||
text = element.get_text(strip=True)
|
||||
if len(text.split()) >= (min_word_threshold or self.min_word_count):
|
||||
# Calculate link density
|
||||
links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a'))
|
||||
links_text = " ".join(
|
||||
a.get_text(strip=True) for a in element.find_all("a")
|
||||
)
|
||||
link_density = len(links_text) / len(text) if text else 1
|
||||
|
||||
|
||||
# Accept if link density is reasonable
|
||||
if link_density < 0.5:
|
||||
filtered_chunks.append(self.clean_element(element))
|
||||
|
||||
|
||||
return filtered_chunks
|
||||
|
||||
# Create markdown generator with custom filter
|
||||
markdown_gen = DefaultMarkdownGenerator(
|
||||
content_filter=CustomNewsFilter()
|
||||
)
|
||||
markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter())
|
||||
|
||||
run_config = CrawlerRunConfig(
|
||||
markdown_generator=markdown_gen,
|
||||
cache_mode=CacheMode.BYPASS
|
||||
markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://news.ycombinator.com",
|
||||
config=run_config
|
||||
url="https://news.ycombinator.com", config=run_config
|
||||
)
|
||||
print("Filtered Content Sample:")
|
||||
print(result.markdown[:500]) # Show first 500 chars
|
||||
|
||||
|
||||
async def demo_json_extraction():
|
||||
"""
|
||||
Improved JSON Extraction Demo
|
||||
---------------------------
|
||||
|
||||
|
||||
Demonstrates the enhanced JSON extraction capabilities:
|
||||
1. Base element attributes extraction
|
||||
2. Complex nested structures
|
||||
3. Multiple extraction patterns
|
||||
|
||||
|
||||
Key features shown:
|
||||
- Extracting attributes from base elements (href, data-* attributes)
|
||||
- Processing repeated patterns
|
||||
@@ -206,7 +207,7 @@ async def demo_json_extraction():
|
||||
"baseSelector": "div.article-list",
|
||||
"baseFields": [
|
||||
{"name": "list_id", "type": "attribute", "attribute": "data-list-id"},
|
||||
{"name": "category", "type": "attribute", "attribute": "data-category"}
|
||||
{"name": "category", "type": "attribute", "attribute": "data-category"},
|
||||
],
|
||||
"fields": [
|
||||
{
|
||||
@@ -214,8 +215,16 @@ async def demo_json_extraction():
|
||||
"selector": "article.post",
|
||||
"type": "nested_list",
|
||||
"baseFields": [
|
||||
{"name": "post_id", "type": "attribute", "attribute": "data-post-id"},
|
||||
{"name": "author_id", "type": "attribute", "attribute": "data-author"}
|
||||
{
|
||||
"name": "post_id",
|
||||
"type": "attribute",
|
||||
"attribute": "data-post-id",
|
||||
},
|
||||
{
|
||||
"name": "author_id",
|
||||
"type": "attribute",
|
||||
"attribute": "data-author",
|
||||
},
|
||||
],
|
||||
"fields": [
|
||||
{
|
||||
@@ -223,60 +232,68 @@ async def demo_json_extraction():
|
||||
"selector": "h2.title a",
|
||||
"type": "text",
|
||||
"baseFields": [
|
||||
{"name": "url", "type": "attribute", "attribute": "href"}
|
||||
]
|
||||
{
|
||||
"name": "url",
|
||||
"type": "attribute",
|
||||
"attribute": "href",
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "author",
|
||||
"selector": "div.meta a.author",
|
||||
"type": "text",
|
||||
"baseFields": [
|
||||
{"name": "profile_url", "type": "attribute", "attribute": "href"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "date",
|
||||
"selector": "span.date",
|
||||
"type": "text"
|
||||
{
|
||||
"name": "profile_url",
|
||||
"type": "attribute",
|
||||
"attribute": "href",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"name": "date", "selector": "span.date", "type": "text"},
|
||||
{
|
||||
"name": "read_more",
|
||||
"selector": "a.read-more",
|
||||
"type": "nested",
|
||||
"fields": [
|
||||
{"name": "text", "type": "text"},
|
||||
{"name": "url", "type": "attribute", "attribute": "href"}
|
||||
]
|
||||
}
|
||||
]
|
||||
{
|
||||
"name": "url",
|
||||
"type": "attribute",
|
||||
"attribute": "href",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Demonstrate extraction from raw HTML
|
||||
run_config = CrawlerRunConfig(
|
||||
extraction_strategy=json_strategy,
|
||||
cache_mode=CacheMode.BYPASS
|
||||
extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML
|
||||
config=run_config
|
||||
config=run_config,
|
||||
)
|
||||
print("Extracted Content:")
|
||||
print(result.extracted_content)
|
||||
|
||||
|
||||
async def demo_input_formats():
|
||||
"""
|
||||
Input Format Handling Demo
|
||||
----------------------
|
||||
|
||||
|
||||
Demonstrates how LLM extraction can work with different input formats:
|
||||
1. Markdown (default) - Good for simple text extraction
|
||||
2. HTML - Better when you need structure and attributes
|
||||
|
||||
|
||||
This example shows how HTML input can be beneficial when:
|
||||
- You need to understand the DOM structure
|
||||
- You want to extract both visible text and HTML attributes
|
||||
@@ -350,7 +367,7 @@ async def demo_input_formats():
|
||||
</footer>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
# Use raw:// prefix to pass HTML content directly
|
||||
url = f"raw://{dummy_html}"
|
||||
|
||||
@@ -359,18 +376,30 @@ async def demo_input_formats():
|
||||
|
||||
# Define our schema using Pydantic
|
||||
class JobRequirement(BaseModel):
|
||||
category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)")
|
||||
items: List[str] = Field(description="List of specific requirements in this category")
|
||||
priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context")
|
||||
category: str = Field(
|
||||
description="Category of the requirement (e.g., Technical, Soft Skills)"
|
||||
)
|
||||
items: List[str] = Field(
|
||||
description="List of specific requirements in this category"
|
||||
)
|
||||
priority: str = Field(
|
||||
description="Priority level (Required/Preferred) based on the HTML class or context"
|
||||
)
|
||||
|
||||
class JobPosting(BaseModel):
|
||||
title: str = Field(description="Job title")
|
||||
department: str = Field(description="Department or team")
|
||||
location: str = Field(description="Job location, including remote options")
|
||||
salary_range: Optional[str] = Field(description="Salary range if specified")
|
||||
requirements: List[JobRequirement] = Field(description="Categorized job requirements")
|
||||
application_deadline: Optional[str] = Field(description="Application deadline if specified")
|
||||
contact_info: Optional[dict] = Field(description="Contact information from footer or contact section")
|
||||
requirements: List[JobRequirement] = Field(
|
||||
description="Categorized job requirements"
|
||||
)
|
||||
application_deadline: Optional[str] = Field(
|
||||
description="Application deadline if specified"
|
||||
)
|
||||
contact_info: Optional[dict] = Field(
|
||||
description="Contact information from footer or contact section"
|
||||
)
|
||||
|
||||
# First try with markdown (default)
|
||||
markdown_strategy = LLMExtractionStrategy(
|
||||
@@ -382,7 +411,7 @@ async def demo_input_formats():
|
||||
Extract job posting details into structured data. Focus on the visible text content
|
||||
and organize requirements into categories.
|
||||
""",
|
||||
input_format="markdown" # default
|
||||
input_format="markdown", # default
|
||||
)
|
||||
|
||||
# Then with HTML for better structure understanding
|
||||
@@ -400,34 +429,25 @@ async def demo_input_formats():
|
||||
|
||||
Use HTML attributes and classes to enhance extraction accuracy.
|
||||
""",
|
||||
input_format="html" # explicitly use HTML
|
||||
input_format="html", # explicitly use HTML
|
||||
)
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
# Try with markdown first
|
||||
markdown_config = CrawlerRunConfig(
|
||||
extraction_strategy=markdown_strategy
|
||||
)
|
||||
markdown_result = await crawler.arun(
|
||||
url=url,
|
||||
config=markdown_config
|
||||
)
|
||||
markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy)
|
||||
markdown_result = await crawler.arun(url=url, config=markdown_config)
|
||||
print("\nMarkdown-based Extraction Result:")
|
||||
items = json.loads(markdown_result.extracted_content)
|
||||
print(json.dumps(items, indent=2))
|
||||
|
||||
# Then with HTML for better structure understanding
|
||||
html_config = CrawlerRunConfig(
|
||||
extraction_strategy=html_strategy
|
||||
)
|
||||
html_result = await crawler.arun(
|
||||
url=url,
|
||||
config=html_config
|
||||
)
|
||||
html_config = CrawlerRunConfig(extraction_strategy=html_strategy)
|
||||
html_result = await crawler.arun(url=url, config=html_config)
|
||||
print("\nHTML-based Extraction Result:")
|
||||
items = json.loads(html_result.extracted_content)
|
||||
print(json.dumps(items, indent=2))
|
||||
|
||||
|
||||
# Main execution
|
||||
async def main():
|
||||
print("Crawl4AI v0.4.24 Feature Walkthrough")
|
||||
@@ -439,5 +459,6 @@ async def main():
|
||||
await demo_json_extraction()
|
||||
# await demo_input_formats()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
106
main.py
106
main.py
@@ -1,14 +1,9 @@
|
||||
import asyncio, os
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import Depends, Security
|
||||
@@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union
|
||||
import psutil
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from urllib.parse import urlparse
|
||||
import math
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
|
||||
from crawl4ai.config import MIN_WORD_THRESHOLD
|
||||
from crawl4ai.extraction_strategy import (
|
||||
@@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class CrawlerType(str, Enum):
|
||||
BASIC = "basic"
|
||||
LLM = "llm"
|
||||
COSINE = "cosine"
|
||||
JSON_CSS = "json_css"
|
||||
|
||||
|
||||
class ExtractionConfig(BaseModel):
|
||||
type: CrawlerType
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ChunkingStrategy(BaseModel):
|
||||
type: str
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class ContentFilter(BaseModel):
|
||||
type: str = "bm25"
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class CrawlRequest(BaseModel):
|
||||
urls: Union[HttpUrl, List[HttpUrl]]
|
||||
word_count_threshold: int = MIN_WORD_THRESHOLD
|
||||
@@ -77,9 +75,10 @@ class CrawlRequest(BaseModel):
|
||||
session_id: Optional[str] = None
|
||||
cache_mode: Optional[CacheMode] = CacheMode.ENABLED
|
||||
priority: int = Field(default=5, ge=1, le=10)
|
||||
ttl: Optional[int] = 3600
|
||||
ttl: Optional[int] = 3600
|
||||
crawler_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
id: str
|
||||
@@ -89,6 +88,7 @@ class TaskInfo:
|
||||
created_at: float = time.time()
|
||||
ttl: int = 3600
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
@@ -106,7 +106,9 @@ class ResourceMonitor:
|
||||
mem_usage = psutil.virtual_memory().percent / 100
|
||||
cpu_usage = psutil.cpu_percent() / 100
|
||||
|
||||
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold)
|
||||
memory_factor = max(
|
||||
0, (self.memory_threshold - mem_usage) / self.memory_threshold
|
||||
)
|
||||
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
|
||||
|
||||
self._last_available_slots = math.floor(
|
||||
@@ -116,6 +118,7 @@ class ResourceMonitor:
|
||||
|
||||
return self._last_available_slots
|
||||
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, cleanup_interval: int = 300):
|
||||
self.tasks: Dict[str, TaskInfo] = {}
|
||||
@@ -149,12 +152,16 @@ class TaskManager:
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
# Then try low priority
|
||||
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1)
|
||||
_, task_id = await asyncio.wait_for(
|
||||
self.low_priority.get(), timeout=0.1
|
||||
)
|
||||
return task_id
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None):
|
||||
def update_task(
|
||||
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
|
||||
):
|
||||
if task_id in self.tasks:
|
||||
task_info = self.tasks[task_id]
|
||||
task_info.status = status
|
||||
@@ -180,6 +187,7 @@ class TaskManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
|
||||
class CrawlerPool:
|
||||
def __init__(self, max_size: int = 10):
|
||||
self.max_size = max_size
|
||||
@@ -222,6 +230,7 @@ class CrawlerPool:
|
||||
await crawler.__aexit__(None, None, None)
|
||||
self.active_crawlers.clear()
|
||||
|
||||
|
||||
class CrawlerService:
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
|
||||
@@ -258,10 +267,10 @@ class CrawlerService:
|
||||
async def submit_task(self, request: CrawlRequest) -> str:
|
||||
task_id = str(uuid.uuid4())
|
||||
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600)
|
||||
|
||||
|
||||
# Store request data with task
|
||||
self.task_manager.tasks[task_id].request = request
|
||||
|
||||
|
||||
return task_id
|
||||
|
||||
async def _process_queue(self):
|
||||
@@ -286,9 +295,11 @@ class CrawlerService:
|
||||
|
||||
try:
|
||||
crawler = await self.crawler_pool.acquire(**request.crawler_params)
|
||||
|
||||
extraction_strategy = self._create_extraction_strategy(request.extraction_config)
|
||||
|
||||
|
||||
extraction_strategy = self._create_extraction_strategy(
|
||||
request.extraction_config
|
||||
)
|
||||
|
||||
if isinstance(request.urls, list):
|
||||
results = await crawler.arun_many(
|
||||
urls=[str(url) for url in request.urls],
|
||||
@@ -318,16 +329,21 @@ class CrawlerService:
|
||||
)
|
||||
|
||||
await self.crawler_pool.release(crawler)
|
||||
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results)
|
||||
self.task_manager.update_task(
|
||||
task_id, TaskStatus.COMPLETED, results
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e))
|
||||
self.task_manager.update_task(
|
||||
task_id, TaskStatus.FAILED, error=str(e)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in queue processing: {str(e)}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
app = FastAPI(title="Crawl4AI API")
|
||||
|
||||
# CORS configuration
|
||||
@@ -344,6 +360,7 @@ app.add_middleware(
|
||||
security = HTTPBearer()
|
||||
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
|
||||
|
||||
|
||||
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
if not CRAWL4AI_API_TOKEN:
|
||||
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
|
||||
@@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return credentials
|
||||
|
||||
|
||||
def secure_endpoint():
|
||||
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
|
||||
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
|
||||
|
||||
|
||||
# Check if site directory exists
|
||||
if os.path.exists(__location__ + "/site"):
|
||||
# Mount the site directory as a static directory
|
||||
@@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site")
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
await crawler_service.start()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await crawler_service.stop()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
if os.path.exists(__location__ + "/site"):
|
||||
@@ -379,12 +401,16 @@ def read_root():
|
||||
# Return a json response
|
||||
return {"message": "Crawl4AI API service is running"}
|
||||
|
||||
|
||||
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
async def crawl(request: CrawlRequest) -> Dict[str, str]:
|
||||
task_id = await crawler_service.submit_task(request)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
|
||||
@app.get(
|
||||
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
|
||||
)
|
||||
async def get_task_status(task_id: str):
|
||||
task_info = crawler_service.task_manager.get_task(task_id)
|
||||
if not task_info:
|
||||
@@ -406,36 +432,45 @@ async def get_task_status(task_id: str):
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
|
||||
task_id = await crawler_service.submit_task(request)
|
||||
|
||||
|
||||
# Wait up to 60 seconds for task completion
|
||||
for _ in range(60):
|
||||
task_info = crawler_service.task_manager.get_task(task_id)
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
if task_info.status == TaskStatus.COMPLETED:
|
||||
# Return same format as /task/{task_id} endpoint
|
||||
if isinstance(task_info.result, list):
|
||||
return {"status": task_info.status, "results": [result.dict() for result in task_info.result]}
|
||||
return {
|
||||
"status": task_info.status,
|
||||
"results": [result.dict() for result in task_info.result],
|
||||
}
|
||||
return {"status": task_info.status, "result": task_info.result.dict()}
|
||||
|
||||
|
||||
if task_info.status == TaskStatus.FAILED:
|
||||
raise HTTPException(status_code=500, detail=task_info.error)
|
||||
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# If we get here, task didn't complete within timeout
|
||||
raise HTTPException(status_code=408, detail="Task timed out")
|
||||
|
||||
@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||
|
||||
@app.post(
|
||||
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
|
||||
)
|
||||
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
||||
try:
|
||||
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
|
||||
extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config)
|
||||
|
||||
extraction_strategy = crawler_service._create_extraction_strategy(
|
||||
request.extraction_config
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(request.urls, list):
|
||||
results = await crawler.arun_many(
|
||||
@@ -470,7 +505,8 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in direct crawl: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
available_slots = await crawler_service.resource_monitor.get_available_slots()
|
||||
@@ -482,6 +518,8 @@ async def health_check():
|
||||
"cpu_usage": psutil.cpu_percent(),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=11235)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=11235)
|
||||
|
||||
4
setup.py
4
setup.py
@@ -51,9 +51,7 @@ setup(
|
||||
author_email="unclecode@kidocode.com",
|
||||
license="MIT",
|
||||
packages=find_packages(),
|
||||
package_data={
|
||||
'crawl4ai': ['js_snippet/*.js']
|
||||
},
|
||||
package_data={"crawl4ai": ["js_snippet/*.js"]},
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import os, sys
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
|
||||
# Assuming that the changes made allow different configurations
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
|
||||
# Assuming that the changes made allow different configurations
|
||||
# for managed browser, persistent context, and so forth.
|
||||
|
||||
|
||||
async def test_default_headless():
|
||||
async with AsyncWebCrawler(
|
||||
headless=True,
|
||||
@@ -24,13 +25,14 @@ async def test_default_headless():
|
||||
# Testing normal ephemeral context
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://www.kidocode.com/degrees/technology',
|
||||
url="https://www.kidocode.com/degrees/technology",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_default_headless] success:", result.success)
|
||||
print("HTML length:", len(result.html if result.html else ""))
|
||||
|
||||
|
||||
|
||||
async def test_managed_browser_persistent():
|
||||
# Treating use_persistent_context=True as managed_browser scenario.
|
||||
async with AsyncWebCrawler(
|
||||
@@ -44,13 +46,14 @@ async def test_managed_browser_persistent():
|
||||
# This should store and reuse profile data across runs
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://www.google.com',
|
||||
url="https://www.google.com",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_managed_browser_persistent] success:", result.success)
|
||||
print("HTML length:", len(result.html if result.html else ""))
|
||||
|
||||
|
||||
async def test_session_reuse():
|
||||
# Test creating a session, using it for multiple calls
|
||||
session_id = "my_session"
|
||||
@@ -62,25 +65,25 @@ async def test_session_reuse():
|
||||
use_managed_browser=False,
|
||||
use_persistent_context=False,
|
||||
) as crawler:
|
||||
|
||||
# First call: create session
|
||||
result1 = await crawler.arun(
|
||||
url='https://www.example.com',
|
||||
url="https://www.example.com",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
session_id=session_id,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_session_reuse first call] success:", result1.success)
|
||||
|
||||
|
||||
# Second call: same session, possibly cookie retained
|
||||
result2 = await crawler.arun(
|
||||
url='https://www.example.com/about',
|
||||
url="https://www.example.com/about",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
session_id=session_id,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_session_reuse second call] success:", result2.success)
|
||||
|
||||
|
||||
async def test_magic_mode():
|
||||
# Test magic mode with override_navigator and simulate_user
|
||||
async with AsyncWebCrawler(
|
||||
@@ -95,13 +98,14 @@ async def test_magic_mode():
|
||||
simulate_user=True,
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://www.kidocode.com/degrees/business',
|
||||
url="https://www.kidocode.com/degrees/business",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_magic_mode] success:", result.success)
|
||||
print("HTML length:", len(result.html if result.html else ""))
|
||||
|
||||
|
||||
async def test_proxy_settings():
|
||||
# Test with a proxy (if available) to ensure code runs with proxy
|
||||
async with AsyncWebCrawler(
|
||||
@@ -113,14 +117,15 @@ async def test_proxy_settings():
|
||||
use_persistent_context=False,
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://httpbin.org/ip',
|
||||
url="https://httpbin.org/ip",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_proxy_settings] success:", result.success)
|
||||
if result.success:
|
||||
print("HTML preview:", result.html[:200] if result.html else "")
|
||||
|
||||
|
||||
async def test_ignore_https_errors():
|
||||
# Test ignore HTTPS errors with a self-signed or invalid cert domain
|
||||
# This is just conceptual, the domain should be one that triggers SSL error.
|
||||
@@ -134,12 +139,13 @@ async def test_ignore_https_errors():
|
||||
use_persistent_context=False,
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url='https://self-signed.badssl.com/',
|
||||
url="https://self-signed.badssl.com/",
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||
)
|
||||
print("[test_ignore_https_errors] success:", result.success)
|
||||
|
||||
|
||||
async def main():
|
||||
print("Running tests...")
|
||||
# await test_default_headless()
|
||||
@@ -149,5 +155,6 @@ async def main():
|
||||
# await test_proxy_settings()
|
||||
await test_ignore_https_errors()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import os, sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
||||
from crawl4ai.chunking_strategy import RegexChunking
|
||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
|
||||
|
||||
# Category 1: Browser Configuration Tests
|
||||
async def test_browser_config_object():
|
||||
@@ -21,29 +22,31 @@ async def test_browser_config_object():
|
||||
viewport_height=1080,
|
||||
use_managed_browser=True,
|
||||
user_agent_mode="random",
|
||||
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}
|
||||
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"},
|
||||
)
|
||||
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler:
|
||||
result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS)
|
||||
result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS)
|
||||
assert result.success, "Browser config crawl failed"
|
||||
assert len(result.html) > 0, "No HTML content retrieved"
|
||||
|
||||
|
||||
async def test_browser_performance_config():
|
||||
"""Test browser configurations focused on performance"""
|
||||
browser_config = BrowserConfig(
|
||||
text_mode=True,
|
||||
light_mode=True,
|
||||
extra_args=['--disable-gpu', '--disable-software-rasterizer'],
|
||||
extra_args=["--disable-gpu", "--disable-software-rasterizer"],
|
||||
ignore_https_errors=True,
|
||||
java_script_enabled=False
|
||||
java_script_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
result = await crawler.arun('https://example.com')
|
||||
result = await crawler.arun("https://example.com")
|
||||
assert result.success, "Performance optimized crawl failed"
|
||||
assert result.status_code == 200, "Unexpected status code"
|
||||
|
||||
|
||||
# Category 2: Content Processing Tests
|
||||
async def test_content_extraction_config():
|
||||
"""Test content extraction with various strategies"""
|
||||
@@ -53,24 +56,20 @@ async def test_content_extraction_config():
|
||||
schema={
|
||||
"name": "article",
|
||||
"baseSelector": "div",
|
||||
"fields": [{
|
||||
"name": "title",
|
||||
"selector": "h1",
|
||||
"type": "text"
|
||||
}]
|
||||
"fields": [{"name": "title", "selector": "h1", "type": "text"}],
|
||||
}
|
||||
),
|
||||
chunking_strategy=RegexChunking(),
|
||||
content_filter=PruningContentFilter()
|
||||
content_filter=PruningContentFilter(),
|
||||
)
|
||||
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
'https://example.com/article',
|
||||
config=crawler_config
|
||||
"https://example.com/article", config=crawler_config
|
||||
)
|
||||
assert result.extracted_content is not None, "Content extraction failed"
|
||||
assert 'title' in result.extracted_content, "Missing expected content field"
|
||||
assert "title" in result.extracted_content, "Missing expected content field"
|
||||
|
||||
|
||||
# Category 3: Cache and Session Management Tests
|
||||
async def test_cache_and_session_management():
|
||||
@@ -79,25 +78,20 @@ async def test_cache_and_session_management():
|
||||
crawler_config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.WRITE_ONLY,
|
||||
process_iframes=True,
|
||||
remove_overlay_elements=True
|
||||
remove_overlay_elements=True,
|
||||
)
|
||||
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
# First request - should write to cache
|
||||
result1 = await crawler.arun(
|
||||
'https://example.com',
|
||||
config=crawler_config
|
||||
)
|
||||
|
||||
result1 = await crawler.arun("https://example.com", config=crawler_config)
|
||||
|
||||
# Second request - should use fresh fetch due to WRITE_ONLY mode
|
||||
result2 = await crawler.arun(
|
||||
'https://example.com',
|
||||
config=crawler_config
|
||||
)
|
||||
|
||||
result2 = await crawler.arun("https://example.com", config=crawler_config)
|
||||
|
||||
assert result1.success and result2.success, "Cache mode crawl failed"
|
||||
assert result1.html == result2.html, "Inconsistent results between requests"
|
||||
|
||||
|
||||
# Category 4: Media Handling Tests
|
||||
async def test_media_handling_config():
|
||||
"""Test configurations related to media handling"""
|
||||
@@ -107,24 +101,22 @@ async def test_media_handling_config():
|
||||
viewport_width=1920,
|
||||
viewport_height=1080,
|
||||
accept_downloads=True,
|
||||
downloads_path= os.path.expanduser("~/.crawl4ai/downloads")
|
||||
downloads_path=os.path.expanduser("~/.crawl4ai/downloads"),
|
||||
)
|
||||
crawler_config = CrawlerRunConfig(
|
||||
screenshot=True,
|
||||
pdf=True,
|
||||
adjust_viewport_to_content=True,
|
||||
wait_for_images=True,
|
||||
screenshot_height_threshold=20000
|
||||
screenshot_height_threshold=20000,
|
||||
)
|
||||
|
||||
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
result = await crawler.arun(
|
||||
'https://example.com',
|
||||
config=crawler_config
|
||||
)
|
||||
result = await crawler.arun("https://example.com", config=crawler_config)
|
||||
assert result.screenshot is not None, "Screenshot capture failed"
|
||||
assert result.pdf is not None, "PDF generation failed"
|
||||
|
||||
|
||||
# Category 5: Anti-Bot and Site Interaction Tests
|
||||
async def test_antibot_config():
|
||||
"""Test configurations for handling anti-bot measures"""
|
||||
@@ -135,76 +127,64 @@ async def test_antibot_config():
|
||||
wait_for="js:()=>document.querySelector('body')",
|
||||
delay_before_return_html=1.0,
|
||||
log_console=True,
|
||||
cache_mode=CacheMode.BYPASS
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
)
|
||||
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
'https://example.com',
|
||||
config=crawler_config
|
||||
)
|
||||
result = await crawler.arun("https://example.com", config=crawler_config)
|
||||
assert result.success, "Anti-bot measure handling failed"
|
||||
|
||||
|
||||
# Category 6: Parallel Processing Tests
|
||||
async def test_parallel_processing():
|
||||
"""Test parallel processing capabilities"""
|
||||
crawler_config = CrawlerRunConfig(
|
||||
mean_delay=0.5,
|
||||
max_range=1.0,
|
||||
semaphore_count=5
|
||||
)
|
||||
|
||||
urls = [
|
||||
'https://example.com/1',
|
||||
'https://example.com/2',
|
||||
'https://example.com/3'
|
||||
]
|
||||
|
||||
crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5)
|
||||
|
||||
urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"]
|
||||
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
results = await crawler.arun_many(
|
||||
urls,
|
||||
config=crawler_config
|
||||
)
|
||||
results = await crawler.arun_many(urls, config=crawler_config)
|
||||
assert len(results) == len(urls), "Not all URLs were processed"
|
||||
assert all(r.success for r in results), "Some parallel requests failed"
|
||||
|
||||
|
||||
# Category 7: Backwards Compatibility Tests
|
||||
async def test_legacy_parameter_support():
|
||||
"""Test that legacy parameters still work"""
|
||||
async with AsyncWebCrawler(
|
||||
headless=True,
|
||||
browser_type="chromium",
|
||||
viewport_width=1024,
|
||||
viewport_height=768
|
||||
headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
'https://example.com',
|
||||
"https://example.com",
|
||||
screenshot=True,
|
||||
word_count_threshold=200,
|
||||
bypass_cache=True,
|
||||
css_selector=".main-content"
|
||||
css_selector=".main-content",
|
||||
)
|
||||
assert result.success, "Legacy parameter support failed"
|
||||
|
||||
|
||||
# Category 8: Mixed Configuration Tests
|
||||
async def test_mixed_config_usage():
|
||||
"""Test mixing new config objects with legacy parameters"""
|
||||
browser_config = BrowserConfig(headless=True)
|
||||
crawler_config = CrawlerRunConfig(screenshot=True)
|
||||
|
||||
|
||||
async with AsyncWebCrawler(
|
||||
config=browser_config,
|
||||
verbose=True # legacy parameter
|
||||
verbose=True, # legacy parameter
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
'https://example.com',
|
||||
"https://example.com",
|
||||
config=crawler_config,
|
||||
cache_mode=CacheMode.BYPASS, # legacy parameter
|
||||
css_selector="body" # legacy parameter
|
||||
css_selector="body", # legacy parameter
|
||||
)
|
||||
assert result.success, "Mixed configuration usage failed"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
async def run_tests():
|
||||
test_functions = [
|
||||
test_browser_config_object,
|
||||
@@ -217,7 +197,7 @@ if __name__ == "__main__":
|
||||
# test_legacy_parameter_support,
|
||||
# test_mixed_config_usage
|
||||
]
|
||||
|
||||
|
||||
for test in test_functions:
|
||||
print(f"\nRunning {test.__name__}...")
|
||||
try:
|
||||
@@ -227,5 +207,5 @@ if __name__ == "__main__":
|
||||
print(f"✗ {test.__name__} failed: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f"✗ {test.__name__} error: {str(e)}")
|
||||
|
||||
asyncio.run(run_tests())
|
||||
|
||||
asyncio.run(run_tests())
|
||||
|
||||
@@ -4,7 +4,6 @@ import asyncio
|
||||
import shutil
|
||||
from typing import List
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -12,28 +11,27 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
class TestDownloads:
|
||||
def __init__(self):
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_")
|
||||
self.download_dir = os.path.join(self.temp_dir, "downloads")
|
||||
os.makedirs(self.download_dir, exist_ok=True)
|
||||
self.results: List[str] = []
|
||||
|
||||
|
||||
def cleanup(self):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
|
||||
def log_result(self, test_name: str, success: bool, message: str = ""):
|
||||
result = f"{'✅' if success else '❌'} {test_name}: {message}"
|
||||
self.results.append(result)
|
||||
print(result)
|
||||
|
||||
|
||||
async def test_basic_download(self):
|
||||
"""Test basic file download functionality"""
|
||||
try:
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
downloads_path=self.download_dir,
|
||||
verbose=True
|
||||
accept_downloads=True, downloads_path=self.download_dir, verbose=True
|
||||
) as crawler:
|
||||
# Python.org downloads page typically has stable download links
|
||||
result = await crawler.arun(
|
||||
@@ -42,14 +40,19 @@ class TestDownloads:
|
||||
// Click first download link
|
||||
const downloadLink = document.querySelector('a[href$=".exe"]');
|
||||
if (downloadLink) downloadLink.click();
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
success = (
|
||||
result.downloaded_files is not None
|
||||
and len(result.downloaded_files) > 0
|
||||
)
|
||||
|
||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
|
||||
self.log_result(
|
||||
"Basic Download",
|
||||
success,
|
||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
|
||||
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||
if success
|
||||
else "No files downloaded",
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_result("Basic Download", False, str(e))
|
||||
@@ -59,27 +62,32 @@ class TestDownloads:
|
||||
try:
|
||||
user_data_dir = os.path.join(self.temp_dir, "user_data")
|
||||
os.makedirs(user_data_dir, exist_ok=True)
|
||||
|
||||
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
downloads_path=self.download_dir,
|
||||
use_persistent_context=True,
|
||||
user_data_dir=user_data_dir,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
js_code="""
|
||||
const downloadLink = document.querySelector('a[href$=".exe"]');
|
||||
if (downloadLink) downloadLink.click();
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
success = (
|
||||
result.downloaded_files is not None
|
||||
and len(result.downloaded_files) > 0
|
||||
)
|
||||
|
||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
|
||||
self.log_result(
|
||||
"Persistent Context Download",
|
||||
success,
|
||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
|
||||
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||
if success
|
||||
else "No files downloaded",
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_result("Persistent Context Download", False, str(e))
|
||||
@@ -88,9 +96,7 @@ class TestDownloads:
|
||||
"""Test multiple simultaneous downloads"""
|
||||
try:
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
downloads_path=self.download_dir,
|
||||
verbose=True
|
||||
accept_downloads=True, downloads_path=self.download_dir, verbose=True
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
@@ -98,14 +104,19 @@ class TestDownloads:
|
||||
// Click multiple download links
|
||||
const downloadLinks = document.querySelectorAll('a[href$=".exe"]');
|
||||
downloadLinks.forEach(link => link.click());
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
success = (
|
||||
result.downloaded_files is not None
|
||||
and len(result.downloaded_files) > 1
|
||||
)
|
||||
|
||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 1
|
||||
self.log_result(
|
||||
"Multiple Downloads",
|
||||
success,
|
||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded"
|
||||
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||
if success
|
||||
else "Not enough files downloaded",
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_result("Multiple Downloads", False, str(e))
|
||||
@@ -113,49 +124,51 @@ class TestDownloads:
|
||||
async def test_different_browsers(self):
|
||||
"""Test downloads across different browser types"""
|
||||
browsers = ["chromium", "firefox", "webkit"]
|
||||
|
||||
|
||||
for browser_type in browsers:
|
||||
try:
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
downloads_path=self.download_dir,
|
||||
browser_type=browser_type,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
js_code="""
|
||||
const downloadLink = document.querySelector('a[href$=".exe"]');
|
||||
if (downloadLink) downloadLink.click();
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
success = (
|
||||
result.downloaded_files is not None
|
||||
and len(result.downloaded_files) > 0
|
||||
)
|
||||
|
||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
|
||||
self.log_result(
|
||||
f"{browser_type.title()} Download",
|
||||
success,
|
||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
|
||||
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||
if success
|
||||
else "No files downloaded",
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_result(f"{browser_type.title()} Download", False, str(e))
|
||||
|
||||
async def test_edge_cases(self):
|
||||
"""Test various edge cases"""
|
||||
|
||||
|
||||
# Test 1: Downloads without specifying download path
|
||||
try:
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
verbose=True
|
||||
) as crawler:
|
||||
async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
js_code="document.querySelector('a[href$=\".exe\"]').click()"
|
||||
js_code="document.querySelector('a[href$=\".exe\"]').click()",
|
||||
)
|
||||
self.log_result(
|
||||
"Default Download Path",
|
||||
True,
|
||||
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}"
|
||||
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}",
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_result("Default Download Path", False, str(e))
|
||||
@@ -165,31 +178,34 @@ class TestDownloads:
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=True,
|
||||
downloads_path="/invalid/path/that/doesnt/exist",
|
||||
verbose=True
|
||||
verbose=True,
|
||||
) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
js_code="document.querySelector('a[href$=\".exe\"]').click()"
|
||||
js_code="document.querySelector('a[href$=\".exe\"]').click()",
|
||||
)
|
||||
self.log_result("Invalid Download Path", False, "Should have raised an error")
|
||||
except Exception as e:
|
||||
self.log_result("Invalid Download Path", True, "Correctly handled invalid path")
|
||||
self.log_result(
|
||||
"Invalid Download Path", False, "Should have raised an error"
|
||||
)
|
||||
except Exception:
|
||||
self.log_result(
|
||||
"Invalid Download Path", True, "Correctly handled invalid path"
|
||||
)
|
||||
|
||||
# Test 3: Download with accept_downloads=False
|
||||
try:
|
||||
async with AsyncWebCrawler(
|
||||
accept_downloads=False,
|
||||
verbose=True
|
||||
) as crawler:
|
||||
async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://www.python.org/downloads/",
|
||||
js_code="document.querySelector('a[href$=\".exe\"]').click()"
|
||||
js_code="document.querySelector('a[href$=\".exe\"]').click()",
|
||||
)
|
||||
success = result.downloaded_files is None
|
||||
self.log_result(
|
||||
"Disabled Downloads",
|
||||
success,
|
||||
"Correctly ignored downloads" if success else "Unexpectedly downloaded files"
|
||||
"Correctly ignored downloads"
|
||||
if success
|
||||
else "Unexpectedly downloaded files",
|
||||
)
|
||||
except Exception as e:
|
||||
self.log_result("Disabled Downloads", False, str(e))
|
||||
@@ -197,33 +213,35 @@ class TestDownloads:
|
||||
async def run_all_tests(self):
|
||||
"""Run all test cases"""
|
||||
print("\n🧪 Running Download Tests...\n")
|
||||
|
||||
|
||||
test_methods = [
|
||||
self.test_basic_download,
|
||||
self.test_persistent_context_download,
|
||||
self.test_multiple_downloads,
|
||||
self.test_different_browsers,
|
||||
self.test_edge_cases
|
||||
self.test_edge_cases,
|
||||
]
|
||||
|
||||
|
||||
for test in test_methods:
|
||||
print(f"\n📝 Running {test.__doc__}...")
|
||||
await test()
|
||||
await asyncio.sleep(2) # Brief pause between tests
|
||||
|
||||
|
||||
print("\n📊 Test Results Summary:")
|
||||
for result in self.results:
|
||||
print(result)
|
||||
|
||||
successes = len([r for r in self.results if '✅' in r])
|
||||
|
||||
successes = len([r for r in self.results if "✅" in r])
|
||||
total = len(self.results)
|
||||
print(f"\nTotal: {successes}/{total} tests passed")
|
||||
|
||||
|
||||
self.cleanup()
|
||||
|
||||
|
||||
async def main():
|
||||
tester = TestDownloads()
|
||||
await tester.run_all_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parent_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_crawl():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -21,6 +23,7 @@ async def test_successful_crawl():
|
||||
assert result.markdown
|
||||
assert result.cleaned_html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_url():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -29,19 +32,21 @@ async def test_invalid_url():
|
||||
assert not result.success
|
||||
assert result.error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_urls():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
urls = [
|
||||
"https://www.nbcnews.com/business",
|
||||
"https://www.example.com",
|
||||
"https://www.python.org"
|
||||
"https://www.python.org",
|
||||
]
|
||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||
assert len(results) == len(urls)
|
||||
assert all(result.success for result in results)
|
||||
assert all(result.html for result in results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_javascript_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -51,6 +56,7 @@ async def test_javascript_execution():
|
||||
assert result.success
|
||||
assert "<h1>Modified by JS</h1>" in result.html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_crawling_performance():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -59,23 +65,26 @@ async def test_concurrent_crawling_performance():
|
||||
"https://www.example.com",
|
||||
"https://www.python.org",
|
||||
"https://www.github.com",
|
||||
"https://www.stackoverflow.com"
|
||||
"https://www.stackoverflow.com",
|
||||
]
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
total_time = end_time - start_time
|
||||
print(f"Total time for concurrent crawling: {total_time:.2f} seconds")
|
||||
|
||||
|
||||
assert all(result.success for result in results)
|
||||
assert len(results) == len(urls)
|
||||
|
||||
|
||||
# Assert that concurrent crawling is faster than sequential
|
||||
# This multiplier may need adjustment based on the number of URLs and their complexity
|
||||
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||
assert (
|
||||
total_time < len(urls) * 5
|
||||
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -9,74 +9,79 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
|
||||
# First crawl (should not use cache)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
result1 = await crawler.arun(url=url, bypass_cache=True)
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
time_taken1 = end_time - start_time
|
||||
|
||||
|
||||
assert result1.success
|
||||
|
||||
|
||||
# Second crawl (should use cache)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
result2 = await crawler.arun(url=url, bypass_cache=False)
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
time_taken2 = end_time - start_time
|
||||
|
||||
|
||||
assert result2.success
|
||||
assert time_taken2 < time_taken1 # Cached result should be faster
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bypass_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
|
||||
# First crawl
|
||||
result1 = await crawler.arun(url=url, bypass_cache=False)
|
||||
assert result1.success
|
||||
|
||||
|
||||
# Second crawl with bypass_cache=True
|
||||
result2 = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result2.success
|
||||
|
||||
|
||||
# Content should be different (or at least, not guaranteed to be the same)
|
||||
assert result1.html != result2.html or result1.markdown != result2.markdown
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
|
||||
# Crawl and cache
|
||||
await crawler.arun(url=url, bypass_cache=False)
|
||||
|
||||
|
||||
# Clear cache
|
||||
await crawler.aclear_cache()
|
||||
|
||||
|
||||
# Check cache size
|
||||
cache_size = await crawler.aget_cache_size()
|
||||
assert cache_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
|
||||
# Crawl and cache
|
||||
await crawler.arun(url=url, bypass_cache=False)
|
||||
|
||||
|
||||
# Flush cache
|
||||
await crawler.aflush_cache()
|
||||
|
||||
|
||||
# Check cache size
|
||||
cache_size = await crawler.aget_cache_size()
|
||||
assert cache_size == 0
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
@@ -9,8 +8,9 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking
|
||||
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy
|
||||
from crawl4ai.chunking_strategy import RegexChunking
|
||||
from crawl4ai.extraction_strategy import LLMExtractionStrategy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regex_chunking():
|
||||
@@ -18,15 +18,14 @@ async def test_regex_chunking():
|
||||
url = "https://www.nbcnews.com/business"
|
||||
chunking_strategy = RegexChunking(patterns=["\n\n"])
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
chunking_strategy=chunking_strategy,
|
||||
bypass_cache=True
|
||||
url=url, chunking_strategy=chunking_strategy, bypass_cache=True
|
||||
)
|
||||
assert result.success
|
||||
assert result.extracted_content
|
||||
chunks = json.loads(result.extracted_content)
|
||||
assert len(chunks) > 1 # Ensure multiple chunks were created
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_cosine_strategy():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -43,25 +42,25 @@ async def test_regex_chunking():
|
||||
# assert len(extracted_data) > 0
|
||||
# assert all('tags' in item for item in extracted_data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_extraction_strategy():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
extraction_strategy = LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o-mini",
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
instruction="Extract only content related to technology"
|
||||
api_token=os.getenv("OPENAI_API_KEY"),
|
||||
instruction="Extract only content related to technology",
|
||||
)
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
extraction_strategy=extraction_strategy,
|
||||
bypass_cache=True
|
||||
url=url, extraction_strategy=extraction_strategy, bypass_cache=True
|
||||
)
|
||||
assert result.success
|
||||
assert result.extracted_content
|
||||
extracted_data = json.loads(result.extracted_content)
|
||||
assert len(extracted_data) > 0
|
||||
assert all('content' in item for item in extracted_data)
|
||||
assert all("content" in item for item in extracted_data)
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_combined_chunking_and_extraction():
|
||||
@@ -84,4 +83,4 @@ async def test_llm_extraction_strategy():
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_markdown():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -20,6 +19,7 @@ async def test_extract_markdown():
|
||||
assert isinstance(result.markdown, str)
|
||||
assert len(result.markdown) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_cleaned_html():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -30,6 +30,7 @@ async def test_extract_cleaned_html():
|
||||
assert isinstance(result.cleaned_html, str)
|
||||
assert len(result.cleaned_html) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_media():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -46,6 +47,7 @@ async def test_extract_media():
|
||||
assert "alt" in image
|
||||
assert "type" in image
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_links():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -63,6 +65,7 @@ async def test_extract_links():
|
||||
assert "href" in link
|
||||
assert "text" in link
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_metadata():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -75,16 +78,20 @@ async def test_extract_metadata():
|
||||
assert "title" in metadata
|
||||
assert isinstance(metadata["title"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_css_selector_extraction():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
css_selector = "h1, h2, h3"
|
||||
result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector)
|
||||
result = await crawler.arun(
|
||||
url=url, bypass_cache=True, css_selector=css_selector
|
||||
)
|
||||
assert result.success
|
||||
assert result.markdown
|
||||
assert all(heading in result.markdown for heading in ["#", "##", "###"])
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os, sys
|
||||
import pytest
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import List
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -9,6 +8,7 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_html():
|
||||
return """
|
||||
@@ -28,6 +28,7 @@ def basic_html():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wiki_html():
|
||||
return """
|
||||
@@ -46,6 +47,7 @@ def wiki_html():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_meta_html():
|
||||
return """
|
||||
@@ -57,26 +59,27 @@ def no_meta_html():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class TestBM25ContentFilter:
|
||||
def test_basic_extraction(self, basic_html):
|
||||
"""Test basic content extraction functionality"""
|
||||
filter = BM25ContentFilter()
|
||||
contents = filter.filter_content(basic_html)
|
||||
|
||||
|
||||
assert contents, "Should extract content"
|
||||
assert len(contents) >= 1, "Should extract at least one content block"
|
||||
assert "long paragraph" in ' '.join(contents).lower()
|
||||
assert "navigation" not in ' '.join(contents).lower()
|
||||
assert "long paragraph" in " ".join(contents).lower()
|
||||
assert "navigation" not in " ".join(contents).lower()
|
||||
|
||||
def test_user_query_override(self, basic_html):
|
||||
"""Test that user query overrides metadata extraction"""
|
||||
user_query = "specific test query"
|
||||
filter = BM25ContentFilter(user_query=user_query)
|
||||
|
||||
|
||||
# Access internal state to verify query usage
|
||||
soup = BeautifulSoup(basic_html, 'lxml')
|
||||
extracted_query = filter.extract_page_query(soup.find('head'))
|
||||
|
||||
soup = BeautifulSoup(basic_html, "lxml")
|
||||
extracted_query = filter.extract_page_query(soup.find("head"))
|
||||
|
||||
assert extracted_query == user_query
|
||||
assert "Test description" not in extracted_query
|
||||
|
||||
@@ -84,8 +87,8 @@ class TestBM25ContentFilter:
|
||||
"""Test that headers are properly extracted despite length"""
|
||||
filter = BM25ContentFilter()
|
||||
contents = filter.filter_content(wiki_html)
|
||||
|
||||
combined_content = ' '.join(contents).lower()
|
||||
|
||||
combined_content = " ".join(contents).lower()
|
||||
assert "section 1" in combined_content, "Should include section header"
|
||||
assert "article title" in combined_content, "Should include main title"
|
||||
|
||||
@@ -93,9 +96,11 @@ class TestBM25ContentFilter:
|
||||
"""Test fallback behavior when no metadata is present"""
|
||||
filter = BM25ContentFilter()
|
||||
contents = filter.filter_content(no_meta_html)
|
||||
|
||||
|
||||
assert contents, "Should extract content even without metadata"
|
||||
assert "First paragraph" in ' '.join(contents), "Should use first paragraph content"
|
||||
assert "First paragraph" in " ".join(
|
||||
contents
|
||||
), "Should use first paragraph content"
|
||||
|
||||
def test_empty_input(self):
|
||||
"""Test handling of empty input"""
|
||||
@@ -108,29 +113,30 @@ class TestBM25ContentFilter:
|
||||
malformed_html = "<p>Unclosed paragraph<div>Nested content</p></div>"
|
||||
filter = BM25ContentFilter()
|
||||
contents = filter.filter_content(malformed_html)
|
||||
|
||||
|
||||
assert isinstance(contents, list), "Should return list even with malformed HTML"
|
||||
|
||||
|
||||
def test_threshold_behavior(self, basic_html):
|
||||
"""Test different BM25 threshold values"""
|
||||
strict_filter = BM25ContentFilter(bm25_threshold=2.0)
|
||||
lenient_filter = BM25ContentFilter(bm25_threshold=0.5)
|
||||
|
||||
|
||||
strict_contents = strict_filter.filter_content(basic_html)
|
||||
lenient_contents = lenient_filter.filter_content(basic_html)
|
||||
|
||||
assert len(strict_contents) <= len(lenient_contents), \
|
||||
"Strict threshold should extract fewer elements"
|
||||
|
||||
assert len(strict_contents) <= len(
|
||||
lenient_contents
|
||||
), "Strict threshold should extract fewer elements"
|
||||
|
||||
def test_html_cleaning(self, basic_html):
|
||||
"""Test HTML cleaning functionality"""
|
||||
filter = BM25ContentFilter()
|
||||
contents = filter.filter_content(basic_html)
|
||||
|
||||
cleaned_content = ' '.join(contents)
|
||||
assert 'class=' not in cleaned_content, "Should remove class attributes"
|
||||
assert 'style=' not in cleaned_content, "Should remove style attributes"
|
||||
assert '<script' not in cleaned_content, "Should remove script tags"
|
||||
|
||||
cleaned_content = " ".join(contents)
|
||||
assert "class=" not in cleaned_content, "Should remove class attributes"
|
||||
assert "style=" not in cleaned_content, "Should remove style attributes"
|
||||
assert "<script" not in cleaned_content, "Should remove script tags"
|
||||
|
||||
def test_large_content(self):
|
||||
"""Test handling of large content blocks"""
|
||||
@@ -143,9 +149,9 @@ class TestBM25ContentFilter:
|
||||
contents = filter.filter_content(large_html)
|
||||
assert contents, "Should handle large content blocks"
|
||||
|
||||
@pytest.mark.parametrize("unwanted_tag", [
|
||||
'script', 'style', 'nav', 'footer', 'header'
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"unwanted_tag", ["script", "style", "nav", "footer", "header"]
|
||||
)
|
||||
def test_excluded_tags(self, unwanted_tag):
|
||||
"""Test that specific tags are properly excluded"""
|
||||
html = f"""
|
||||
@@ -156,20 +162,22 @@ class TestBM25ContentFilter:
|
||||
"""
|
||||
filter = BM25ContentFilter()
|
||||
contents = filter.filter_content(html)
|
||||
|
||||
combined_content = ' '.join(contents).lower()
|
||||
|
||||
combined_content = " ".join(contents).lower()
|
||||
assert "should not appear" not in combined_content
|
||||
|
||||
|
||||
def test_performance(self, basic_html):
|
||||
"""Test performance with timer"""
|
||||
filter = BM25ContentFilter()
|
||||
|
||||
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
filter.filter_content(basic_html)
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
|
||||
assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os, sys
|
||||
import pytest
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_html():
|
||||
return """
|
||||
@@ -22,6 +22,7 @@ def basic_html():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def link_heavy_html():
|
||||
return """
|
||||
@@ -40,6 +41,7 @@ def link_heavy_html():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_content_html():
|
||||
return """
|
||||
@@ -60,13 +62,14 @@ def mixed_content_html():
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class TestPruningContentFilter:
|
||||
def test_basic_pruning(self, basic_html):
|
||||
"""Test basic content pruning functionality"""
|
||||
filter = PruningContentFilter(min_word_threshold=5)
|
||||
contents = filter.filter_content(basic_html)
|
||||
|
||||
combined_content = ' '.join(contents).lower()
|
||||
|
||||
combined_content = " ".join(contents).lower()
|
||||
assert "high-quality paragraph" in combined_content
|
||||
assert "sidebar content" not in combined_content
|
||||
assert "share buttons" not in combined_content
|
||||
@@ -75,40 +78,42 @@ class TestPruningContentFilter:
|
||||
"""Test minimum word threshold filtering"""
|
||||
filter = PruningContentFilter(min_word_threshold=10)
|
||||
contents = filter.filter_content(mixed_content_html)
|
||||
|
||||
combined_content = ' '.join(contents).lower()
|
||||
|
||||
combined_content = " ".join(contents).lower()
|
||||
assert "short summary" not in combined_content
|
||||
assert "long high-quality paragraph" in combined_content
|
||||
assert "short comment" not in combined_content
|
||||
|
||||
def test_threshold_types(self, basic_html):
|
||||
"""Test fixed vs dynamic thresholds"""
|
||||
fixed_filter = PruningContentFilter(threshold_type='fixed', threshold=0.48)
|
||||
dynamic_filter = PruningContentFilter(threshold_type='dynamic', threshold=0.45)
|
||||
|
||||
fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48)
|
||||
dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45)
|
||||
|
||||
fixed_contents = fixed_filter.filter_content(basic_html)
|
||||
dynamic_contents = dynamic_filter.filter_content(basic_html)
|
||||
|
||||
assert len(fixed_contents) != len(dynamic_contents), \
|
||||
"Fixed and dynamic thresholds should yield different results"
|
||||
|
||||
assert len(fixed_contents) != len(
|
||||
dynamic_contents
|
||||
), "Fixed and dynamic thresholds should yield different results"
|
||||
|
||||
def test_link_density_impact(self, link_heavy_html):
|
||||
"""Test handling of link-heavy content"""
|
||||
filter = PruningContentFilter(threshold_type='dynamic')
|
||||
filter = PruningContentFilter(threshold_type="dynamic")
|
||||
contents = filter.filter_content(link_heavy_html)
|
||||
|
||||
combined_content = ' '.join(contents).lower()
|
||||
|
||||
combined_content = " ".join(contents).lower()
|
||||
assert "good content paragraph" in combined_content
|
||||
assert len([c for c in contents if 'href' in c]) < 2, \
|
||||
"Should prune link-heavy sections"
|
||||
assert (
|
||||
len([c for c in contents if "href" in c]) < 2
|
||||
), "Should prune link-heavy sections"
|
||||
|
||||
def test_tag_importance(self, mixed_content_html):
|
||||
"""Test tag importance in scoring"""
|
||||
filter = PruningContentFilter(threshold_type='dynamic')
|
||||
filter = PruningContentFilter(threshold_type="dynamic")
|
||||
contents = filter.filter_content(mixed_content_html)
|
||||
|
||||
has_article = any('article' in c.lower() for c in contents)
|
||||
has_h1 = any('h1' in c.lower() for c in contents)
|
||||
|
||||
has_article = any("article" in c.lower() for c in contents)
|
||||
has_h1 = any("h1" in c.lower() for c in contents)
|
||||
assert has_article or has_h1, "Should retain important tags"
|
||||
|
||||
def test_empty_input(self):
|
||||
@@ -127,26 +132,31 @@ class TestPruningContentFilter:
|
||||
def test_performance(self, basic_html):
|
||||
"""Test performance with timer"""
|
||||
filter = PruningContentFilter()
|
||||
|
||||
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
filter.filter_content(basic_html)
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
|
||||
# Extra strict on performance since you mentioned milliseconds matter
|
||||
assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds"
|
||||
|
||||
@pytest.mark.parametrize("threshold,expected_count", [
|
||||
(0.3, 4), # Very lenient
|
||||
(0.48, 2), # Default
|
||||
(0.7, 1), # Very strict
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"threshold,expected_count",
|
||||
[
|
||||
(0.3, 4), # Very lenient
|
||||
(0.48, 2), # Default
|
||||
(0.7, 1), # Very strict
|
||||
],
|
||||
)
|
||||
def test_threshold_levels(self, mixed_content_html, threshold, expected_count):
|
||||
"""Test different threshold levels"""
|
||||
filter = PruningContentFilter(threshold_type='fixed', threshold=threshold)
|
||||
filter = PruningContentFilter(threshold_type="fixed", threshold=threshold)
|
||||
contents = filter.filter_content(mixed_content_html)
|
||||
assert len(contents) <= expected_count, \
|
||||
f"Expected {expected_count} or fewer elements with threshold {threshold}"
|
||||
assert (
|
||||
len(contents) <= expected_count
|
||||
), f"Expected {expected_count} or fewer elements with threshold {threshold}"
|
||||
|
||||
def test_consistent_output(self, basic_html):
|
||||
"""Test output consistency across multiple runs"""
|
||||
@@ -155,5 +165,6 @@ class TestPruningContentFilter:
|
||||
second_run = filter.filter_content(basic_html)
|
||||
assert first_run == second_run, "Output should be consistent"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import asyncio
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import csv
|
||||
from tabulate import tabulate
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parent_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(parent_dir)
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
from crawl4ai.content_scraping_strategy import WebScrapingStrategy
|
||||
from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent
|
||||
from crawl4ai.content_scraping_strategy import (
|
||||
WebScrapingStrategy as WebScrapingStrategyCurrent,
|
||||
)
|
||||
# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
name: str
|
||||
@@ -27,69 +29,71 @@ class TestResult:
|
||||
markdown_length: int
|
||||
execution_time: float
|
||||
|
||||
|
||||
class StrategyTester:
|
||||
def __init__(self):
|
||||
self.new_scraper = WebScrapingStrategy()
|
||||
self.current_scraper = WebScrapingStrategyCurrent()
|
||||
with open(__location__ + '/sample_wikipedia.html', 'r', encoding='utf-8') as f:
|
||||
with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f:
|
||||
self.WIKI_HTML = f.read()
|
||||
self.results = {'new': [], 'current': []}
|
||||
|
||||
self.results = {"new": [], "current": []}
|
||||
|
||||
def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
|
||||
results = []
|
||||
for scraper in [self.new_scraper, self.current_scraper]:
|
||||
start_time = time.time()
|
||||
result = scraper._get_content_of_website_optimized(
|
||||
url="https://en.wikipedia.org/wiki/Test",
|
||||
html=self.WIKI_HTML,
|
||||
**kwargs
|
||||
url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs
|
||||
)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
test_result = TestResult(
|
||||
name=name,
|
||||
success=result['success'],
|
||||
images=len(result['media']['images']),
|
||||
internal_links=len(result['links']['internal']),
|
||||
external_links=len(result['links']['external']),
|
||||
markdown_length=len(result['markdown']),
|
||||
execution_time=execution_time
|
||||
success=result["success"],
|
||||
images=len(result["media"]["images"]),
|
||||
internal_links=len(result["links"]["internal"]),
|
||||
external_links=len(result["links"]["external"]),
|
||||
markdown_length=len(result["markdown"]),
|
||||
execution_time=execution_time,
|
||||
)
|
||||
results.append(test_result)
|
||||
|
||||
|
||||
return results[0], results[1] # new, current
|
||||
|
||||
def run_all_tests(self):
|
||||
test_cases = [
|
||||
("Basic Extraction", {}),
|
||||
("Exclude Tags", {'excluded_tags': ['table', 'div.infobox', 'div.navbox']}),
|
||||
("Word Threshold", {'word_count_threshold': 50}),
|
||||
("CSS Selector", {'css_selector': 'div.mw-parser-output > p'}),
|
||||
("Link Exclusions", {
|
||||
'exclude_external_links': True,
|
||||
'exclude_social_media_links': True,
|
||||
'exclude_domains': ['facebook.com', 'twitter.com']
|
||||
}),
|
||||
("Media Handling", {
|
||||
'exclude_external_images': True,
|
||||
'image_description_min_word_threshold': 20
|
||||
}),
|
||||
("Text Only", {
|
||||
'only_text': True,
|
||||
'remove_forms': True
|
||||
}),
|
||||
("HTML Cleaning", {
|
||||
'clean_html': True,
|
||||
'keep_data_attributes': True
|
||||
}),
|
||||
("HTML2Text Options", {
|
||||
'html2text': {
|
||||
'skip_internal_links': True,
|
||||
'single_line_break': True,
|
||||
'mark_code': True,
|
||||
'preserve_tags': ['pre', 'code']
|
||||
}
|
||||
})
|
||||
("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
|
||||
("Word Threshold", {"word_count_threshold": 50}),
|
||||
("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
|
||||
(
|
||||
"Link Exclusions",
|
||||
{
|
||||
"exclude_external_links": True,
|
||||
"exclude_social_media_links": True,
|
||||
"exclude_domains": ["facebook.com", "twitter.com"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"Media Handling",
|
||||
{
|
||||
"exclude_external_images": True,
|
||||
"image_description_min_word_threshold": 20,
|
||||
},
|
||||
),
|
||||
("Text Only", {"only_text": True, "remove_forms": True}),
|
||||
("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
|
||||
(
|
||||
"HTML2Text Options",
|
||||
{
|
||||
"html2text": {
|
||||
"skip_internal_links": True,
|
||||
"single_line_break": True,
|
||||
"mark_code": True,
|
||||
"preserve_tags": ["pre", "code"],
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
all_results = []
|
||||
@@ -99,64 +103,117 @@ class StrategyTester:
|
||||
all_results.append((name, new_result, current_result))
|
||||
except Exception as e:
|
||||
print(f"Error in {name}: {str(e)}")
|
||||
|
||||
|
||||
self.save_results_to_csv(all_results)
|
||||
self.print_comparison_table(all_results)
|
||||
|
||||
def save_results_to_csv(self, all_results: List[tuple]):
|
||||
csv_file = os.path.join(__location__, 'strategy_comparison_results.csv')
|
||||
with open(csv_file, 'w', newline='') as f:
|
||||
csv_file = os.path.join(__location__, "strategy_comparison_results.csv")
|
||||
with open(csv_file, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links',
|
||||
'External Links', 'Markdown Length', 'Execution Time'])
|
||||
|
||||
writer.writerow(
|
||||
[
|
||||
"Test Name",
|
||||
"Strategy",
|
||||
"Success",
|
||||
"Images",
|
||||
"Internal Links",
|
||||
"External Links",
|
||||
"Markdown Length",
|
||||
"Execution Time",
|
||||
]
|
||||
)
|
||||
|
||||
for name, new_result, current_result in all_results:
|
||||
writer.writerow([name, 'New', new_result.success, new_result.images,
|
||||
new_result.internal_links, new_result.external_links,
|
||||
new_result.markdown_length, f"{new_result.execution_time:.3f}"])
|
||||
writer.writerow([name, 'Current', current_result.success, current_result.images,
|
||||
current_result.internal_links, current_result.external_links,
|
||||
current_result.markdown_length, f"{current_result.execution_time:.3f}"])
|
||||
writer.writerow(
|
||||
[
|
||||
name,
|
||||
"New",
|
||||
new_result.success,
|
||||
new_result.images,
|
||||
new_result.internal_links,
|
||||
new_result.external_links,
|
||||
new_result.markdown_length,
|
||||
f"{new_result.execution_time:.3f}",
|
||||
]
|
||||
)
|
||||
writer.writerow(
|
||||
[
|
||||
name,
|
||||
"Current",
|
||||
current_result.success,
|
||||
current_result.images,
|
||||
current_result.internal_links,
|
||||
current_result.external_links,
|
||||
current_result.markdown_length,
|
||||
f"{current_result.execution_time:.3f}",
|
||||
]
|
||||
)
|
||||
|
||||
def print_comparison_table(self, all_results: List[tuple]):
|
||||
table_data = []
|
||||
headers = ['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links',
|
||||
'External Links', 'Markdown Length', 'Time (s)']
|
||||
headers = [
|
||||
"Test Name",
|
||||
"Strategy",
|
||||
"Success",
|
||||
"Images",
|
||||
"Internal Links",
|
||||
"External Links",
|
||||
"Markdown Length",
|
||||
"Time (s)",
|
||||
]
|
||||
|
||||
for name, new_result, current_result in all_results:
|
||||
# Check for differences
|
||||
differences = []
|
||||
if new_result.images != current_result.images: differences.append('images')
|
||||
if new_result.internal_links != current_result.internal_links: differences.append('internal_links')
|
||||
if new_result.external_links != current_result.external_links: differences.append('external_links')
|
||||
if new_result.markdown_length != current_result.markdown_length: differences.append('markdown')
|
||||
|
||||
if new_result.images != current_result.images:
|
||||
differences.append("images")
|
||||
if new_result.internal_links != current_result.internal_links:
|
||||
differences.append("internal_links")
|
||||
if new_result.external_links != current_result.external_links:
|
||||
differences.append("external_links")
|
||||
if new_result.markdown_length != current_result.markdown_length:
|
||||
differences.append("markdown")
|
||||
|
||||
# Add row for new strategy
|
||||
new_row = [
|
||||
name, 'New', new_result.success, new_result.images,
|
||||
new_result.internal_links, new_result.external_links,
|
||||
new_result.markdown_length, f"{new_result.execution_time:.3f}"
|
||||
name,
|
||||
"New",
|
||||
new_result.success,
|
||||
new_result.images,
|
||||
new_result.internal_links,
|
||||
new_result.external_links,
|
||||
new_result.markdown_length,
|
||||
f"{new_result.execution_time:.3f}",
|
||||
]
|
||||
table_data.append(new_row)
|
||||
|
||||
|
||||
# Add row for current strategy
|
||||
current_row = [
|
||||
'', 'Current', current_result.success, current_result.images,
|
||||
current_result.internal_links, current_result.external_links,
|
||||
current_result.markdown_length, f"{current_result.execution_time:.3f}"
|
||||
"",
|
||||
"Current",
|
||||
current_result.success,
|
||||
current_result.images,
|
||||
current_result.internal_links,
|
||||
current_result.external_links,
|
||||
current_result.markdown_length,
|
||||
f"{current_result.execution_time:.3f}",
|
||||
]
|
||||
table_data.append(current_row)
|
||||
|
||||
|
||||
# Add difference summary if any
|
||||
if differences:
|
||||
table_data.append(['', '⚠️ Differences', ', '.join(differences), '', '', '', '', ''])
|
||||
|
||||
table_data.append(
|
||||
["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
|
||||
)
|
||||
|
||||
# Add empty row for better readability
|
||||
table_data.append([''] * len(headers))
|
||||
table_data.append([""] * len(headers))
|
||||
|
||||
print("\nStrategy Comparison Results:")
|
||||
print(tabulate(table_data, headers=headers, tablefmt='grid'))
|
||||
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = StrategyTester()
|
||||
tester.run_all_tests()
|
||||
tester.run_all_tests()
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_user_agent():
|
||||
@@ -20,6 +19,7 @@ async def test_custom_user_agent():
|
||||
assert result.success
|
||||
assert custom_user_agent in result.html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_headers():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -31,6 +31,7 @@ async def test_custom_headers():
|
||||
assert "X-Test-Header" in result.html
|
||||
assert "TestValue" in result.html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_javascript_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -40,19 +41,22 @@ async def test_javascript_execution():
|
||||
assert result.success
|
||||
assert "<h1>Modified by JS</h1>" in result.html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
|
||||
async def test_hook(page):
|
||||
await page.evaluate("document.body.style.backgroundColor = 'red';")
|
||||
return page
|
||||
|
||||
crawler.crawler_strategy.set_hook('after_goto', test_hook)
|
||||
crawler.crawler_strategy.set_hook("after_goto", test_hook)
|
||||
url = "https://www.example.com"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert "background-color: red" in result.html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -63,6 +67,7 @@ async def test_screenshot():
|
||||
assert isinstance(result.screenshot, str)
|
||||
assert len(result.screenshot) > 0
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_url():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -23,6 +22,7 @@ async def test_cache_url():
|
||||
assert result2.success
|
||||
assert result2.html == result1.html
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bypass_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -34,25 +34,29 @@ async def test_bypass_cache():
|
||||
# Second run bypassing cache
|
||||
result2 = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result2.success
|
||||
assert result2.html != result1.html # Content might be different due to dynamic nature of websites
|
||||
assert (
|
||||
result2.html != result1.html
|
||||
) # Content might be different due to dynamic nature of websites
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_size():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
initial_size = await crawler.aget_cache_size()
|
||||
|
||||
|
||||
url = "https://www.nbcnews.com/business"
|
||||
await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
|
||||
new_size = await crawler.aget_cache_size()
|
||||
assert new_size == initial_size + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.example.org"
|
||||
await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
|
||||
initial_size = await crawler.aget_cache_size()
|
||||
assert initial_size > 0
|
||||
|
||||
@@ -60,12 +64,13 @@ async def test_clear_cache():
|
||||
new_size = await crawler.aget_cache_size()
|
||||
assert new_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.example.net"
|
||||
await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
|
||||
initial_size = await crawler.aget_cache_size()
|
||||
assert initial_size > 0
|
||||
|
||||
@@ -75,8 +80,11 @@ async def test_flush_cache():
|
||||
|
||||
# Try to retrieve the previously cached URL
|
||||
result = await crawler.arun(url=url, bypass_cache=False)
|
||||
assert result.success # The crawler should still succeed, but it will fetch the content anew
|
||||
assert (
|
||||
result.success
|
||||
) # The crawler should still succeed, but it will fetch the content anew
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,114 +1,133 @@
|
||||
import pytest
|
||||
import asyncio, time
|
||||
import time
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig,
|
||||
MemoryAdaptiveDispatcher, SemaphoreDispatcher,
|
||||
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode
|
||||
AsyncWebCrawler,
|
||||
BrowserConfig,
|
||||
CrawlerRunConfig,
|
||||
MemoryAdaptiveDispatcher,
|
||||
SemaphoreDispatcher,
|
||||
RateLimiter,
|
||||
CrawlerMonitor,
|
||||
DisplayMode,
|
||||
CacheMode,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def browser_config():
|
||||
return BrowserConfig(
|
||||
headless=True,
|
||||
verbose=False
|
||||
)
|
||||
return BrowserConfig(headless=True, verbose=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_config():
|
||||
return CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
verbose=False
|
||||
)
|
||||
return CrawlerRunConfig(cache_mode=CacheMode.BYPASS, verbose=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_urls():
|
||||
return [
|
||||
"http://example.com",
|
||||
"http://example.com/page1",
|
||||
"http://example.com/page2"
|
||||
"http://example.com/page2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDispatchStrategies:
|
||||
|
||||
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=70.0,
|
||||
max_session_permit=2,
|
||||
check_interval=0.1
|
||||
memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
test_urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
||||
assert len(results) == len(test_urls)
|
||||
assert all(r.success for r in results)
|
||||
|
||||
async def test_memory_adaptive_with_rate_limit(self, browser_config, run_config, test_urls):
|
||||
async def test_memory_adaptive_with_rate_limit(
|
||||
self, browser_config, run_config, test_urls
|
||||
):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=70.0,
|
||||
max_session_permit=2,
|
||||
check_interval=0.1,
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=(0.1, 0.2),
|
||||
max_delay=1.0,
|
||||
max_retries=2
|
||||
)
|
||||
base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
|
||||
),
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
test_urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
||||
assert len(results) == len(test_urls)
|
||||
assert all(r.success for r in results)
|
||||
|
||||
async def test_semaphore_basic(self, browser_config, run_config, test_urls):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = SemaphoreDispatcher(
|
||||
semaphore_count=2
|
||||
dispatcher = SemaphoreDispatcher(semaphore_count=2)
|
||||
results = await crawler.arun_many(
|
||||
test_urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
||||
assert len(results) == len(test_urls)
|
||||
assert all(r.success for r in results)
|
||||
|
||||
async def test_semaphore_with_rate_limit(self, browser_config, run_config, test_urls):
|
||||
async def test_semaphore_with_rate_limit(
|
||||
self, browser_config, run_config, test_urls
|
||||
):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = SemaphoreDispatcher(
|
||||
semaphore_count=2,
|
||||
rate_limiter=RateLimiter(
|
||||
base_delay=(0.1, 0.2),
|
||||
max_delay=1.0,
|
||||
max_retries=2
|
||||
)
|
||||
base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
|
||||
),
|
||||
)
|
||||
results = await crawler.arun_many(
|
||||
test_urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
||||
assert len(results) == len(test_urls)
|
||||
assert all(r.success for r in results)
|
||||
|
||||
async def test_memory_adaptive_memory_error(self, browser_config, run_config, test_urls):
|
||||
async def test_memory_adaptive_memory_error(
|
||||
self, browser_config, run_config, test_urls
|
||||
):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
memory_threshold_percent=1.0, # Set unrealistically low threshold
|
||||
max_session_permit=2,
|
||||
check_interval=0.1,
|
||||
memory_wait_timeout=1.0 # Short timeout for testing
|
||||
memory_wait_timeout=1.0, # Short timeout for testing
|
||||
)
|
||||
with pytest.raises(MemoryError):
|
||||
await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
||||
await crawler.arun_many(
|
||||
test_urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
|
||||
async def test_empty_urls(self, browser_config, run_config):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
||||
results = await crawler.arun_many([], config=run_config, dispatcher=dispatcher)
|
||||
results = await crawler.arun_many(
|
||||
[], config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
assert len(results) == 0
|
||||
|
||||
async def test_single_url(self, browser_config, run_config):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
||||
results = await crawler.arun_many(["http://example.com"], config=run_config, dispatcher=dispatcher)
|
||||
results = await crawler.arun_many(
|
||||
["http://example.com"], config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].success
|
||||
|
||||
async def test_invalid_urls(self, browser_config, run_config):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
||||
results = await crawler.arun_many(["http://invalid.url.that.doesnt.exist"], config=run_config, dispatcher=dispatcher)
|
||||
results = await crawler.arun_many(
|
||||
["http://invalid.url.that.doesnt.exist"],
|
||||
config=run_config,
|
||||
dispatcher=dispatcher,
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert not results[0].success
|
||||
|
||||
@@ -121,27 +140,31 @@ class TestDispatchStrategies:
|
||||
base_delay=(0.1, 0.2),
|
||||
max_delay=1.0,
|
||||
max_retries=2,
|
||||
rate_limit_codes=[200] # Force rate limiting for testing
|
||||
)
|
||||
rate_limit_codes=[200], # Force rate limiting for testing
|
||||
),
|
||||
)
|
||||
start_time = time.time()
|
||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
||||
results = await crawler.arun_many(
|
||||
urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
duration = time.time() - start_time
|
||||
assert len(results) == len(urls)
|
||||
assert duration > 1.0 # Ensure rate limiting caused delays
|
||||
|
||||
async def test_monitor_integration(self, browser_config, run_config, test_urls):
|
||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||
monitor = CrawlerMonitor(max_visible_rows=5, display_mode=DisplayMode.DETAILED)
|
||||
dispatcher = MemoryAdaptiveDispatcher(
|
||||
max_session_permit=2,
|
||||
monitor=monitor
|
||||
monitor = CrawlerMonitor(
|
||||
max_visible_rows=5, display_mode=DisplayMode.DETAILED
|
||||
)
|
||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor)
|
||||
results = await crawler.arun_many(
|
||||
test_urls, config=run_config, dispatcher=dispatcher
|
||||
)
|
||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
||||
assert len(results) == len(test_urls)
|
||||
# Check monitor stats
|
||||
assert len(monitor.stats) == len(test_urls)
|
||||
assert all(stat.end_time is not None for stat in monitor.stats.values())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
||||
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import pytest
|
||||
import json
|
||||
from bs4 import BeautifulSoup
|
||||
import asyncio
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -59,19 +59,21 @@ from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
# assert result.success
|
||||
# assert "github" in result.html.lower()
|
||||
|
||||
|
||||
# Add this test to your existing test file
|
||||
@pytest.mark.asyncio
|
||||
async def test_typescript_commits_multi_page():
|
||||
first_commit = ""
|
||||
|
||||
async def on_execution_started(page):
|
||||
nonlocal first_commit
|
||||
nonlocal first_commit
|
||||
try:
|
||||
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
|
||||
while True:
|
||||
await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4')
|
||||
commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4')
|
||||
commit = await commit.evaluate('(element) => element.textContent')
|
||||
commit = re.sub(r'\s+', '', commit)
|
||||
await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4")
|
||||
commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4")
|
||||
commit = await commit.evaluate("(element) => element.textContent")
|
||||
commit = re.sub(r"\s+", "", commit)
|
||||
if commit and commit != first_commit:
|
||||
first_commit = commit
|
||||
break
|
||||
@@ -79,9 +81,8 @@ async def test_typescript_commits_multi_page():
|
||||
except Exception as e:
|
||||
print(f"Warning: New content didn't appear after JavaScript execution: {e}")
|
||||
|
||||
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started)
|
||||
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
|
||||
|
||||
url = "https://github.com/microsoft/TypeScript/commits/main"
|
||||
session_id = "typescript_commits_session"
|
||||
@@ -97,19 +98,21 @@ async def test_typescript_commits_multi_page():
|
||||
url=url, # Only use URL for the first page
|
||||
session_id=session_id,
|
||||
css_selector="li.Box-sc-g0xbh4-0",
|
||||
js=js_next_page if page > 0 else None, # Don't click 'next' on the first page
|
||||
js=js_next_page
|
||||
if page > 0
|
||||
else None, # Don't click 'next' on the first page
|
||||
bypass_cache=True,
|
||||
js_only=page > 0 # Use js_only for subsequent pages
|
||||
js_only=page > 0, # Use js_only for subsequent pages
|
||||
)
|
||||
|
||||
assert result.success, f"Failed to crawl page {page + 1}"
|
||||
|
||||
# Parse the HTML and extract commits
|
||||
soup = BeautifulSoup(result.cleaned_html, 'html.parser')
|
||||
soup = BeautifulSoup(result.cleaned_html, "html.parser")
|
||||
commits = soup.select("li")
|
||||
# Take first commit find h4 extract text
|
||||
first_commit = commits[0].find("h4").text
|
||||
first_commit = re.sub(r'\s+', '', first_commit)
|
||||
first_commit = re.sub(r"\s+", "", first_commit)
|
||||
all_commits.extend(commits)
|
||||
|
||||
print(f"Page {page + 1}: Found {len(commits)} commits")
|
||||
@@ -118,10 +121,13 @@ async def test_typescript_commits_multi_page():
|
||||
await crawler.crawler_strategy.kill_session(session_id)
|
||||
|
||||
# Assertions
|
||||
assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}"
|
||||
|
||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||
assert (
|
||||
len(all_commits) >= 90
|
||||
), f"Expected at least 90 commits, but got {len(all_commits)}"
|
||||
|
||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -75,4 +75,4 @@
|
||||
|
||||
# # Entry point for debugging
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-v"])
|
||||
# pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
import time
|
||||
from bs4 import BeautifulSoup
|
||||
from crawl4ai.content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from crawl4ai.content_scraping_strategy import (
|
||||
WebScrapingStrategy,
|
||||
LXMLWebScrapingStrategy,
|
||||
)
|
||||
from typing import Dict, List, Tuple
|
||||
import difflib
|
||||
from lxml import html as lhtml, etree
|
||||
|
||||
|
||||
def normalize_dom(element):
|
||||
"""
|
||||
Recursively normalizes an lxml HTML element:
|
||||
@@ -15,7 +19,7 @@ def normalize_dom(element):
|
||||
Returns the same element (mutated).
|
||||
"""
|
||||
# Remove comment nodes
|
||||
comments = element.xpath('//comment()')
|
||||
comments = element.xpath("//comment()")
|
||||
for c in comments:
|
||||
p = c.getparent()
|
||||
if p is not None:
|
||||
@@ -45,7 +49,7 @@ def strip_html_body(root):
|
||||
"""
|
||||
If 'root' is <html>, find its <body> child and move all of <body>'s children
|
||||
into a new <div>. Return that <div>.
|
||||
|
||||
|
||||
If 'root' is <body>, similarly move all of its children into a new <div> and return it.
|
||||
|
||||
Otherwise, return 'root' as-is.
|
||||
@@ -53,8 +57,8 @@ def strip_html_body(root):
|
||||
tag_name = (root.tag or "").lower()
|
||||
|
||||
# Case 1: The root is <html>
|
||||
if tag_name == 'html':
|
||||
bodies = root.xpath('./body')
|
||||
if tag_name == "html":
|
||||
bodies = root.xpath("./body")
|
||||
if bodies:
|
||||
body = bodies[0]
|
||||
new_div = lhtml.Element("div")
|
||||
@@ -66,7 +70,7 @@ def strip_html_body(root):
|
||||
return root
|
||||
|
||||
# Case 2: The root is <body>
|
||||
elif tag_name == 'body':
|
||||
elif tag_name == "body":
|
||||
new_div = lhtml.Element("div")
|
||||
for child in root:
|
||||
new_div.append(child)
|
||||
@@ -92,7 +96,9 @@ def compare_nodes(node1, node2, differences, path="/"):
|
||||
attrs1 = list(node1.attrib.items())
|
||||
attrs2 = list(node2.attrib.items())
|
||||
if attrs1 != attrs2:
|
||||
differences.append(f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}")
|
||||
differences.append(
|
||||
f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}"
|
||||
)
|
||||
|
||||
# 3) Compare text (trim or unify whitespace as needed)
|
||||
text1 = (node1.text or "").strip()
|
||||
@@ -102,7 +108,9 @@ def compare_nodes(node1, node2, differences, path="/"):
|
||||
text2 = " ".join(text2.split())
|
||||
if text1 != text2:
|
||||
# If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup
|
||||
differences.append(f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'")
|
||||
differences.append(
|
||||
f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'"
|
||||
)
|
||||
|
||||
# 4) Compare number of children
|
||||
children1 = list(node1)
|
||||
@@ -123,7 +131,9 @@ def compare_nodes(node1, node2, differences, path="/"):
|
||||
tail1 = (node1.tail or "").strip()
|
||||
tail2 = (node2.tail or "").strip()
|
||||
if tail1 != tail2:
|
||||
differences.append(f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'")
|
||||
differences.append(
|
||||
f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'"
|
||||
)
|
||||
|
||||
|
||||
def compare_html_structurally(html1, html2):
|
||||
@@ -156,11 +166,11 @@ def compare_html_structurally(html1, html2):
|
||||
return differences
|
||||
|
||||
|
||||
|
||||
def generate_large_html(n_elements=1000):
|
||||
html = ['<!DOCTYPE html><html><head></head><body>']
|
||||
html = ["<!DOCTYPE html><html><head></head><body>"]
|
||||
for i in range(n_elements):
|
||||
html.append(f'''
|
||||
html.append(
|
||||
f"""
|
||||
<div class="article">
|
||||
<h2>Heading {i}</h2>
|
||||
<p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p>
|
||||
@@ -170,13 +180,15 @@ def generate_large_html(n_elements=1000):
|
||||
<li>List item {i}.2</li>
|
||||
</ul>
|
||||
</div>
|
||||
''')
|
||||
html.append('</body></html>')
|
||||
return ''.join(html)
|
||||
"""
|
||||
)
|
||||
html.append("</body></html>")
|
||||
return "".join(html)
|
||||
|
||||
|
||||
def generate_complicated_html():
|
||||
"""
|
||||
HTML with multiple domains, forms, data attributes,
|
||||
HTML with multiple domains, forms, data attributes,
|
||||
various images, comments, style, and noscript to test all parameter toggles.
|
||||
"""
|
||||
return """
|
||||
@@ -258,7 +270,7 @@ def generate_complicated_html():
|
||||
def get_test_scenarios():
|
||||
"""
|
||||
Returns a dictionary of parameter sets (test scenarios) for the scraper.
|
||||
Each scenario name maps to a dictionary of keyword arguments
|
||||
Each scenario name maps to a dictionary of keyword arguments
|
||||
that will be passed into scrap() for testing various features.
|
||||
"""
|
||||
TEST_SCENARIOS = {
|
||||
@@ -341,7 +353,7 @@ def get_test_scenarios():
|
||||
# "exclude_external_links": True
|
||||
# },
|
||||
# "comprehensive_removal": {
|
||||
# # Exclude multiple tags, remove forms & comments,
|
||||
# # Exclude multiple tags, remove forms & comments,
|
||||
# # and also remove targeted selectors
|
||||
# "excluded_tags": ["aside", "noscript", "script"],
|
||||
# "excluded_selector": "#promo-section, .social-widget",
|
||||
@@ -352,19 +364,18 @@ def get_test_scenarios():
|
||||
return TEST_SCENARIOS
|
||||
|
||||
|
||||
|
||||
class ScraperEquivalenceTester:
|
||||
def __init__(self):
|
||||
self.test_cases = {
|
||||
'basic': self.generate_basic_html(),
|
||||
'complex': self.generate_complex_html(),
|
||||
'malformed': self.generate_malformed_html(),
|
||||
"basic": self.generate_basic_html(),
|
||||
"complex": self.generate_complex_html(),
|
||||
"malformed": self.generate_malformed_html(),
|
||||
# 'real_world': self.load_real_samples()
|
||||
}
|
||||
|
||||
|
||||
def generate_basic_html(self):
|
||||
return generate_large_html(1000) # Your existing function
|
||||
|
||||
|
||||
def generate_complex_html(self):
|
||||
return """
|
||||
<html><body>
|
||||
@@ -384,7 +395,7 @@ class ScraperEquivalenceTester:
|
||||
</div>
|
||||
</body></html>
|
||||
"""
|
||||
|
||||
|
||||
def generate_malformed_html(self):
|
||||
return """
|
||||
<div>Unclosed div
|
||||
@@ -395,139 +406,139 @@ class ScraperEquivalenceTester:
|
||||
<!-- Malformed comment -- > -->
|
||||
<![CDATA[Test CDATA]]>
|
||||
"""
|
||||
|
||||
|
||||
def load_real_samples(self):
|
||||
# Load some real-world HTML samples you've collected
|
||||
samples = {
|
||||
'article': open('tests/samples/article.html').read(),
|
||||
'product': open('tests/samples/product.html').read(),
|
||||
'blog': open('tests/samples/blog.html').read()
|
||||
"article": open("tests/samples/article.html").read(),
|
||||
"product": open("tests/samples/product.html").read(),
|
||||
"blog": open("tests/samples/blog.html").read(),
|
||||
}
|
||||
return samples
|
||||
|
||||
|
||||
def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]:
|
||||
"""Detailed comparison of link structures"""
|
||||
differences = []
|
||||
|
||||
for category in ['internal', 'external']:
|
||||
old_urls = {link['href'] for link in old_links[category]}
|
||||
new_urls = {link['href'] for link in new_links[category]}
|
||||
|
||||
|
||||
for category in ["internal", "external"]:
|
||||
old_urls = {link["href"] for link in old_links[category]}
|
||||
new_urls = {link["href"] for link in new_links[category]}
|
||||
|
||||
missing = old_urls - new_urls
|
||||
extra = new_urls - old_urls
|
||||
|
||||
|
||||
if missing:
|
||||
differences.append(f"Missing {category} links: {missing}")
|
||||
if extra:
|
||||
differences.append(f"Extra {category} links: {extra}")
|
||||
|
||||
|
||||
# Compare link attributes for common URLs
|
||||
common = old_urls & new_urls
|
||||
for url in common:
|
||||
old_link = next(l for l in old_links[category] if l['href'] == url)
|
||||
new_link = next(l for l in new_links[category] if l['href'] == url)
|
||||
|
||||
for attr in ['text', 'title']:
|
||||
old_link = next(l for l in old_links[category] if l["href"] == url)
|
||||
new_link = next(l for l in new_links[category] if l["href"] == url)
|
||||
|
||||
for attr in ["text", "title"]:
|
||||
if old_link[attr] != new_link[attr]:
|
||||
differences.append(
|
||||
f"Link attribute mismatch for {url} - {attr}:"
|
||||
f" old='{old_link[attr]}' vs new='{new_link[attr]}'"
|
||||
)
|
||||
|
||||
|
||||
return differences
|
||||
|
||||
def deep_compare_media(self, old_media: Dict, new_media: Dict) -> List[str]:
|
||||
"""Detailed comparison of media elements"""
|
||||
differences = []
|
||||
|
||||
for media_type in ['images', 'videos', 'audios']:
|
||||
old_srcs = {item['src'] for item in old_media[media_type]}
|
||||
new_srcs = {item['src'] for item in new_media[media_type]}
|
||||
|
||||
|
||||
for media_type in ["images", "videos", "audios"]:
|
||||
old_srcs = {item["src"] for item in old_media[media_type]}
|
||||
new_srcs = {item["src"] for item in new_media[media_type]}
|
||||
|
||||
missing = old_srcs - new_srcs
|
||||
extra = new_srcs - old_srcs
|
||||
|
||||
|
||||
if missing:
|
||||
differences.append(f"Missing {media_type}: {missing}")
|
||||
if extra:
|
||||
differences.append(f"Extra {media_type}: {extra}")
|
||||
|
||||
|
||||
# Compare media attributes for common sources
|
||||
common = old_srcs & new_srcs
|
||||
for src in common:
|
||||
old_item = next(m for m in old_media[media_type] if m['src'] == src)
|
||||
new_item = next(m for m in new_media[media_type] if m['src'] == src)
|
||||
|
||||
for attr in ['alt', 'description']:
|
||||
old_item = next(m for m in old_media[media_type] if m["src"] == src)
|
||||
new_item = next(m for m in new_media[media_type] if m["src"] == src)
|
||||
|
||||
for attr in ["alt", "description"]:
|
||||
if old_item.get(attr) != new_item.get(attr):
|
||||
differences.append(
|
||||
f"{media_type} attribute mismatch for {src} - {attr}:"
|
||||
f" old='{old_item.get(attr)}' vs new='{new_item.get(attr)}'"
|
||||
)
|
||||
|
||||
|
||||
return differences
|
||||
|
||||
def compare_html_content(self, old_html: str, new_html: str) -> List[str]:
|
||||
"""Compare HTML content structure and text"""
|
||||
# return compare_html_structurally(old_html, new_html)
|
||||
differences = []
|
||||
|
||||
|
||||
def normalize_html(html: str) -> Tuple[str, str]:
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
# Get both structure and text
|
||||
structure = ' '.join(tag.name for tag in soup.find_all())
|
||||
text = ' '.join(soup.get_text().split())
|
||||
structure = " ".join(tag.name for tag in soup.find_all())
|
||||
text = " ".join(soup.get_text().split())
|
||||
return structure, text
|
||||
|
||||
|
||||
old_structure, old_text = normalize_html(old_html)
|
||||
new_structure, new_text = normalize_html(new_html)
|
||||
|
||||
|
||||
# Compare structure
|
||||
if abs(len(old_structure) - len(new_structure)) > 100:
|
||||
# if old_structure != new_structure:
|
||||
# if old_structure != new_structure:
|
||||
diff = difflib.unified_diff(
|
||||
old_structure.split(),
|
||||
new_structure.split(),
|
||||
lineterm=''
|
||||
old_structure.split(), new_structure.split(), lineterm=""
|
||||
)
|
||||
differences.append("HTML structure differences:\n" + '\n'.join(diff))
|
||||
|
||||
differences.append("HTML structure differences:\n" + "\n".join(diff))
|
||||
|
||||
# Compare text content
|
||||
if abs(len(old_text) - len(new_text)) > 100:
|
||||
# if old_text != new_text:
|
||||
# if old_text != new_text:
|
||||
# Show detailed text differences
|
||||
text_diff = difflib.unified_diff(
|
||||
old_text.split(),
|
||||
new_text.split(),
|
||||
lineterm=''
|
||||
old_text.split(), new_text.split(), lineterm=""
|
||||
)
|
||||
differences.append("Text content differences:\n" + '\n'.join(text_diff))
|
||||
|
||||
differences.append("Text content differences:\n" + "\n".join(text_diff))
|
||||
|
||||
return differences
|
||||
|
||||
def compare_results(self, old_result: Dict, new_result: Dict) -> Dict[str, List[str]]:
|
||||
def compare_results(
|
||||
self, old_result: Dict, new_result: Dict
|
||||
) -> Dict[str, List[str]]:
|
||||
"""Comprehensive comparison of scraper outputs"""
|
||||
differences = {}
|
||||
|
||||
|
||||
# Compare links
|
||||
link_differences = self.deep_compare_links(old_result['links'], new_result['links'])
|
||||
link_differences = self.deep_compare_links(
|
||||
old_result["links"], new_result["links"]
|
||||
)
|
||||
if link_differences:
|
||||
differences['links'] = link_differences
|
||||
|
||||
differences["links"] = link_differences
|
||||
|
||||
# Compare media
|
||||
media_differences = self.deep_compare_media(old_result['media'], new_result['media'])
|
||||
media_differences = self.deep_compare_media(
|
||||
old_result["media"], new_result["media"]
|
||||
)
|
||||
if media_differences:
|
||||
differences['media'] = media_differences
|
||||
|
||||
differences["media"] = media_differences
|
||||
|
||||
# Compare HTML
|
||||
html_differences = self.compare_html_content(
|
||||
old_result['cleaned_html'],
|
||||
new_result['cleaned_html']
|
||||
old_result["cleaned_html"], new_result["cleaned_html"]
|
||||
)
|
||||
if html_differences:
|
||||
differences['html'] = html_differences
|
||||
|
||||
differences["html"] = html_differences
|
||||
|
||||
return differences
|
||||
|
||||
def run_tests(self) -> Dict:
|
||||
@@ -535,52 +546,49 @@ class ScraperEquivalenceTester:
|
||||
# We'll still keep some "test_cases" logic from above (basic, complex, malformed).
|
||||
# But we add a new section for the complicated HTML scenarios.
|
||||
|
||||
results = {
|
||||
'tests': [],
|
||||
'summary': {'passed': 0, 'failed': 0}
|
||||
}
|
||||
results = {"tests": [], "summary": {"passed": 0, "failed": 0}}
|
||||
|
||||
# 1) First, run the existing 3 built-in test cases (basic, complex, malformed).
|
||||
# for case_name, html in self.test_cases.items():
|
||||
# print(f"\nTesting built-in case: {case_name}...")
|
||||
|
||||
|
||||
# original = WebScrapingStrategy()
|
||||
# lxml = LXMLWebScrapingStrategy()
|
||||
|
||||
|
||||
# start = time.time()
|
||||
# orig_result = original.scrap("http://test.com", html)
|
||||
# orig_time = time.time() - start
|
||||
|
||||
|
||||
# print("\nOriginal Mode:")
|
||||
# print(f"Cleaned HTML size: {len(orig_result['cleaned_html'])/1024:.2f} KB")
|
||||
# print(f"Images: {len(orig_result['media']['images'])}")
|
||||
# print(f"External links: {len(orig_result['links']['external'])}")
|
||||
# print(f"Times - Original: {orig_time:.3f}s")
|
||||
|
||||
|
||||
# start = time.time()
|
||||
# lxml_result = lxml.scrap("http://test.com", html)
|
||||
# lxml_time = time.time() - start
|
||||
|
||||
|
||||
# print("\nLXML Mode:")
|
||||
# print(f"Cleaned HTML size: {len(lxml_result['cleaned_html'])/1024:.2f} KB")
|
||||
# print(f"Images: {len(lxml_result['media']['images'])}")
|
||||
# print(f"External links: {len(lxml_result['links']['external'])}")
|
||||
# print(f"Times - LXML: {lxml_time:.3f}s")
|
||||
|
||||
|
||||
# # Compare
|
||||
# diffs = {}
|
||||
# link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links'])
|
||||
# if link_diff:
|
||||
# diffs['links'] = link_diff
|
||||
|
||||
|
||||
# media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media'])
|
||||
# if media_diff:
|
||||
# diffs['media'] = media_diff
|
||||
|
||||
|
||||
# html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html'])
|
||||
# if html_diff:
|
||||
# diffs['html'] = html_diff
|
||||
|
||||
|
||||
# test_result = {
|
||||
# 'case': case_name,
|
||||
# 'lxml_mode': {
|
||||
@@ -590,7 +598,7 @@ class ScraperEquivalenceTester:
|
||||
# 'original_time': orig_time
|
||||
# }
|
||||
# results['tests'].append(test_result)
|
||||
|
||||
|
||||
# if not diffs:
|
||||
# results['summary']['passed'] += 1
|
||||
# else:
|
||||
@@ -599,50 +607,55 @@ class ScraperEquivalenceTester:
|
||||
# 2) Now, run the complicated HTML with multiple parameter scenarios.
|
||||
complicated_html = generate_complicated_html()
|
||||
print("\n=== Testing complicated HTML with multiple parameter scenarios ===")
|
||||
|
||||
|
||||
# Create the scrapers once (or you can re-create if needed)
|
||||
original = WebScrapingStrategy()
|
||||
lxml = LXMLWebScrapingStrategy()
|
||||
|
||||
for scenario_name, params in get_test_scenarios().items():
|
||||
print(f"\nScenario: {scenario_name}")
|
||||
|
||||
|
||||
start = time.time()
|
||||
orig_result = original.scrap("http://test.com", complicated_html, **params)
|
||||
orig_time = time.time() - start
|
||||
|
||||
|
||||
start = time.time()
|
||||
lxml_result = lxml.scrap("http://test.com", complicated_html, **params)
|
||||
lxml_time = time.time() - start
|
||||
|
||||
|
||||
diffs = {}
|
||||
link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links'])
|
||||
link_diff = self.deep_compare_links(
|
||||
orig_result["links"], lxml_result["links"]
|
||||
)
|
||||
if link_diff:
|
||||
diffs['links'] = link_diff
|
||||
diffs["links"] = link_diff
|
||||
|
||||
media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media'])
|
||||
media_diff = self.deep_compare_media(
|
||||
orig_result["media"], lxml_result["media"]
|
||||
)
|
||||
if media_diff:
|
||||
diffs['media'] = media_diff
|
||||
diffs["media"] = media_diff
|
||||
|
||||
html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html'])
|
||||
html_diff = self.compare_html_content(
|
||||
orig_result["cleaned_html"], lxml_result["cleaned_html"]
|
||||
)
|
||||
if html_diff:
|
||||
diffs['html'] = html_diff
|
||||
|
||||
diffs["html"] = html_diff
|
||||
|
||||
test_result = {
|
||||
'case': f"complicated_{scenario_name}",
|
||||
'lxml_mode': {
|
||||
'differences': diffs,
|
||||
'execution_time': lxml_time
|
||||
},
|
||||
'original_time': orig_time
|
||||
"case": f"complicated_{scenario_name}",
|
||||
"lxml_mode": {"differences": diffs, "execution_time": lxml_time},
|
||||
"original_time": orig_time,
|
||||
}
|
||||
results['tests'].append(test_result)
|
||||
|
||||
results["tests"].append(test_result)
|
||||
|
||||
if not diffs:
|
||||
results['summary']['passed'] += 1
|
||||
print(f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)")
|
||||
results["summary"]["passed"] += 1
|
||||
print(
|
||||
f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)"
|
||||
)
|
||||
else:
|
||||
results['summary']['failed'] += 1
|
||||
results["summary"]["failed"] += 1
|
||||
print("❌ Differences found:")
|
||||
for category, dlist in diffs.items():
|
||||
print(f" {category}:")
|
||||
@@ -657,20 +670,22 @@ class ScraperEquivalenceTester:
|
||||
print(f"Total Cases: {len(results['tests'])}")
|
||||
print(f"Passed: {results['summary']['passed']}")
|
||||
print(f"Failed: {results['summary']['failed']}")
|
||||
|
||||
for test in results['tests']:
|
||||
|
||||
for test in results["tests"]:
|
||||
print(f"\nTest Case: {test['case']}")
|
||||
|
||||
if not test['lxml_mode']['differences']:
|
||||
|
||||
if not test["lxml_mode"]["differences"]:
|
||||
print("✅ All implementations produced identical results")
|
||||
print(f"Times - Original: {test['original_time']:.3f}s, "
|
||||
f"LXML: {test['lxml_mode']['execution_time']:.3f}s")
|
||||
print(
|
||||
f"Times - Original: {test['original_time']:.3f}s, "
|
||||
f"LXML: {test['lxml_mode']['execution_time']:.3f}s"
|
||||
)
|
||||
else:
|
||||
print("❌ Differences found:")
|
||||
|
||||
if test['lxml_mode']['differences']:
|
||||
|
||||
if test["lxml_mode"]["differences"]:
|
||||
print("\nLXML Mode Differences:")
|
||||
for category, diffs in test['lxml_mode']['differences'].items():
|
||||
for category, diffs in test["lxml_mode"]["differences"].items():
|
||||
print(f"\n{category}:")
|
||||
for diff in diffs:
|
||||
print(f" - {diff}")
|
||||
@@ -680,11 +695,11 @@ def main():
|
||||
tester = ScraperEquivalenceTester()
|
||||
results = tester.run_tests()
|
||||
tester.print_report(results)
|
||||
|
||||
|
||||
# Save detailed results for debugging
|
||||
with open('scraper_equivalence_results.json', 'w') as f:
|
||||
with open("scraper_equivalence_results.json", "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
# - **State:** open
|
||||
|
||||
import os, sys, time
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
import asyncio
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
@@ -16,18 +16,18 @@ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||
# Get current directory
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
|
||||
"""Helper function to print test results."""
|
||||
print(f"\n{'='*20} {name} {'='*20}")
|
||||
print(f"Execution time: {execution_time:.4f} seconds")
|
||||
|
||||
|
||||
|
||||
# Save markdown to files
|
||||
for key, content in result.items():
|
||||
if isinstance(content, str):
|
||||
with open(__location__ + f"/output/{name.lower()}_{key}.md", "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
# # Print first few lines of each markdown version
|
||||
# for key, content in result.items():
|
||||
# if isinstance(content, str):
|
||||
@@ -36,32 +36,39 @@ def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
|
||||
# print(preview)
|
||||
# print(f"Total length: {len(content)} characters")
|
||||
|
||||
|
||||
def test_basic_markdown_conversion():
|
||||
"""Test basic markdown conversion with links."""
|
||||
with open(__location__ + "/data/wikipedia.html", "r") as f:
|
||||
cleaned_html = f.read()
|
||||
|
||||
generator = DefaultMarkdownGenerator()
|
||||
|
||||
|
||||
start_time = time.perf_counter()
|
||||
result = generator.generate_markdown(
|
||||
cleaned_html=cleaned_html,
|
||||
base_url="https://en.wikipedia.org"
|
||||
cleaned_html=cleaned_html, base_url="https://en.wikipedia.org"
|
||||
)
|
||||
execution_time = time.perf_counter() - start_time
|
||||
|
||||
print_test_result("Basic Markdown Conversion", {
|
||||
'raw': result.raw_markdown,
|
||||
'with_citations': result.markdown_with_citations,
|
||||
'references': result.references_markdown
|
||||
}, execution_time)
|
||||
|
||||
|
||||
print_test_result(
|
||||
"Basic Markdown Conversion",
|
||||
{
|
||||
"raw": result.raw_markdown,
|
||||
"with_citations": result.markdown_with_citations,
|
||||
"references": result.references_markdown,
|
||||
},
|
||||
execution_time,
|
||||
)
|
||||
|
||||
# Basic assertions
|
||||
assert result.raw_markdown, "Raw markdown should not be empty"
|
||||
assert result.markdown_with_citations, "Markdown with citations should not be empty"
|
||||
assert result.references_markdown, "References should not be empty"
|
||||
assert "⟨" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets"
|
||||
assert "## References" in result.references_markdown, "Should contain references section"
|
||||
assert (
|
||||
"## References" in result.references_markdown
|
||||
), "Should contain references section"
|
||||
|
||||
|
||||
def test_relative_links():
|
||||
"""Test handling of relative links with base URL."""
|
||||
@@ -69,97 +76,106 @@ def test_relative_links():
|
||||
Here's a [relative link](/wiki/Apple) and an [absolute link](https://example.com).
|
||||
Also an [image](/images/test.png) and another [page](/wiki/Banana).
|
||||
"""
|
||||
|
||||
|
||||
generator = DefaultMarkdownGenerator()
|
||||
result = generator.generate_markdown(
|
||||
cleaned_html=markdown,
|
||||
base_url="https://en.wikipedia.org"
|
||||
cleaned_html=markdown, base_url="https://en.wikipedia.org"
|
||||
)
|
||||
|
||||
|
||||
assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown
|
||||
assert "https://example.com" in result.references_markdown
|
||||
assert "https://en.wikipedia.org/images/test.png" in result.references_markdown
|
||||
|
||||
|
||||
def test_duplicate_links():
|
||||
"""Test handling of duplicate links."""
|
||||
markdown = """
|
||||
Here's a [link](/test) and another [link](/test) and a [different link](/other).
|
||||
"""
|
||||
|
||||
|
||||
generator = DefaultMarkdownGenerator()
|
||||
result = generator.generate_markdown(
|
||||
cleaned_html=markdown,
|
||||
base_url="https://example.com"
|
||||
cleaned_html=markdown, base_url="https://example.com"
|
||||
)
|
||||
|
||||
|
||||
# Count citations in markdown
|
||||
citations = result.markdown_with_citations.count("⟨1⟩")
|
||||
assert citations == 2, "Same link should use same citation number"
|
||||
|
||||
|
||||
def test_link_descriptions():
|
||||
"""Test handling of link titles and descriptions."""
|
||||
markdown = """
|
||||
Here's a [link with title](/test "Test Title") and a [link with description](/other) to test.
|
||||
"""
|
||||
|
||||
|
||||
generator = DefaultMarkdownGenerator()
|
||||
result = generator.generate_markdown(
|
||||
cleaned_html=markdown,
|
||||
base_url="https://example.com"
|
||||
cleaned_html=markdown, base_url="https://example.com"
|
||||
)
|
||||
|
||||
assert "Test Title" in result.references_markdown, "Link title should be in references"
|
||||
assert "link with description" in result.references_markdown, "Link text should be in references"
|
||||
|
||||
assert (
|
||||
"Test Title" in result.references_markdown
|
||||
), "Link title should be in references"
|
||||
assert (
|
||||
"link with description" in result.references_markdown
|
||||
), "Link text should be in references"
|
||||
|
||||
|
||||
def test_performance_large_document():
|
||||
"""Test performance with large document."""
|
||||
with open(__location__ + "/data/wikipedia.md", "r") as f:
|
||||
markdown = f.read()
|
||||
|
||||
|
||||
# Test with multiple iterations
|
||||
iterations = 5
|
||||
times = []
|
||||
|
||||
|
||||
generator = DefaultMarkdownGenerator()
|
||||
|
||||
|
||||
for i in range(iterations):
|
||||
start_time = time.perf_counter()
|
||||
result = generator.generate_markdown(
|
||||
cleaned_html=markdown,
|
||||
base_url="https://en.wikipedia.org"
|
||||
cleaned_html=markdown, base_url="https://en.wikipedia.org"
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
times.append(end_time - start_time)
|
||||
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"\n{'='*20} Performance Test {'='*20}")
|
||||
print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds")
|
||||
print(
|
||||
f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds"
|
||||
)
|
||||
print(f"Min time: {min(times):.4f} seconds")
|
||||
print(f"Max time: {max(times):.4f} seconds")
|
||||
|
||||
|
||||
def test_image_links():
|
||||
"""Test handling of image links."""
|
||||
markdown = """
|
||||
Here's an  and another .
|
||||
And a regular [link](/page).
|
||||
"""
|
||||
|
||||
|
||||
generator = DefaultMarkdownGenerator()
|
||||
result = generator.generate_markdown(
|
||||
cleaned_html=markdown,
|
||||
base_url="https://example.com"
|
||||
cleaned_html=markdown, base_url="https://example.com"
|
||||
)
|
||||
|
||||
assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved"
|
||||
assert "Image Title" in result.references_markdown, "Image title should be in references"
|
||||
|
||||
assert (
|
||||
"![" in result.markdown_with_citations
|
||||
), "Image markdown syntax should be preserved"
|
||||
assert (
|
||||
"Image Title" in result.references_markdown
|
||||
), "Image title should be in references"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Running markdown generation strategy tests...")
|
||||
|
||||
|
||||
test_basic_markdown_conversion()
|
||||
test_relative_links()
|
||||
test_duplicate_links()
|
||||
test_link_descriptions()
|
||||
test_performance_large_document()
|
||||
test_image_links()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -10,24 +8,37 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_word_count_threshold():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True)
|
||||
result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True)
|
||||
|
||||
result_no_threshold = await crawler.arun(
|
||||
url=url, word_count_threshold=0, bypass_cache=True
|
||||
)
|
||||
result_with_threshold = await crawler.arun(
|
||||
url=url, word_count_threshold=50, bypass_cache=True
|
||||
)
|
||||
|
||||
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_css_selector():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
css_selector = "h1, h2, h3"
|
||||
result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True)
|
||||
|
||||
result = await crawler.arun(
|
||||
url=url, css_selector=css_selector, bypass_cache=True
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert "<h1" in result.cleaned_html or "<h2" in result.cleaned_html or "<h3" in result.cleaned_html
|
||||
assert (
|
||||
"<h1" in result.cleaned_html
|
||||
or "<h2" in result.cleaned_html
|
||||
or "<h3" in result.cleaned_html
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_javascript_execution():
|
||||
@@ -36,59 +47,70 @@ async def test_javascript_execution():
|
||||
|
||||
# Crawl without JS
|
||||
result_without_more = await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
js_code = ["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"]
|
||||
|
||||
js_code = [
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
]
|
||||
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
|
||||
|
||||
|
||||
assert result_with_more.success
|
||||
assert len(result_with_more.markdown) > len(result_without_more.markdown)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, screenshot=True, bypass_cache=True)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.screenshot
|
||||
assert isinstance(result.screenshot, str) # Should be a base64 encoded string
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_user_agent():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
|
||||
result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True)
|
||||
|
||||
result = await crawler.arun(
|
||||
url=url, user_agent=custom_user_agent, bypass_cache=True
|
||||
)
|
||||
|
||||
assert result.success
|
||||
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_media_and_links():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.media
|
||||
assert isinstance(result.media, dict)
|
||||
assert 'images' in result.media
|
||||
assert "images" in result.media
|
||||
assert result.links
|
||||
assert isinstance(result.links, dict)
|
||||
assert 'internal' in result.links and 'external' in result.links
|
||||
assert "internal" in result.links and "external" in result.links
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_extraction():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.metadata
|
||||
assert isinstance(result.metadata, dict)
|
||||
# Check for common metadata fields
|
||||
assert any(key in result.metadata for key in ['title', 'description', 'keywords'])
|
||||
assert any(
|
||||
key in result.metadata for key in ["title", "description", "keywords"]
|
||||
)
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
@@ -10,6 +9,7 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_speed():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -17,13 +17,14 @@ async def test_crawl_speed():
|
||||
start_time = time.time()
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
assert result.success
|
||||
crawl_time = end_time - start_time
|
||||
print(f"Crawl time: {crawl_time:.2f} seconds")
|
||||
|
||||
|
||||
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_crawling_performance():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
@@ -32,41 +33,47 @@ async def test_concurrent_crawling_performance():
|
||||
"https://www.example.com",
|
||||
"https://www.python.org",
|
||||
"https://www.github.com",
|
||||
"https://www.stackoverflow.com"
|
||||
"https://www.stackoverflow.com",
|
||||
]
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
total_time = end_time - start_time
|
||||
print(f"Total time for concurrent crawling: {total_time:.2f} seconds")
|
||||
|
||||
|
||||
assert all(result.success for result in results)
|
||||
assert len(results) == len(urls)
|
||||
|
||||
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||
|
||||
assert (
|
||||
total_time < len(urls) * 5
|
||||
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_speed_with_caching():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
result1 = await crawler.arun(url=url, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
first_crawl_time = end_time - start_time
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
result2 = await crawler.arun(url=url, bypass_cache=False)
|
||||
end_time = time.time()
|
||||
second_crawl_time = end_time - start_time
|
||||
|
||||
|
||||
assert result1.success and result2.success
|
||||
print(f"First crawl time: {first_crawl_time:.2f} seconds")
|
||||
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
|
||||
|
||||
assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster"
|
||||
|
||||
assert (
|
||||
second_crawl_time < first_crawl_time / 2
|
||||
), "Cached crawl not significantly faster"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
@@ -12,113 +11,112 @@ sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_screenshot():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://example.com" # A static website
|
||||
result = await crawler.arun(url=url, bypass_cache=True, screenshot=True)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.screenshot is not None
|
||||
|
||||
|
||||
# Verify the screenshot is a valid image
|
||||
image_data = base64.b64decode(result.screenshot)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
assert image.format == "PNG"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot_with_wait_for():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# Using a website with dynamic content
|
||||
url = "https://www.youtube.com"
|
||||
wait_for = "css:#content" # Wait for the main content to load
|
||||
|
||||
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
bypass_cache=True,
|
||||
screenshot=True,
|
||||
wait_for=wait_for
|
||||
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
|
||||
)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.screenshot is not None
|
||||
|
||||
|
||||
# Verify the screenshot is a valid image
|
||||
image_data = base64.b64decode(result.screenshot)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
assert image.format == "PNG"
|
||||
|
||||
|
||||
# You might want to add more specific checks here, like image dimensions
|
||||
# or even use image recognition to verify certain elements are present
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot_with_js_wait_for():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.amazon.com"
|
||||
wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null"
|
||||
|
||||
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
bypass_cache=True,
|
||||
screenshot=True,
|
||||
wait_for=wait_for
|
||||
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
|
||||
)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.screenshot is not None
|
||||
|
||||
|
||||
image_data = base64.b64decode(result.screenshot)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
assert image.format == "PNG"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot_without_wait_for():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nytimes.com" # A website with lots of dynamic content
|
||||
|
||||
|
||||
result = await crawler.arun(url=url, bypass_cache=True, screenshot=True)
|
||||
|
||||
|
||||
assert result.success
|
||||
assert result.screenshot is not None
|
||||
|
||||
|
||||
image_data = base64.b64decode(result.screenshot)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
assert image.format == "PNG"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot_comparison():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.reddit.com"
|
||||
wait_for = "css:#SHORTCUT_FOCUSABLE_DIV"
|
||||
|
||||
|
||||
# Take screenshot without wait_for
|
||||
result_without_wait = await crawler.arun(
|
||||
url=url,
|
||||
bypass_cache=True,
|
||||
screenshot=True
|
||||
url=url, bypass_cache=True, screenshot=True
|
||||
)
|
||||
|
||||
|
||||
# Take screenshot with wait_for
|
||||
result_with_wait = await crawler.arun(
|
||||
url=url,
|
||||
bypass_cache=True,
|
||||
screenshot=True,
|
||||
wait_for=wait_for
|
||||
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
|
||||
)
|
||||
|
||||
|
||||
assert result_without_wait.success and result_with_wait.success
|
||||
assert result_without_wait.screenshot is not None
|
||||
assert result_with_wait.screenshot is not None
|
||||
|
||||
|
||||
# Compare the two screenshots
|
||||
image_without_wait = Image.open(io.BytesIO(base64.b64decode(result_without_wait.screenshot)))
|
||||
image_with_wait = Image.open(io.BytesIO(base64.b64decode(result_with_wait.screenshot)))
|
||||
|
||||
image_without_wait = Image.open(
|
||||
io.BytesIO(base64.b64decode(result_without_wait.screenshot))
|
||||
)
|
||||
image_with_wait = Image.open(
|
||||
io.BytesIO(base64.b64decode(result_with_wait.screenshot))
|
||||
)
|
||||
|
||||
# This is a simple size comparison. In a real-world scenario, you might want to use
|
||||
# more sophisticated image comparison techniques.
|
||||
assert image_with_wait.size[0] >= image_without_wait.size[0]
|
||||
assert image_with_wait.size[1] >= image_without_wait.size[1]
|
||||
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -6,53 +6,72 @@ import base64
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class Crawl4AiTester:
|
||||
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
|
||||
self.base_url = base_url
|
||||
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback
|
||||
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {}
|
||||
|
||||
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
|
||||
self.api_token = api_token or os.getenv(
|
||||
"CRAWL4AI_API_TOKEN"
|
||||
) # Check environment variable as fallback
|
||||
self.headers = (
|
||||
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
|
||||
)
|
||||
|
||||
def submit_and_wait(
|
||||
self, request_data: Dict[str, Any], timeout: int = 300
|
||||
) -> Dict[str, Any]:
|
||||
# Submit crawl job
|
||||
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/crawl", json=request_data, headers=self.headers
|
||||
)
|
||||
if response.status_code == 403:
|
||||
raise Exception("API token is invalid or missing")
|
||||
task_id = response.json()["task_id"]
|
||||
print(f"Task ID: {task_id}")
|
||||
|
||||
|
||||
# Poll for result
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
||||
|
||||
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers)
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
result = requests.get(
|
||||
f"{self.base_url}/task/{task_id}", headers=self.headers
|
||||
)
|
||||
status = result.json()
|
||||
|
||||
|
||||
if status["status"] == "failed":
|
||||
print("Task failed:", status.get("error"))
|
||||
raise Exception(f"Task failed: {status.get('error')}")
|
||||
|
||||
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/crawl_sync",
|
||||
json=request_data,
|
||||
headers=self.headers,
|
||||
timeout=60,
|
||||
)
|
||||
if response.status_code == 408:
|
||||
raise TimeoutError("Task did not complete within server timeout")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def test_docker_deployment(version="basic"):
|
||||
tester = Crawl4AiTester(
|
||||
# base_url="http://localhost:11235" ,
|
||||
base_url="https://crawl4ai-sby74.ondigitalocean.app",
|
||||
api_token="test"
|
||||
api_token="test",
|
||||
)
|
||||
print(f"Testing Crawl4AI Docker {version} version")
|
||||
|
||||
|
||||
# Health check with timeout and retry
|
||||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
@@ -60,18 +79,18 @@ def test_docker_deployment(version="basic"):
|
||||
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
||||
print("Health check:", health.json())
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.exceptions.RequestException:
|
||||
if i == max_retries - 1:
|
||||
print(f"Failed to connect after {max_retries} attempts")
|
||||
sys.exit(1)
|
||||
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# Test cases based on version
|
||||
test_basic_crawl(tester)
|
||||
test_basic_crawl(tester)
|
||||
test_basic_crawl_sync(tester)
|
||||
|
||||
|
||||
# if version in ["full", "transformer"]:
|
||||
# test_cosine_extraction(tester)
|
||||
|
||||
@@ -81,35 +100,37 @@ def test_docker_deployment(version="basic"):
|
||||
# test_llm_extraction(tester)
|
||||
# test_llm_with_ollama(tester)
|
||||
# test_screenshot(tester)
|
||||
|
||||
|
||||
|
||||
def test_basic_crawl(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Basic Crawl ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10,
|
||||
"session_id": "test"
|
||||
"priority": 10,
|
||||
"session_id": "test",
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
assert len(result["result"]["markdown"]) > 0
|
||||
|
||||
|
||||
def test_basic_crawl_sync(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Basic Crawl (Sync) ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10,
|
||||
"session_id": "test"
|
||||
"session_id": "test",
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_sync(request)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result['status'] == 'completed'
|
||||
assert result['result']['success']
|
||||
assert len(result['result']['markdown']) > 0
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
assert len(result["result"]["markdown"]) > 0
|
||||
|
||||
|
||||
def test_js_execution(tester: Crawl4AiTester):
|
||||
print("\n=== Testing JS Execution ===")
|
||||
request = {
|
||||
@@ -119,32 +140,29 @@ def test_js_execution(tester: Crawl4AiTester):
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
],
|
||||
"wait_for": "article.tease-card:nth-child(10)",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
def test_css_selector(tester: Crawl4AiTester):
|
||||
print("\n=== Testing CSS Selector ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 7,
|
||||
"css_selector": ".wide-tease-item__description",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
},
|
||||
"extra": {"word_count_threshold": 10}
|
||||
|
||||
"crawler_params": {"headless": True},
|
||||
"extra": {"word_count_threshold": 10},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
def test_structured_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Structured Extraction ===")
|
||||
schema = {
|
||||
@@ -165,21 +183,16 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
||||
"name": "price",
|
||||
"selector": "td:nth-child(2)",
|
||||
"type": "text",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.coinbase.com/explore",
|
||||
"priority": 9,
|
||||
"extraction_config": {
|
||||
"type": "json_css",
|
||||
"params": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
print(f"Extracted {len(extracted)} items")
|
||||
@@ -187,6 +200,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
||||
assert result["result"]["success"]
|
||||
assert len(extracted) > 0
|
||||
|
||||
|
||||
def test_llm_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing LLM Extraction ===")
|
||||
schema = {
|
||||
@@ -194,20 +208,20 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
"properties": {
|
||||
"model_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the OpenAI model."
|
||||
"description": "Name of the OpenAI model.",
|
||||
},
|
||||
"input_fee": {
|
||||
"type": "string",
|
||||
"description": "Fee for input token for the OpenAI model."
|
||||
"description": "Fee for input token for the OpenAI model.",
|
||||
},
|
||||
"output_fee": {
|
||||
"type": "string",
|
||||
"description": "Fee for output token for the OpenAI model."
|
||||
}
|
||||
"description": "Fee for output token for the OpenAI model.",
|
||||
},
|
||||
},
|
||||
"required": ["model_name", "input_fee", "output_fee"]
|
||||
"required": ["model_name", "input_fee", "output_fee"],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://openai.com/api/pricing",
|
||||
"priority": 8,
|
||||
@@ -218,12 +232,12 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
"api_token": os.getenv("OPENAI_API_KEY"),
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
|
||||
}
|
||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
||||
},
|
||||
},
|
||||
"crawler_params": {"word_count_threshold": 1}
|
||||
"crawler_params": {"word_count_threshold": 1},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -233,6 +247,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
||||
|
||||
|
||||
def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
print("\n=== Testing LLM with Ollama ===")
|
||||
schema = {
|
||||
@@ -240,20 +255,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
"properties": {
|
||||
"article_title": {
|
||||
"type": "string",
|
||||
"description": "The main title of the news article"
|
||||
"description": "The main title of the news article",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A brief summary of the article content"
|
||||
"description": "A brief summary of the article content",
|
||||
},
|
||||
"main_topics": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Main topics or themes discussed in the article"
|
||||
}
|
||||
}
|
||||
"description": "Main topics or themes discussed in the article",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 8,
|
||||
@@ -263,13 +278,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
"provider": "ollama/llama2",
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": "Extract the main article information including title, summary, and main topics."
|
||||
}
|
||||
"instruction": "Extract the main article information including title, summary, and main topics.",
|
||||
},
|
||||
},
|
||||
"extra": {"word_count_threshold": 1},
|
||||
"crawler_params": {"verbose": True}
|
||||
"crawler_params": {"verbose": True},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -278,6 +293,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"Ollama extraction test failed: {str(e)}")
|
||||
|
||||
|
||||
def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Cosine Extraction ===")
|
||||
request = {
|
||||
@@ -289,11 +305,11 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
"semantic_filter": "business finance economy",
|
||||
"word_count_threshold": 10,
|
||||
"max_dist": 0.2,
|
||||
"top_k": 3
|
||||
}
|
||||
}
|
||||
"top_k": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -303,30 +319,30 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"Cosine extraction test failed: {str(e)}")
|
||||
|
||||
|
||||
def test_screenshot(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Screenshot ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 5,
|
||||
"screenshot": True,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print("Screenshot captured:", bool(result["result"]["screenshot"]))
|
||||
|
||||
|
||||
if result["result"]["screenshot"]:
|
||||
# Save screenshot
|
||||
screenshot_data = base64.b64decode(result["result"]["screenshot"])
|
||||
with open("test_screenshot.jpg", "wb") as f:
|
||||
f.write(screenshot_data)
|
||||
print("Screenshot saved as test_screenshot.jpg")
|
||||
|
||||
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
||||
# version = "full"
|
||||
test_docker_deployment(version)
|
||||
test_docker_deployment(version)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from crawl4ai.docs_manager import DocsManager
|
||||
from click.testing import CliRunner
|
||||
from crawl4ai.cli import cli
|
||||
|
||||
|
||||
def test_cli():
|
||||
"""Test all CLI commands"""
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
print("\n1. Testing docs update...")
|
||||
# Use sync version for testing
|
||||
docs_manager = DocsManager()
|
||||
@@ -27,17 +27,18 @@ def test_cli():
|
||||
# print("\n3. Testing search...")
|
||||
# result = runner.invoke(cli, ['docs', 'search', 'how to use crawler', '--build-index'])
|
||||
# print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
|
||||
# print(f"First 200 chars: {result.output[:200]}...")
|
||||
|
||||
# print(f"First 200 chars: {result.output[:200]}...")
|
||||
|
||||
# print("\n4. Testing combine with sections...")
|
||||
# result = runner.invoke(cli, ['docs', 'combine', 'chunking_strategies', 'extraction_strategies', '--mode', 'extended'])
|
||||
# print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
|
||||
# print(f"First 200 chars: {result.output[:200]}...")
|
||||
|
||||
print("\n5. Testing combine all sections...")
|
||||
result = runner.invoke(cli, ['docs', 'combine', '--mode', 'condensed'])
|
||||
result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"])
|
||||
print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
|
||||
print(f"First 200 chars: {result.output[:200]}...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cli()
|
||||
test_cli()
|
||||
|
||||
@@ -6,38 +6,44 @@ import base64
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class Crawl4AiTester:
|
||||
def __init__(self, base_url: str = "http://localhost:11235"):
|
||||
self.base_url = base_url
|
||||
|
||||
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
|
||||
|
||||
def submit_and_wait(
|
||||
self, request_data: Dict[str, Any], timeout: int = 300
|
||||
) -> Dict[str, Any]:
|
||||
# Submit crawl job
|
||||
response = requests.post(f"{self.base_url}/crawl", json=request_data)
|
||||
task_id = response.json()["task_id"]
|
||||
print(f"Task ID: {task_id}")
|
||||
|
||||
|
||||
# Poll for result
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
||||
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
result = requests.get(f"{self.base_url}/task/{task_id}")
|
||||
status = result.json()
|
||||
|
||||
|
||||
if status["status"] == "failed":
|
||||
print("Task failed:", status.get("error"))
|
||||
raise Exception(f"Task failed: {status.get('error')}")
|
||||
|
||||
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def test_docker_deployment(version="basic"):
|
||||
tester = Crawl4AiTester()
|
||||
print(f"Testing Crawl4AI Docker {version} version")
|
||||
|
||||
|
||||
# Health check with timeout and retry
|
||||
max_retries = 5
|
||||
for i in range(max_retries):
|
||||
@@ -45,16 +51,16 @@ def test_docker_deployment(version="basic"):
|
||||
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
||||
print("Health check:", health.json())
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.exceptions.RequestException:
|
||||
if i == max_retries - 1:
|
||||
print(f"Failed to connect after {max_retries} attempts")
|
||||
sys.exit(1)
|
||||
print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# Test cases based on version
|
||||
test_basic_crawl(tester)
|
||||
|
||||
|
||||
# if version in ["full", "transformer"]:
|
||||
# test_cosine_extraction(tester)
|
||||
|
||||
@@ -64,20 +70,18 @@ def test_docker_deployment(version="basic"):
|
||||
# test_llm_extraction(tester)
|
||||
# test_llm_with_ollama(tester)
|
||||
# test_screenshot(tester)
|
||||
|
||||
|
||||
|
||||
def test_basic_crawl(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Basic Crawl ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10
|
||||
}
|
||||
|
||||
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
assert len(result["result"]["markdown"]) > 0
|
||||
|
||||
|
||||
def test_js_execution(tester: Crawl4AiTester):
|
||||
print("\n=== Testing JS Execution ===")
|
||||
request = {
|
||||
@@ -87,32 +91,29 @@ def test_js_execution(tester: Crawl4AiTester):
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
],
|
||||
"wait_for": "article.tease-card:nth-child(10)",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
def test_css_selector(tester: Crawl4AiTester):
|
||||
print("\n=== Testing CSS Selector ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 7,
|
||||
"css_selector": ".wide-tease-item__description",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
},
|
||||
"extra": {"word_count_threshold": 10}
|
||||
|
||||
"crawler_params": {"headless": True},
|
||||
"extra": {"word_count_threshold": 10},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
def test_structured_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Structured Extraction ===")
|
||||
schema = {
|
||||
@@ -133,21 +134,16 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
||||
"name": "price",
|
||||
"selector": "td:nth-child(2)",
|
||||
"type": "text",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.coinbase.com/explore",
|
||||
"priority": 9,
|
||||
"extraction_config": {
|
||||
"type": "json_css",
|
||||
"params": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
print(f"Extracted {len(extracted)} items")
|
||||
@@ -155,6 +151,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
||||
assert result["result"]["success"]
|
||||
assert len(extracted) > 0
|
||||
|
||||
|
||||
def test_llm_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing LLM Extraction ===")
|
||||
schema = {
|
||||
@@ -162,20 +159,20 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
"properties": {
|
||||
"model_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the OpenAI model."
|
||||
"description": "Name of the OpenAI model.",
|
||||
},
|
||||
"input_fee": {
|
||||
"type": "string",
|
||||
"description": "Fee for input token for the OpenAI model."
|
||||
"description": "Fee for input token for the OpenAI model.",
|
||||
},
|
||||
"output_fee": {
|
||||
"type": "string",
|
||||
"description": "Fee for output token for the OpenAI model."
|
||||
}
|
||||
"description": "Fee for output token for the OpenAI model.",
|
||||
},
|
||||
},
|
||||
"required": ["model_name", "input_fee", "output_fee"]
|
||||
"required": ["model_name", "input_fee", "output_fee"],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://openai.com/api/pricing",
|
||||
"priority": 8,
|
||||
@@ -186,12 +183,12 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
"api_token": os.getenv("OPENAI_API_KEY"),
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
|
||||
}
|
||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
||||
},
|
||||
},
|
||||
"crawler_params": {"word_count_threshold": 1}
|
||||
"crawler_params": {"word_count_threshold": 1},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -201,6 +198,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
||||
|
||||
|
||||
def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
print("\n=== Testing LLM with Ollama ===")
|
||||
schema = {
|
||||
@@ -208,20 +206,20 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
"properties": {
|
||||
"article_title": {
|
||||
"type": "string",
|
||||
"description": "The main title of the news article"
|
||||
"description": "The main title of the news article",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A brief summary of the article content"
|
||||
"description": "A brief summary of the article content",
|
||||
},
|
||||
"main_topics": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Main topics or themes discussed in the article"
|
||||
}
|
||||
}
|
||||
"description": "Main topics or themes discussed in the article",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 8,
|
||||
@@ -231,13 +229,13 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
"provider": "ollama/llama2",
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": "Extract the main article information including title, summary, and main topics."
|
||||
}
|
||||
"instruction": "Extract the main article information including title, summary, and main topics.",
|
||||
},
|
||||
},
|
||||
"extra": {"word_count_threshold": 1},
|
||||
"crawler_params": {"verbose": True}
|
||||
"crawler_params": {"verbose": True},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -246,6 +244,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"Ollama extraction test failed: {str(e)}")
|
||||
|
||||
|
||||
def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Cosine Extraction ===")
|
||||
request = {
|
||||
@@ -257,11 +256,11 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
"semantic_filter": "business finance economy",
|
||||
"word_count_threshold": 10,
|
||||
"max_dist": 0.2,
|
||||
"top_k": 3
|
||||
}
|
||||
}
|
||||
"top_k": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = tester.submit_and_wait(request)
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
@@ -271,30 +270,30 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
||||
except Exception as e:
|
||||
print(f"Cosine extraction test failed: {str(e)}")
|
||||
|
||||
|
||||
def test_screenshot(tester: Crawl4AiTester):
|
||||
print("\n=== Testing Screenshot ===")
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 5,
|
||||
"screenshot": True,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
|
||||
|
||||
result = tester.submit_and_wait(request)
|
||||
print("Screenshot captured:", bool(result["result"]["screenshot"]))
|
||||
|
||||
|
||||
if result["result"]["screenshot"]:
|
||||
# Save screenshot
|
||||
screenshot_data = base64.b64decode(result["result"]["screenshot"])
|
||||
with open("test_screenshot.jpg", "wb") as f:
|
||||
f.write(screenshot_data)
|
||||
print("Screenshot saved as test_screenshot.jpg")
|
||||
|
||||
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
||||
# version = "full"
|
||||
test_docker_deployment(version)
|
||||
test_docker_deployment(version)
|
||||
|
||||
@@ -3,20 +3,21 @@ from crawl4ai.async_logger import AsyncLogger
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
|
||||
async def main():
|
||||
current_file = Path(__file__).resolve()
|
||||
# base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
|
||||
base_dir = current_file.parent.parent / "local/_docs/llm.txt"
|
||||
docs_dir = base_dir
|
||||
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
docs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Initialize logger
|
||||
logger = AsyncLogger()
|
||||
# Updated initialization with default batching params
|
||||
# manager = AsyncLLMTextManager(docs_dir, logger, max_concurrent_calls=3, batch_size=2)
|
||||
manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2)
|
||||
manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2)
|
||||
|
||||
# Let's first check what files we have
|
||||
print("\nAvailable files:")
|
||||
@@ -26,8 +27,7 @@ async def main():
|
||||
# Generate index files
|
||||
print("\nGenerating index files...")
|
||||
await manager.generate_index_files(
|
||||
force_generate_facts=False,
|
||||
clear_bm25_cache=False
|
||||
force_generate_facts=False, clear_bm25_cache=False
|
||||
)
|
||||
|
||||
# Test some relevant queries about Crawl4AI
|
||||
@@ -41,9 +41,12 @@ async def main():
|
||||
results = manager.search(query, top_k=2)
|
||||
print(f"Results length: {len(results)} characters")
|
||||
if results:
|
||||
print("First 200 chars of results:", results[:200].replace('\n', ' '), "...")
|
||||
print(
|
||||
"First 200 chars of results:", results[:200].replace("\n", " "), "..."
|
||||
)
|
||||
else:
|
||||
print("No results found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -3,8 +3,8 @@ import aiohttp
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class NBCNewsAPITest:
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
@@ -20,7 +20,9 @@ class NBCNewsAPITest:
|
||||
await self.session.close()
|
||||
|
||||
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
|
||||
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response:
|
||||
async with self.session.post(
|
||||
f"{self.base_url}/crawl", json=request_data
|
||||
) as response:
|
||||
result = await response.json()
|
||||
return result["task_id"]
|
||||
|
||||
@@ -28,11 +30,15 @@ class NBCNewsAPITest:
|
||||
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
|
||||
return await response.json()
|
||||
|
||||
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]:
|
||||
async def wait_for_task(
|
||||
self, task_id: str, timeout: int = 300, poll_interval: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
status = await self.get_task_status(task_id)
|
||||
if status["status"] in ["completed", "failed"]:
|
||||
@@ -44,13 +50,11 @@ class NBCNewsAPITest:
|
||||
async with self.session.get(f"{self.base_url}/health") as response:
|
||||
return await response.json()
|
||||
|
||||
|
||||
async def test_basic_crawl():
|
||||
print("\n=== Testing Basic Crawl ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 10
|
||||
}
|
||||
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||
@@ -58,6 +62,7 @@ async def test_basic_crawl():
|
||||
assert "result" in result
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
async def test_js_execution():
|
||||
print("\n=== Testing JS Execution ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
@@ -68,9 +73,7 @@ async def test_js_execution():
|
||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||
],
|
||||
"wait_for": "article.tease-card:nth-child(10)",
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
@@ -78,13 +81,14 @@ async def test_js_execution():
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
async def test_css_selector():
|
||||
print("\n=== Testing CSS Selector ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 7,
|
||||
"css_selector": ".wide-tease-item__description"
|
||||
"css_selector": ".wide-tease-item__description",
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
@@ -92,6 +96,7 @@ async def test_css_selector():
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
async def test_structured_extraction():
|
||||
print("\n=== Testing Structured Extraction ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
@@ -99,34 +104,25 @@ async def test_structured_extraction():
|
||||
"name": "NBC News Articles",
|
||||
"baseSelector": "article.tease-card",
|
||||
"fields": [
|
||||
{
|
||||
"name": "title",
|
||||
"selector": "h2",
|
||||
"type": "text"
|
||||
},
|
||||
{"name": "title", "selector": "h2", "type": "text"},
|
||||
{
|
||||
"name": "description",
|
||||
"selector": ".tease-card__description",
|
||||
"type": "text"
|
||||
"type": "text",
|
||||
},
|
||||
{
|
||||
"name": "link",
|
||||
"selector": "a",
|
||||
"type": "attribute",
|
||||
"attribute": "href"
|
||||
}
|
||||
]
|
||||
"attribute": "href",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
request = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 9,
|
||||
"extraction_config": {
|
||||
"type": "json_css",
|
||||
"params": {
|
||||
"schema": schema
|
||||
}
|
||||
}
|
||||
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
@@ -136,6 +132,7 @@ async def test_structured_extraction():
|
||||
assert result["result"]["success"]
|
||||
assert len(extracted) > 0
|
||||
|
||||
|
||||
async def test_batch_crawl():
|
||||
print("\n=== Testing Batch Crawl ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
@@ -143,12 +140,10 @@ async def test_batch_crawl():
|
||||
"urls": [
|
||||
"https://www.nbcnews.com/business",
|
||||
"https://www.nbcnews.com/business/consumer",
|
||||
"https://www.nbcnews.com/business/economy"
|
||||
"https://www.nbcnews.com/business/economy",
|
||||
],
|
||||
"priority": 6,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
@@ -157,6 +152,7 @@ async def test_batch_crawl():
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
|
||||
async def test_llm_extraction():
|
||||
print("\n=== Testing LLM Extraction with Ollama ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
@@ -165,19 +161,19 @@ async def test_llm_extraction():
|
||||
"properties": {
|
||||
"article_title": {
|
||||
"type": "string",
|
||||
"description": "The main title of the news article"
|
||||
"description": "The main title of the news article",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A brief summary of the article content"
|
||||
"description": "A brief summary of the article content",
|
||||
},
|
||||
"main_topics": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Main topics or themes discussed in the article"
|
||||
}
|
||||
"description": "Main topics or themes discussed in the article",
|
||||
},
|
||||
},
|
||||
"required": ["article_title", "summary", "main_topics"]
|
||||
"required": ["article_title", "summary", "main_topics"],
|
||||
}
|
||||
|
||||
request = {
|
||||
@@ -191,26 +187,24 @@ async def test_llm_extraction():
|
||||
"schema": schema,
|
||||
"extraction_type": "schema",
|
||||
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
|
||||
Focus on the primary business news article on the page."""
|
||||
}
|
||||
Focus on the primary business news article on the page.""",
|
||||
},
|
||||
},
|
||||
"crawler_params": {
|
||||
"headless": True,
|
||||
"word_count_threshold": 1
|
||||
}
|
||||
"crawler_params": {"headless": True, "word_count_threshold": 1},
|
||||
}
|
||||
|
||||
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
|
||||
|
||||
if result["status"] == "completed":
|
||||
extracted = json.loads(result["result"]["extracted_content"])
|
||||
print(f"Extracted article analysis:")
|
||||
print("Extracted article analysis:")
|
||||
print(json.dumps(extracted, indent=2))
|
||||
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["success"]
|
||||
|
||||
|
||||
async def test_screenshot():
|
||||
print("\n=== Testing Screenshot ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
@@ -218,9 +212,7 @@ async def test_screenshot():
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 5,
|
||||
"screenshot": True,
|
||||
"crawler_params": {
|
||||
"headless": True
|
||||
}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
task_id = await api.submit_crawl(request)
|
||||
result = await api.wait_for_task(task_id)
|
||||
@@ -229,6 +221,7 @@ async def test_screenshot():
|
||||
assert result["result"]["success"]
|
||||
assert result["result"]["screenshot"] is not None
|
||||
|
||||
|
||||
async def test_priority_handling():
|
||||
print("\n=== Testing Priority Handling ===")
|
||||
async with NBCNewsAPITest() as api:
|
||||
@@ -236,7 +229,7 @@ async def test_priority_handling():
|
||||
low_priority = {
|
||||
"urls": "https://www.nbcnews.com/business",
|
||||
"priority": 1,
|
||||
"crawler_params": {"headless": True}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
low_task_id = await api.submit_crawl(low_priority)
|
||||
|
||||
@@ -244,7 +237,7 @@ async def test_priority_handling():
|
||||
high_priority = {
|
||||
"urls": "https://www.nbcnews.com/business/consumer",
|
||||
"priority": 10,
|
||||
"crawler_params": {"headless": True}
|
||||
"crawler_params": {"headless": True},
|
||||
}
|
||||
high_task_id = await api.submit_crawl(high_priority)
|
||||
|
||||
@@ -256,6 +249,7 @@ async def test_priority_handling():
|
||||
assert high_result["status"] == "completed"
|
||||
assert low_result["status"] == "completed"
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Start with health check
|
||||
@@ -277,5 +271,6 @@ async def main():
|
||||
print(f"Test failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
import asyncio
|
||||
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, LXMLWebScrapingStrategy, CacheMode
|
||||
from crawl4ai import (
|
||||
AsyncWebCrawler,
|
||||
CrawlerRunConfig,
|
||||
LXMLWebScrapingStrategy,
|
||||
CacheMode,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
config = CrawlerRunConfig(
|
||||
cache_mode=CacheMode.BYPASS,
|
||||
scraping_strategy=LXMLWebScrapingStrategy() # Faster alternative to default BeautifulSoup
|
||||
scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup
|
||||
)
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(
|
||||
url="https://example.com",
|
||||
config=config
|
||||
)
|
||||
result = await crawler.arun(url="https://example.com", config=config)
|
||||
print(f"Success: {result.success}")
|
||||
print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,79 +1,105 @@
|
||||
import unittest, os
|
||||
from crawl4ai.web_crawler import WebCrawler
|
||||
from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking
|
||||
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy
|
||||
from crawl4ai.chunking_strategy import (
|
||||
RegexChunking,
|
||||
FixedLengthWordChunking,
|
||||
SlidingWindowChunking,
|
||||
)
|
||||
from crawl4ai.extraction_strategy import (
|
||||
CosineStrategy,
|
||||
LLMExtractionStrategy,
|
||||
TopicExtractionStrategy,
|
||||
NoExtractionStrategy,
|
||||
)
|
||||
|
||||
|
||||
class TestWebCrawler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.crawler = WebCrawler()
|
||||
|
||||
|
||||
def test_warmup(self):
|
||||
self.crawler.warmup()
|
||||
self.assertTrue(self.crawler.ready, "WebCrawler failed to warm up")
|
||||
|
||||
|
||||
def test_run_default_strategies(self):
|
||||
result = self.crawler.run(
|
||||
url='https://www.nbcnews.com/business',
|
||||
url="https://www.nbcnews.com/business",
|
||||
word_count_threshold=5,
|
||||
chunking_strategy=RegexChunking(),
|
||||
extraction_strategy=CosineStrategy(), bypass_cache=True
|
||||
extraction_strategy=CosineStrategy(),
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertTrue(result.success, "Failed to crawl and extract using default strategies")
|
||||
|
||||
self.assertTrue(
|
||||
result.success, "Failed to crawl and extract using default strategies"
|
||||
)
|
||||
|
||||
def test_run_different_strategies(self):
|
||||
url = 'https://www.nbcnews.com/business'
|
||||
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# Test with FixedLengthWordChunking and LLMExtractionStrategy
|
||||
result = self.crawler.run(
|
||||
url=url,
|
||||
word_count_threshold=5,
|
||||
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
|
||||
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True
|
||||
extraction_strategy=LLMExtractionStrategy(
|
||||
provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY")
|
||||
),
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy")
|
||||
|
||||
self.assertTrue(
|
||||
result.success,
|
||||
"Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy",
|
||||
)
|
||||
|
||||
# Test with SlidingWindowChunking and TopicExtractionStrategy
|
||||
result = self.crawler.run(
|
||||
url=url,
|
||||
word_count_threshold=5,
|
||||
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
|
||||
extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True
|
||||
extraction_strategy=TopicExtractionStrategy(num_keywords=5),
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy")
|
||||
|
||||
self.assertTrue(
|
||||
result.success,
|
||||
"Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy",
|
||||
)
|
||||
|
||||
def test_invalid_url(self):
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.crawler.run(url='invalid_url', bypass_cache=True)
|
||||
self.crawler.run(url="invalid_url", bypass_cache=True)
|
||||
self.assertIn("Invalid URL", str(context.exception))
|
||||
|
||||
|
||||
def test_unsupported_extraction_strategy(self):
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True)
|
||||
self.crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
extraction_strategy="UnsupportedStrategy",
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertIn("Unsupported extraction strategy", str(context.exception))
|
||||
|
||||
|
||||
def test_invalid_css_selector(self):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True)
|
||||
self.crawler.run(
|
||||
url="https://www.nbcnews.com/business",
|
||||
css_selector="invalid_selector",
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertIn("Invalid CSS selector", str(context.exception))
|
||||
|
||||
|
||||
def test_crawl_with_cache_and_bypass_cache(self):
|
||||
url = 'https://www.nbcnews.com/business'
|
||||
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# First crawl with cache enabled
|
||||
result = self.crawler.run(url=url, bypass_cache=False)
|
||||
self.assertTrue(result.success, "Failed to crawl and cache the result")
|
||||
|
||||
|
||||
# Second crawl with bypass_cache=True
|
||||
result = self.crawler.run(url=url, bypass_cache=True)
|
||||
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
|
||||
|
||||
|
||||
def test_fetch_multiple_pages(self):
|
||||
urls = [
|
||||
'https://www.nbcnews.com/business',
|
||||
'https://www.bbc.com/news'
|
||||
]
|
||||
urls = ["https://www.nbcnews.com/business", "https://www.bbc.com/news"]
|
||||
results = []
|
||||
for url in urls:
|
||||
result = self.crawler.run(
|
||||
@@ -81,31 +107,42 @@ class TestWebCrawler(unittest.TestCase):
|
||||
word_count_threshold=5,
|
||||
chunking_strategy=RegexChunking(),
|
||||
extraction_strategy=CosineStrategy(),
|
||||
bypass_cache=True
|
||||
bypass_cache=True,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
|
||||
self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages")
|
||||
for result in results:
|
||||
self.assertTrue(result.success, "Failed to crawl and extract a page in the list")
|
||||
|
||||
self.assertTrue(
|
||||
result.success, "Failed to crawl and extract a page in the list"
|
||||
)
|
||||
|
||||
def test_run_fixed_length_word_chunking_and_no_extraction(self):
|
||||
result = self.crawler.run(
|
||||
url='https://www.nbcnews.com/business',
|
||||
url="https://www.nbcnews.com/business",
|
||||
word_count_threshold=5,
|
||||
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
|
||||
extraction_strategy=NoExtractionStrategy(), bypass_cache=True
|
||||
extraction_strategy=NoExtractionStrategy(),
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertTrue(
|
||||
result.success,
|
||||
"Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy",
|
||||
)
|
||||
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy")
|
||||
|
||||
def test_run_sliding_window_and_no_extraction(self):
|
||||
result = self.crawler.run(
|
||||
url='https://www.nbcnews.com/business',
|
||||
url="https://www.nbcnews.com/business",
|
||||
word_count_threshold=5,
|
||||
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
|
||||
extraction_strategy=NoExtractionStrategy(), bypass_cache=True
|
||||
extraction_strategy=NoExtractionStrategy(),
|
||||
bypass_cache=True,
|
||||
)
|
||||
self.assertTrue(
|
||||
result.success,
|
||||
"Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy",
|
||||
)
|
||||
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user