Apply Ruff Corrections
This commit is contained in:
8
.pre-commit-config.yaml
Normal file
8
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# .pre-commit-config.yaml
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.1.11
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
- id: ruff-format
|
||||||
@@ -2,14 +2,28 @@
|
|||||||
|
|
||||||
from .async_webcrawler import AsyncWebCrawler, CacheMode
|
from .async_webcrawler import AsyncWebCrawler, CacheMode
|
||||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||||
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy, LXMLWebScrapingStrategy
|
from .content_scraping_strategy import (
|
||||||
from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy, CosineStrategy, JsonCssExtractionStrategy
|
ContentScrapingStrategy,
|
||||||
|
WebScrapingStrategy,
|
||||||
|
LXMLWebScrapingStrategy,
|
||||||
|
)
|
||||||
|
from .extraction_strategy import (
|
||||||
|
ExtractionStrategy,
|
||||||
|
LLMExtractionStrategy,
|
||||||
|
CosineStrategy,
|
||||||
|
JsonCssExtractionStrategy,
|
||||||
|
)
|
||||||
from .chunking_strategy import ChunkingStrategy, RegexChunking
|
from .chunking_strategy import ChunkingStrategy, RegexChunking
|
||||||
from .markdown_generation_strategy import DefaultMarkdownGenerator
|
from .markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter
|
from .content_filter_strategy import PruningContentFilter, BM25ContentFilter
|
||||||
from .models import CrawlResult, MarkdownGenerationResult
|
from .models import CrawlResult, MarkdownGenerationResult
|
||||||
from .async_dispatcher import MemoryAdaptiveDispatcher, SemaphoreDispatcher, RateLimiter, CrawlerMonitor, DisplayMode
|
from .async_dispatcher import (
|
||||||
from .__version__ import __version__
|
MemoryAdaptiveDispatcher,
|
||||||
|
SemaphoreDispatcher,
|
||||||
|
RateLimiter,
|
||||||
|
CrawlerMonitor,
|
||||||
|
DisplayMode,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AsyncWebCrawler",
|
"AsyncWebCrawler",
|
||||||
@@ -18,39 +32,44 @@ __all__ = [
|
|||||||
"ContentScrapingStrategy",
|
"ContentScrapingStrategy",
|
||||||
"WebScrapingStrategy",
|
"WebScrapingStrategy",
|
||||||
"LXMLWebScrapingStrategy",
|
"LXMLWebScrapingStrategy",
|
||||||
'BrowserConfig',
|
"BrowserConfig",
|
||||||
'CrawlerRunConfig',
|
"CrawlerRunConfig",
|
||||||
'ExtractionStrategy',
|
"ExtractionStrategy",
|
||||||
'LLMExtractionStrategy',
|
"LLMExtractionStrategy",
|
||||||
'CosineStrategy',
|
"CosineStrategy",
|
||||||
'JsonCssExtractionStrategy',
|
"JsonCssExtractionStrategy",
|
||||||
'ChunkingStrategy',
|
"ChunkingStrategy",
|
||||||
'RegexChunking',
|
"RegexChunking",
|
||||||
'DefaultMarkdownGenerator',
|
"DefaultMarkdownGenerator",
|
||||||
'PruningContentFilter',
|
"PruningContentFilter",
|
||||||
'BM25ContentFilter',
|
"BM25ContentFilter",
|
||||||
'MemoryAdaptiveDispatcher',
|
"MemoryAdaptiveDispatcher",
|
||||||
'SemaphoreDispatcher',
|
"SemaphoreDispatcher",
|
||||||
'RateLimiter',
|
"RateLimiter",
|
||||||
'CrawlerMonitor',
|
"CrawlerMonitor",
|
||||||
'DisplayMode',
|
"DisplayMode",
|
||||||
'MarkdownGenerationResult',
|
"MarkdownGenerationResult",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_sync_version_installed():
|
def is_sync_version_installed():
|
||||||
try:
|
try:
|
||||||
import selenium
|
import selenium
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
if is_sync_version_installed():
|
if is_sync_version_installed():
|
||||||
try:
|
try:
|
||||||
from .web_crawler import WebCrawler
|
from .web_crawler import WebCrawler
|
||||||
|
|
||||||
__all__.append("WebCrawler")
|
__all__.append("WebCrawler")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import warnings
|
print(
|
||||||
print("Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies.")
|
"Warning: Failed to import WebCrawler even though selenium is installed. This might be due to other missing dependencies."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
WebCrawler = None
|
WebCrawler = None
|
||||||
# import warnings
|
# import warnings
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from .config import (
|
|||||||
PAGE_TIMEOUT,
|
PAGE_TIMEOUT,
|
||||||
IMAGE_SCORE_THRESHOLD,
|
IMAGE_SCORE_THRESHOLD,
|
||||||
SOCIAL_MEDIA_DOMAINS,
|
SOCIAL_MEDIA_DOMAINS,
|
||||||
|
|
||||||
)
|
)
|
||||||
from .user_agent_generator import UserAgentGenerator
|
from .user_agent_generator import UserAgentGenerator
|
||||||
from .extraction_strategy import ExtractionStrategy
|
from .extraction_strategy import ExtractionStrategy
|
||||||
@@ -14,6 +13,7 @@ from .markdown_generation_strategy import MarkdownGenerationStrategy
|
|||||||
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
|
from .content_scraping_strategy import ContentScrapingStrategy, WebScrapingStrategy
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
|
||||||
|
|
||||||
class BrowserConfig:
|
class BrowserConfig:
|
||||||
"""
|
"""
|
||||||
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
|
Configuration class for setting up a browser instance and its context in AsyncPlaywrightCrawlerStrategy.
|
||||||
@@ -335,10 +335,8 @@ class CrawlerRunConfig:
|
|||||||
prettiify: bool = False,
|
prettiify: bool = False,
|
||||||
parser_type: str = "lxml",
|
parser_type: str = "lxml",
|
||||||
scraping_strategy: ContentScrapingStrategy = None,
|
scraping_strategy: ContentScrapingStrategy = None,
|
||||||
|
|
||||||
# SSL Parameters
|
# SSL Parameters
|
||||||
fetch_ssl_certificate: bool = False,
|
fetch_ssl_certificate: bool = False,
|
||||||
|
|
||||||
# Caching Parameters
|
# Caching Parameters
|
||||||
cache_mode=None,
|
cache_mode=None,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
@@ -346,7 +344,6 @@ class CrawlerRunConfig:
|
|||||||
disable_cache: bool = False,
|
disable_cache: bool = False,
|
||||||
no_cache_read: bool = False,
|
no_cache_read: bool = False,
|
||||||
no_cache_write: bool = False,
|
no_cache_write: bool = False,
|
||||||
|
|
||||||
# Page Navigation and Timing Parameters
|
# Page Navigation and Timing Parameters
|
||||||
wait_until: str = "domcontentloaded",
|
wait_until: str = "domcontentloaded",
|
||||||
page_timeout: int = PAGE_TIMEOUT,
|
page_timeout: int = PAGE_TIMEOUT,
|
||||||
@@ -356,7 +353,6 @@ class CrawlerRunConfig:
|
|||||||
mean_delay: float = 0.1,
|
mean_delay: float = 0.1,
|
||||||
max_range: float = 0.3,
|
max_range: float = 0.3,
|
||||||
semaphore_count: int = 5,
|
semaphore_count: int = 5,
|
||||||
|
|
||||||
# Page Interaction Parameters
|
# Page Interaction Parameters
|
||||||
js_code: Union[str, List[str]] = None,
|
js_code: Union[str, List[str]] = None,
|
||||||
js_only: bool = False,
|
js_only: bool = False,
|
||||||
@@ -369,7 +365,6 @@ class CrawlerRunConfig:
|
|||||||
override_navigator: bool = False,
|
override_navigator: bool = False,
|
||||||
magic: bool = False,
|
magic: bool = False,
|
||||||
adjust_viewport_to_content: bool = False,
|
adjust_viewport_to_content: bool = False,
|
||||||
|
|
||||||
# Media Handling Parameters
|
# Media Handling Parameters
|
||||||
screenshot: bool = False,
|
screenshot: bool = False,
|
||||||
screenshot_wait_for: float = None,
|
screenshot_wait_for: float = None,
|
||||||
@@ -378,17 +373,14 @@ class CrawlerRunConfig:
|
|||||||
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||||
image_score_threshold: int = IMAGE_SCORE_THRESHOLD,
|
image_score_threshold: int = IMAGE_SCORE_THRESHOLD,
|
||||||
exclude_external_images: bool = False,
|
exclude_external_images: bool = False,
|
||||||
|
|
||||||
# Link and Domain Handling Parameters
|
# Link and Domain Handling Parameters
|
||||||
exclude_social_media_domains: list = None,
|
exclude_social_media_domains: list = None,
|
||||||
exclude_external_links: bool = False,
|
exclude_external_links: bool = False,
|
||||||
exclude_social_media_links: bool = False,
|
exclude_social_media_links: bool = False,
|
||||||
exclude_domains: list = None,
|
exclude_domains: list = None,
|
||||||
|
|
||||||
# Debugging and Logging Parameters
|
# Debugging and Logging Parameters
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
log_console: bool = False,
|
log_console: bool = False,
|
||||||
|
|
||||||
url: str = None,
|
url: str = None,
|
||||||
):
|
):
|
||||||
self.url = url
|
self.url = url
|
||||||
@@ -453,7 +445,9 @@ class CrawlerRunConfig:
|
|||||||
self.exclude_external_images = exclude_external_images
|
self.exclude_external_images = exclude_external_images
|
||||||
|
|
||||||
# Link and Domain Handling Parameters
|
# Link and Domain Handling Parameters
|
||||||
self.exclude_social_media_domains = exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
|
self.exclude_social_media_domains = (
|
||||||
|
exclude_social_media_domains or SOCIAL_MEDIA_DOMAINS
|
||||||
|
)
|
||||||
self.exclude_external_links = exclude_external_links
|
self.exclude_external_links = exclude_external_links
|
||||||
self.exclude_social_media_links = exclude_social_media_links
|
self.exclude_social_media_links = exclude_social_media_links
|
||||||
self.exclude_domains = exclude_domains or []
|
self.exclude_domains = exclude_domains or []
|
||||||
@@ -466,11 +460,15 @@ class CrawlerRunConfig:
|
|||||||
if self.extraction_strategy is not None and not isinstance(
|
if self.extraction_strategy is not None and not isinstance(
|
||||||
self.extraction_strategy, ExtractionStrategy
|
self.extraction_strategy, ExtractionStrategy
|
||||||
):
|
):
|
||||||
raise ValueError("extraction_strategy must be an instance of ExtractionStrategy")
|
raise ValueError(
|
||||||
|
"extraction_strategy must be an instance of ExtractionStrategy"
|
||||||
|
)
|
||||||
if self.chunking_strategy is not None and not isinstance(
|
if self.chunking_strategy is not None and not isinstance(
|
||||||
self.chunking_strategy, ChunkingStrategy
|
self.chunking_strategy, ChunkingStrategy
|
||||||
):
|
):
|
||||||
raise ValueError("chunking_strategy must be an instance of ChunkingStrategy")
|
raise ValueError(
|
||||||
|
"chunking_strategy must be an instance of ChunkingStrategy"
|
||||||
|
)
|
||||||
|
|
||||||
# Set default chunking strategy if None
|
# Set default chunking strategy if None
|
||||||
if self.chunking_strategy is None:
|
if self.chunking_strategy is None:
|
||||||
@@ -494,10 +492,8 @@ class CrawlerRunConfig:
|
|||||||
prettiify=kwargs.get("prettiify", False),
|
prettiify=kwargs.get("prettiify", False),
|
||||||
parser_type=kwargs.get("parser_type", "lxml"),
|
parser_type=kwargs.get("parser_type", "lxml"),
|
||||||
scraping_strategy=kwargs.get("scraping_strategy"),
|
scraping_strategy=kwargs.get("scraping_strategy"),
|
||||||
|
|
||||||
# SSL Parameters
|
# SSL Parameters
|
||||||
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
|
fetch_ssl_certificate=kwargs.get("fetch_ssl_certificate", False),
|
||||||
|
|
||||||
# Caching Parameters
|
# Caching Parameters
|
||||||
cache_mode=kwargs.get("cache_mode"),
|
cache_mode=kwargs.get("cache_mode"),
|
||||||
session_id=kwargs.get("session_id"),
|
session_id=kwargs.get("session_id"),
|
||||||
@@ -505,7 +501,6 @@ class CrawlerRunConfig:
|
|||||||
disable_cache=kwargs.get("disable_cache", False),
|
disable_cache=kwargs.get("disable_cache", False),
|
||||||
no_cache_read=kwargs.get("no_cache_read", False),
|
no_cache_read=kwargs.get("no_cache_read", False),
|
||||||
no_cache_write=kwargs.get("no_cache_write", False),
|
no_cache_write=kwargs.get("no_cache_write", False),
|
||||||
|
|
||||||
# Page Navigation and Timing Parameters
|
# Page Navigation and Timing Parameters
|
||||||
wait_until=kwargs.get("wait_until", "domcontentloaded"),
|
wait_until=kwargs.get("wait_until", "domcontentloaded"),
|
||||||
page_timeout=kwargs.get("page_timeout", 60000),
|
page_timeout=kwargs.get("page_timeout", 60000),
|
||||||
@@ -515,7 +510,6 @@ class CrawlerRunConfig:
|
|||||||
mean_delay=kwargs.get("mean_delay", 0.1),
|
mean_delay=kwargs.get("mean_delay", 0.1),
|
||||||
max_range=kwargs.get("max_range", 0.3),
|
max_range=kwargs.get("max_range", 0.3),
|
||||||
semaphore_count=kwargs.get("semaphore_count", 5),
|
semaphore_count=kwargs.get("semaphore_count", 5),
|
||||||
|
|
||||||
# Page Interaction Parameters
|
# Page Interaction Parameters
|
||||||
js_code=kwargs.get("js_code"),
|
js_code=kwargs.get("js_code"),
|
||||||
js_only=kwargs.get("js_only", False),
|
js_only=kwargs.get("js_only", False),
|
||||||
@@ -528,26 +522,31 @@ class CrawlerRunConfig:
|
|||||||
override_navigator=kwargs.get("override_navigator", False),
|
override_navigator=kwargs.get("override_navigator", False),
|
||||||
magic=kwargs.get("magic", False),
|
magic=kwargs.get("magic", False),
|
||||||
adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
|
adjust_viewport_to_content=kwargs.get("adjust_viewport_to_content", False),
|
||||||
|
|
||||||
# Media Handling Parameters
|
# Media Handling Parameters
|
||||||
screenshot=kwargs.get("screenshot", False),
|
screenshot=kwargs.get("screenshot", False),
|
||||||
screenshot_wait_for=kwargs.get("screenshot_wait_for"),
|
screenshot_wait_for=kwargs.get("screenshot_wait_for"),
|
||||||
screenshot_height_threshold=kwargs.get("screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD),
|
screenshot_height_threshold=kwargs.get(
|
||||||
|
"screenshot_height_threshold", SCREENSHOT_HEIGHT_TRESHOLD
|
||||||
|
),
|
||||||
pdf=kwargs.get("pdf", False),
|
pdf=kwargs.get("pdf", False),
|
||||||
image_description_min_word_threshold=kwargs.get("image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD),
|
image_description_min_word_threshold=kwargs.get(
|
||||||
image_score_threshold=kwargs.get("image_score_threshold", IMAGE_SCORE_THRESHOLD),
|
"image_description_min_word_threshold",
|
||||||
|
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||||
|
),
|
||||||
|
image_score_threshold=kwargs.get(
|
||||||
|
"image_score_threshold", IMAGE_SCORE_THRESHOLD
|
||||||
|
),
|
||||||
exclude_external_images=kwargs.get("exclude_external_images", False),
|
exclude_external_images=kwargs.get("exclude_external_images", False),
|
||||||
|
|
||||||
# Link and Domain Handling Parameters
|
# Link and Domain Handling Parameters
|
||||||
exclude_social_media_domains=kwargs.get("exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS),
|
exclude_social_media_domains=kwargs.get(
|
||||||
|
"exclude_social_media_domains", SOCIAL_MEDIA_DOMAINS
|
||||||
|
),
|
||||||
exclude_external_links=kwargs.get("exclude_external_links", False),
|
exclude_external_links=kwargs.get("exclude_external_links", False),
|
||||||
exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
|
exclude_social_media_links=kwargs.get("exclude_social_media_links", False),
|
||||||
exclude_domains=kwargs.get("exclude_domains", []),
|
exclude_domains=kwargs.get("exclude_domains", []),
|
||||||
|
|
||||||
# Debugging and Logging Parameters
|
# Debugging and Logging Parameters
|
||||||
verbose=kwargs.get("verbose", True),
|
verbose=kwargs.get("verbose", True),
|
||||||
log_console=kwargs.get("log_console", False),
|
log_console=kwargs.get("log_console", False),
|
||||||
|
|
||||||
url=kwargs.get("url"),
|
url=kwargs.get("url"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,27 +2,25 @@ import asyncio
|
|||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Dict, Any, List, Optional, Awaitable, Union
|
from typing import Callable, Dict, Any, List, Optional, Union
|
||||||
import os, sys, shutil
|
import os
|
||||||
import tempfile, subprocess
|
import sys
|
||||||
from playwright.async_api import async_playwright, Page, Browser, Error, BrowserContext
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import subprocess
|
||||||
|
from playwright.async_api import Page, Error, BrowserContext
|
||||||
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from pathlib import Path
|
|
||||||
from playwright.async_api import ProxySettings
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
from .js_snippet import load_js_script
|
from .js_snippet import load_js_script
|
||||||
from .models import AsyncCrawlResponse
|
from .models import AsyncCrawlResponse
|
||||||
from .utils import get_error_context
|
|
||||||
from .user_agent_generator import UserAgentGenerator
|
from .user_agent_generator import UserAgentGenerator
|
||||||
from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT
|
from .config import SCREENSHOT_HEIGHT_TRESHOLD, DOWNLOAD_PAGE_TIMEOUT
|
||||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||||
from .async_logger import AsyncLogger
|
from .async_logger import AsyncLogger
|
||||||
from playwright_stealth import StealthConfig, stealth_async
|
from playwright_stealth import StealthConfig
|
||||||
from .ssl_certificate import SSLCertificate
|
from .ssl_certificate import SSLCertificate
|
||||||
|
|
||||||
stealth_config = StealthConfig(
|
stealth_config = StealthConfig(
|
||||||
@@ -94,6 +92,7 @@ class ManagedBrowser:
|
|||||||
temp_dir: str
|
temp_dir: str
|
||||||
debugging_port: int
|
debugging_port: int
|
||||||
host: str
|
host: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
browser_type: str = "chromium",
|
browser_type: str = "chromium",
|
||||||
@@ -139,7 +138,7 @@ class ManagedBrowser:
|
|||||||
self.user_data_dir = self.temp_dir
|
self.user_data_dir = self.temp_dir
|
||||||
|
|
||||||
# Get browser path and args based on OS and browser type
|
# Get browser path and args based on OS and browser type
|
||||||
browser_path = self._get_browser_path()
|
# browser_path = self._get_browser_path()
|
||||||
args = self._get_browser_args()
|
args = self._get_browser_args()
|
||||||
|
|
||||||
# Start browser process
|
# Start browser process
|
||||||
@@ -300,6 +299,7 @@ class BrowserManager:
|
|||||||
sessions (dict): Dictionary to store session information
|
sessions (dict): Dictionary to store session information
|
||||||
session_ttl (int): Session timeout in seconds
|
session_ttl (int): Session timeout in seconds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, browser_config: BrowserConfig, logger=None):
|
def __init__(self, browser_config: BrowserConfig, logger=None):
|
||||||
"""
|
"""
|
||||||
Initialize the BrowserManager with a browser configuration.
|
Initialize the BrowserManager with a browser configuration.
|
||||||
@@ -453,7 +453,12 @@ class BrowserManager:
|
|||||||
|
|
||||||
return browser_args
|
return browser_args
|
||||||
|
|
||||||
async def setup_context(self, context: BrowserContext, crawlerRunConfig: CrawlerRunConfig = None, is_default=False):
|
async def setup_context(
|
||||||
|
self,
|
||||||
|
context: BrowserContext,
|
||||||
|
crawlerRunConfig: CrawlerRunConfig = None,
|
||||||
|
is_default=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Set up a browser context with the configured options.
|
Set up a browser context with the configured options.
|
||||||
|
|
||||||
@@ -496,9 +501,9 @@ class BrowserManager:
|
|||||||
context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT)
|
context.set_default_navigation_timeout(DOWNLOAD_PAGE_TIMEOUT)
|
||||||
if self.config.downloads_path:
|
if self.config.downloads_path:
|
||||||
context._impl_obj._options["accept_downloads"] = True
|
context._impl_obj._options["accept_downloads"] = True
|
||||||
context._impl_obj._options["downloads_path"] = (
|
context._impl_obj._options[
|
||||||
self.config.downloads_path
|
"downloads_path"
|
||||||
)
|
] = self.config.downloads_path
|
||||||
|
|
||||||
# Handle user agent and browser hints
|
# Handle user agent and browser hints
|
||||||
if self.config.user_agent:
|
if self.config.user_agent:
|
||||||
@@ -511,7 +516,15 @@ class BrowserManager:
|
|||||||
|
|
||||||
# Add default cookie
|
# Add default cookie
|
||||||
await context.add_cookies(
|
await context.add_cookies(
|
||||||
[{"name": "cookiesEnabled", "value": "true", "url": crawlerRunConfig.url if crawlerRunConfig else "https://crawl4ai.com/"}]
|
[
|
||||||
|
{
|
||||||
|
"name": "cookiesEnabled",
|
||||||
|
"value": "true",
|
||||||
|
"url": crawlerRunConfig.url
|
||||||
|
if crawlerRunConfig
|
||||||
|
else "https://crawl4ai.com/",
|
||||||
|
}
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle navigator overrides
|
# Handle navigator overrides
|
||||||
@@ -541,20 +554,57 @@ class BrowserManager:
|
|||||||
|
|
||||||
blocked_extensions = [
|
blocked_extensions = [
|
||||||
# Images
|
# Images
|
||||||
'jpg', 'jpeg', 'png', 'gif', 'webp', 'svg', 'ico', 'bmp', 'tiff', 'psd',
|
"jpg",
|
||||||
|
"jpeg",
|
||||||
|
"png",
|
||||||
|
"gif",
|
||||||
|
"webp",
|
||||||
|
"svg",
|
||||||
|
"ico",
|
||||||
|
"bmp",
|
||||||
|
"tiff",
|
||||||
|
"psd",
|
||||||
# Fonts
|
# Fonts
|
||||||
'woff', 'woff2', 'ttf', 'otf', 'eot',
|
"woff",
|
||||||
|
"woff2",
|
||||||
|
"ttf",
|
||||||
|
"otf",
|
||||||
|
"eot",
|
||||||
# Styles
|
# Styles
|
||||||
# 'css', 'less', 'scss', 'sass',
|
# 'css', 'less', 'scss', 'sass',
|
||||||
# Media
|
# Media
|
||||||
'mp4', 'webm', 'ogg', 'avi', 'mov', 'wmv', 'flv', 'm4v',
|
"mp4",
|
||||||
'mp3', 'wav', 'aac', 'm4a', 'opus', 'flac',
|
"webm",
|
||||||
|
"ogg",
|
||||||
|
"avi",
|
||||||
|
"mov",
|
||||||
|
"wmv",
|
||||||
|
"flv",
|
||||||
|
"m4v",
|
||||||
|
"mp3",
|
||||||
|
"wav",
|
||||||
|
"aac",
|
||||||
|
"m4a",
|
||||||
|
"opus",
|
||||||
|
"flac",
|
||||||
# Documents
|
# Documents
|
||||||
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx',
|
"pdf",
|
||||||
|
"doc",
|
||||||
|
"docx",
|
||||||
|
"xls",
|
||||||
|
"xlsx",
|
||||||
|
"ppt",
|
||||||
|
"pptx",
|
||||||
# Archives
|
# Archives
|
||||||
'zip', 'rar', '7z', 'tar', 'gz',
|
"zip",
|
||||||
|
"rar",
|
||||||
|
"7z",
|
||||||
|
"tar",
|
||||||
|
"gz",
|
||||||
# Scripts and data
|
# Scripts and data
|
||||||
'xml', 'swf', 'wasm'
|
"xml",
|
||||||
|
"swf",
|
||||||
|
"wasm",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Common context settings
|
# Common context settings
|
||||||
@@ -672,12 +722,12 @@ class AsyncCrawlerStrategy(ABC):
|
|||||||
Abstract base class for crawler strategies.
|
Abstract base class for crawler strategies.
|
||||||
Subclasses must implement the crawl method.
|
Subclasses must implement the crawl method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
|
async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
|
||||||
pass # 4 + 3
|
pass # 4 + 3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
||||||
"""
|
"""
|
||||||
Crawler strategy using Playwright.
|
Crawler strategy using Playwright.
|
||||||
@@ -706,6 +756,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
Run the crawler for a single URL.
|
Run the crawler for a single URL.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs
|
self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs
|
||||||
):
|
):
|
||||||
@@ -917,7 +968,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
"or explicitly prefixed with 'js:' or 'css:'."
|
"or explicitly prefixed with 'js:' or 'css:'."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def csp_compliant_wait( self, page: Page, user_wait_function: str, timeout: float = 30000 ):
|
async def csp_compliant_wait(
|
||||||
|
self, page: Page, user_wait_function: str, timeout: float = 30000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Wait for a condition in a CSP-compliant way.
|
Wait for a condition in a CSP-compliant way.
|
||||||
|
|
||||||
@@ -1045,7 +1098,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
page, context = await self.browser_manager.get_page(session_id, user_agent)
|
page, context = await self.browser_manager.get_page(session_id, user_agent)
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
async def crawl( self, url: str, config: CrawlerRunConfig, **kwargs ) -> AsyncCrawlResponse:
|
async def crawl(
|
||||||
|
self, url: str, config: CrawlerRunConfig, **kwargs
|
||||||
|
) -> AsyncCrawlResponse:
|
||||||
"""
|
"""
|
||||||
Crawls a given URL or processes raw HTML/local file content based on the URL prefix.
|
Crawls a given URL or processes raw HTML/local file content based on the URL prefix.
|
||||||
|
|
||||||
@@ -1104,7 +1159,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
"URL must start with 'http://', 'https://', 'file://', or 'raw:'"
|
"URL must start with 'http://', 'https://', 'file://', or 'raw:'"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _crawl_web( self, url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse:
|
async def _crawl_web(
|
||||||
|
self, url: str, config: CrawlerRunConfig
|
||||||
|
) -> AsyncCrawlResponse:
|
||||||
"""
|
"""
|
||||||
Internal method to crawl web URLs with the specified configuration.
|
Internal method to crawl web URLs with the specified configuration.
|
||||||
|
|
||||||
@@ -1190,9 +1247,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
nonce = hashlib.sha256(os.urandom(32)).hexdigest()
|
nonce = hashlib.sha256(os.urandom(32)).hexdigest()
|
||||||
|
|
||||||
# Add CSP headers to the request
|
# Add CSP headers to the request
|
||||||
await page.set_extra_http_headers({
|
await page.set_extra_http_headers(
|
||||||
'Content-Security-Policy': f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'"
|
{
|
||||||
})
|
"Content-Security-Policy": f"default-src 'self'; script-src 'self' 'nonce-{nonce}' 'strict-dynamic'"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
response = await page.goto(
|
response = await page.goto(
|
||||||
url, wait_until=config.wait_until, timeout=config.page_timeout
|
url, wait_until=config.wait_until, timeout=config.page_timeout
|
||||||
@@ -1200,7 +1259,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
except Error as e:
|
except Error as e:
|
||||||
raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}")
|
raise RuntimeError(f"Failed on navigating ACS-GOTO:\n{str(e)}")
|
||||||
|
|
||||||
await self.execute_hook("after_goto", page, context=context, url=url, response=response)
|
await self.execute_hook(
|
||||||
|
"after_goto", page, context=context, url=url, response=response
|
||||||
|
)
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
status_code = 200
|
status_code = 200
|
||||||
@@ -1229,14 +1290,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
style.opacity !== '0';
|
style.opacity !== '0';
|
||||||
return isVisible;
|
return isVisible;
|
||||||
}""",
|
}""",
|
||||||
timeout=30000
|
timeout=30000,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_visible and not config.ignore_body_visibility:
|
if not is_visible and not config.ignore_body_visibility:
|
||||||
visibility_info = await self.check_visibility(page)
|
visibility_info = await self.check_visibility(page)
|
||||||
raise Error(f"Body element is hidden: {visibility_info}")
|
raise Error(f"Body element is hidden: {visibility_info}")
|
||||||
|
|
||||||
except Error as e:
|
except Error:
|
||||||
visibility_info = await self.check_visibility(page)
|
visibility_info = await self.check_visibility(page)
|
||||||
|
|
||||||
if self.config.verbose:
|
if self.config.verbose:
|
||||||
@@ -1249,7 +1310,6 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
if not config.ignore_body_visibility:
|
if not config.ignore_body_visibility:
|
||||||
raise Error(f"Body element is hidden: {visibility_info}")
|
raise Error(f"Body element is hidden: {visibility_info}")
|
||||||
|
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# await page.wait_for_selector("body", state="attached", timeout=30000)
|
# await page.wait_for_selector("body", state="attached", timeout=30000)
|
||||||
|
|
||||||
@@ -1303,7 +1363,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
images_loaded = await self.csp_compliant_wait(
|
images_loaded = await self.csp_compliant_wait(
|
||||||
page,
|
page,
|
||||||
"() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)",
|
"() => Array.from(document.getElementsByTagName('img')).every(img => img.complete)",
|
||||||
timeout=1000
|
timeout=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not images_loaded and self.logger:
|
if not images_loaded and self.logger:
|
||||||
@@ -1316,8 +1376,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
if not self.browser_config.text_mode and config.adjust_viewport_to_content:
|
if not self.browser_config.text_mode and config.adjust_viewport_to_content:
|
||||||
try:
|
try:
|
||||||
dimensions = await self.get_page_dimensions(page)
|
dimensions = await self.get_page_dimensions(page)
|
||||||
page_height = dimensions['height']
|
page_height = dimensions["height"]
|
||||||
page_width = dimensions['width']
|
page_width = dimensions["width"]
|
||||||
# page_width = await page.evaluate(
|
# page_width = await page.evaluate(
|
||||||
# "document.documentElement.scrollWidth"
|
# "document.documentElement.scrollWidth"
|
||||||
# )
|
# )
|
||||||
@@ -1364,12 +1424,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
|
|
||||||
if config.js_code:
|
if config.js_code:
|
||||||
# execution_result = await self.execute_user_script(page, config.js_code)
|
# execution_result = await self.execute_user_script(page, config.js_code)
|
||||||
execution_result = await self.robust_execute_user_script(page, config.js_code)
|
execution_result = await self.robust_execute_user_script(
|
||||||
|
page, config.js_code
|
||||||
|
)
|
||||||
if not execution_result["success"]:
|
if not execution_result["success"]:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="User script execution had issues: {error}",
|
message="User script execution had issues: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": execution_result.get("error")}
|
params={"error": execution_result.get("error")},
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.execute_hook("on_execution_started", page, context=context)
|
await self.execute_hook("on_execution_started", page, context=context)
|
||||||
@@ -1425,7 +1487,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
|
|
||||||
# Get final HTML content
|
# Get final HTML content
|
||||||
html = await page.content()
|
html = await page.content()
|
||||||
await self.execute_hook("before_return_html", page = page, html = html, context=context)
|
await self.execute_hook(
|
||||||
|
"before_return_html", page=page, html=html, context=context
|
||||||
|
)
|
||||||
|
|
||||||
# Handle PDF and screenshot generation
|
# Handle PDF and screenshot generation
|
||||||
start_export_time = time.perf_counter()
|
start_export_time = time.perf_counter()
|
||||||
@@ -1511,7 +1575,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
|
|
||||||
# total_height = await page.evaluate("document.documentElement.scrollHeight")
|
# total_height = await page.evaluate("document.documentElement.scrollHeight")
|
||||||
dimensions = await self.get_page_dimensions(page)
|
dimensions = await self.get_page_dimensions(page)
|
||||||
total_height = dimensions['height']
|
total_height = dimensions["height"]
|
||||||
|
|
||||||
while current_position < total_height:
|
while current_position < total_height:
|
||||||
current_position = min(current_position + viewport_height, total_height)
|
current_position = min(current_position + viewport_height, total_height)
|
||||||
@@ -1521,7 +1585,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
|
|
||||||
# new_height = await page.evaluate("document.documentElement.scrollHeight")
|
# new_height = await page.evaluate("document.documentElement.scrollHeight")
|
||||||
dimensions = await self.get_page_dimensions(page)
|
dimensions = await self.get_page_dimensions(page)
|
||||||
new_height = dimensions['height']
|
new_height = dimensions["height"]
|
||||||
|
|
||||||
if new_height > total_height:
|
if new_height > total_height:
|
||||||
total_height = new_height
|
total_height = new_height
|
||||||
@@ -1598,7 +1662,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
remove_overlays_js = load_js_script("remove_overlay_elements")
|
remove_overlays_js = load_js_script("remove_overlay_elements")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await page.evaluate(f"""
|
await page.evaluate(
|
||||||
|
f"""
|
||||||
(() => {{
|
(() => {{
|
||||||
try {{
|
try {{
|
||||||
{remove_overlays_js}
|
{remove_overlays_js}
|
||||||
@@ -1611,7 +1676,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
}};
|
}};
|
||||||
}}
|
}}
|
||||||
}})()
|
}})()
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
await page.wait_for_timeout(500) # Wait for any animations to complete
|
await page.wait_for_timeout(500) # Wait for any animations to complete
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
@@ -1707,8 +1773,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
try:
|
try:
|
||||||
# Get page height
|
# Get page height
|
||||||
dimensions = await self.get_page_dimensions(page)
|
dimensions = await self.get_page_dimensions(page)
|
||||||
page_width = dimensions['width']
|
page_width = dimensions["width"]
|
||||||
page_height = dimensions['height']
|
page_height = dimensions["height"]
|
||||||
# page_height = await page.evaluate("document.documentElement.scrollHeight")
|
# page_height = await page.evaluate("document.documentElement.scrollHeight")
|
||||||
# page_width = await page.evaluate("document.documentElement.scrollWidth")
|
# page_width = await page.evaluate("document.documentElement.scrollWidth")
|
||||||
|
|
||||||
@@ -1826,7 +1892,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
tag="WARNING",
|
tag="WARNING",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def robust_execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]:
|
async def robust_execute_user_script(
|
||||||
|
self, page: Page, js_code: Union[str, List[str]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Executes user-provided JavaScript code with proper error handling and context,
|
Executes user-provided JavaScript code with proper error handling and context,
|
||||||
supporting both synchronous and async user code, plus navigations.
|
supporting both synchronous and async user code, plus navigations.
|
||||||
@@ -1846,7 +1914,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
Dict[str, Any]: The results of the execution
|
Dict[str, Any]: The results of the execution
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await page.wait_for_load_state('domcontentloaded')
|
await page.wait_for_load_state("domcontentloaded")
|
||||||
|
|
||||||
if isinstance(js_code, str):
|
if isinstance(js_code, str):
|
||||||
scripts = [js_code]
|
scripts = [js_code]
|
||||||
@@ -1861,7 +1929,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
# then wait for the new page to load before continuing
|
# then wait for the new page to load before continuing
|
||||||
result = None
|
result = None
|
||||||
try:
|
try:
|
||||||
result = await page.evaluate(f"""
|
result = await page.evaluate(
|
||||||
|
f"""
|
||||||
(async () => {{
|
(async () => {{
|
||||||
try {{
|
try {{
|
||||||
{script}
|
{script}
|
||||||
@@ -1870,51 +1939,60 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
return {{ success: false, error: err.toString(), stack: err.stack }};
|
return {{ success: false, error: err.toString(), stack: err.stack }};
|
||||||
}}
|
}}
|
||||||
}})();
|
}})();
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
except Error as e:
|
except Error as e:
|
||||||
# If it's due to navigation destroying the context, handle gracefully
|
# If it's due to navigation destroying the context, handle gracefully
|
||||||
if "Execution context was destroyed" in str(e):
|
if "Execution context was destroyed" in str(e):
|
||||||
self.logger.info("Navigation triggered by script, waiting for load state", tag="JS_EXEC")
|
self.logger.info(
|
||||||
|
"Navigation triggered by script, waiting for load state",
|
||||||
|
tag="JS_EXEC",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await page.wait_for_load_state('load', timeout=30000)
|
await page.wait_for_load_state("load", timeout=30000)
|
||||||
except Error as nav_err:
|
except Error as nav_err:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="Navigation wait failed: {error}",
|
message="Navigation wait failed: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(nav_err)}
|
params={"error": str(nav_err)},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await page.wait_for_load_state('networkidle', timeout=30000)
|
await page.wait_for_load_state(
|
||||||
|
"networkidle", timeout=30000
|
||||||
|
)
|
||||||
except Error as nav_err:
|
except Error as nav_err:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="Network idle wait failed: {error}",
|
message="Network idle wait failed: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(nav_err)}
|
params={"error": str(nav_err)},
|
||||||
)
|
)
|
||||||
# Return partial success, or adapt as you see fit
|
# Return partial success, or adapt as you see fit
|
||||||
result = {
|
result = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"info": "Navigation triggered, ignoring context destroyed error"
|
"info": "Navigation triggered, ignoring context destroyed error",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# It's some other error, log and continue
|
# It's some other error, log and continue
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Playwright execution error: {error}",
|
message="Playwright execution error: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
result = {"success": False, "error": str(e)}
|
result = {"success": False, "error": str(e)}
|
||||||
|
|
||||||
# If we made it this far with no repeated error, do post-load waits
|
# If we made it this far with no repeated error, do post-load waits
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
try:
|
try:
|
||||||
await page.wait_for_load_state('domcontentloaded', timeout=5000)
|
await page.wait_for_load_state("domcontentloaded", timeout=5000)
|
||||||
print("DOM content loaded after script execution in", time.time() - t1)
|
print(
|
||||||
|
"DOM content loaded after script execution in",
|
||||||
|
time.time() - t1,
|
||||||
|
)
|
||||||
except Error as e:
|
except Error as e:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="DOM content load timeout: {error}",
|
message="DOM content load timeout: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
# t1 = time.time()
|
# t1 = time.time()
|
||||||
@@ -1935,7 +2013,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Script chunk failed: {error}",
|
message="Script chunk failed: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
results.append({"success": False, "error": str(e)})
|
results.append({"success": False, "error": str(e)})
|
||||||
|
|
||||||
@@ -1945,11 +2023,13 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Script execution failed: {error}",
|
message="Script execution failed: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
async def execute_user_script(self, page: Page, js_code: Union[str, List[str]]) -> Dict[str, Any]:
|
async def execute_user_script(
|
||||||
|
self, page: Page, js_code: Union[str, List[str]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Executes user-provided JavaScript code with proper error handling and context.
|
Executes user-provided JavaScript code with proper error handling and context.
|
||||||
|
|
||||||
@@ -1962,7 +2042,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Ensure the page is ready for script execution
|
# Ensure the page is ready for script execution
|
||||||
await page.wait_for_load_state('domcontentloaded')
|
await page.wait_for_load_state("domcontentloaded")
|
||||||
|
|
||||||
# Handle single script or multiple scripts
|
# Handle single script or multiple scripts
|
||||||
if isinstance(js_code, str):
|
if isinstance(js_code, str):
|
||||||
@@ -1974,7 +2054,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
for script in scripts:
|
for script in scripts:
|
||||||
try:
|
try:
|
||||||
# Execute the script and wait for network idle
|
# Execute the script and wait for network idle
|
||||||
result = await page.evaluate(f"""
|
result = await page.evaluate(
|
||||||
|
f"""
|
||||||
(() => {{
|
(() => {{
|
||||||
return new Promise((resolve) => {{
|
return new Promise((resolve) => {{
|
||||||
try {{
|
try {{
|
||||||
@@ -2007,15 +2088,18 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
}})()
|
}})()
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for network idle after script execution
|
# Wait for network idle after script execution
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
await page.wait_for_load_state('domcontentloaded', timeout=5000)
|
await page.wait_for_load_state("domcontentloaded", timeout=5000)
|
||||||
print("DOM content loaded after script execution in", time.time() - t1)
|
print(
|
||||||
|
"DOM content loaded after script execution in", time.time() - t1
|
||||||
|
)
|
||||||
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
await page.wait_for_load_state('networkidle', timeout=5000)
|
await page.wait_for_load_state("networkidle", timeout=5000)
|
||||||
print("Network idle after script execution in", time.time() - t1)
|
print("Network idle after script execution in", time.time() - t1)
|
||||||
|
|
||||||
results.append(result if result else {"success": True})
|
results.append(result if result else {"success": True})
|
||||||
@@ -2025,7 +2109,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Playwright execution error: {error}",
|
message="Playwright execution error: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
results.append({"success": False, "error": str(e)})
|
results.append({"success": False, "error": str(e)})
|
||||||
|
|
||||||
@@ -2035,7 +2119,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Script execution failed: {error}",
|
message="Script execution failed: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
@@ -2043,7 +2127,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Script execution failed: {error}",
|
message="Script execution failed: {error}",
|
||||||
tag="JS_EXEC",
|
tag="JS_EXEC",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
@@ -2057,7 +2141,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
Boolean indicating visibility
|
Boolean indicating visibility
|
||||||
"""
|
"""
|
||||||
return await page.evaluate("""
|
return await page.evaluate(
|
||||||
|
"""
|
||||||
() => {
|
() => {
|
||||||
const element = document.body;
|
const element = document.body;
|
||||||
if (!element) return false;
|
if (!element) return false;
|
||||||
@@ -2067,7 +2152,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
style.opacity !== '0';
|
style.opacity !== '0';
|
||||||
return isVisible;
|
return isVisible;
|
||||||
}
|
}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1):
|
async def safe_scroll(self, page: Page, x: int, y: int, delay: float = 0.1):
|
||||||
"""
|
"""
|
||||||
@@ -2079,7 +2165,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
y: Vertical scroll position
|
y: Vertical scroll position
|
||||||
"""
|
"""
|
||||||
result = await self.csp_scroll_to(page, x, y)
|
result = await self.csp_scroll_to(page, x, y)
|
||||||
if result['success']:
|
if result["success"]:
|
||||||
await page.wait_for_timeout(delay * 1000)
|
await page.wait_for_timeout(delay * 1000)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -2126,11 +2212,11 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
}}"""
|
}}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result['success']:
|
if not result["success"]:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="Scroll operation failed: {error}",
|
message="Scroll operation failed: {error}",
|
||||||
tag="SCROLL",
|
tag="SCROLL",
|
||||||
params={"error": result.get('error')}
|
params={"error": result.get("error")},
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -2139,12 +2225,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Failed to execute scroll: {error}",
|
message="Failed to execute scroll: {error}",
|
||||||
tag="SCROLL",
|
tag="SCROLL",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return {
|
return {"success": False, "error": str(e)}
|
||||||
"success": False,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_page_dimensions(self, page: Page):
|
async def get_page_dimensions(self, page: Page):
|
||||||
"""
|
"""
|
||||||
@@ -2156,12 +2239,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict containing width and height of the page
|
Dict containing width and height of the page
|
||||||
"""
|
"""
|
||||||
return await page.evaluate("""
|
return await page.evaluate(
|
||||||
|
"""
|
||||||
() => {
|
() => {
|
||||||
const {scrollWidth, scrollHeight} = document.documentElement;
|
const {scrollWidth, scrollHeight} = document.documentElement;
|
||||||
return {width: scrollWidth, height: scrollHeight};
|
return {width: scrollWidth, height: scrollHeight};
|
||||||
}
|
}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
async def page_need_scroll(self, page: Page) -> bool:
|
async def page_need_scroll(self, page: Page) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -2174,18 +2259,20 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
bool: True if page needs scrolling
|
bool: True if page needs scrolling
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
need_scroll = await page.evaluate("""
|
need_scroll = await page.evaluate(
|
||||||
|
"""
|
||||||
() => {
|
() => {
|
||||||
const scrollHeight = document.documentElement.scrollHeight;
|
const scrollHeight = document.documentElement.scrollHeight;
|
||||||
const viewportHeight = window.innerHeight;
|
const viewportHeight = window.innerHeight;
|
||||||
return scrollHeight > viewportHeight;
|
return scrollHeight > viewportHeight;
|
||||||
}
|
}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
return need_scroll
|
return need_scroll
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="Failed to check scroll need: {error}. Defaulting to True for safety.",
|
message="Failed to check scroll need: {error}. Defaulting to True for safety.",
|
||||||
tag="SCROLL",
|
tag="SCROLL",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return True # Default to scrolling if check fails
|
return True # Default to scrolling if check fails
|
||||||
@@ -1,27 +1,29 @@
|
|||||||
import os, sys
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Tuple, Dict
|
from typing import Optional, Dict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import logging
|
import logging
|
||||||
import json # Added for serialization/deserialization
|
import json # Added for serialization/deserialization
|
||||||
from .utils import ensure_content_dirs, generate_content_hash
|
from .utils import ensure_content_dirs, generate_content_hash
|
||||||
from .models import CrawlResult, MarkdownGenerationResult
|
from .models import CrawlResult, MarkdownGenerationResult
|
||||||
import xxhash
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from .config import NEED_MIGRATION
|
|
||||||
from .version_manager import VersionManager
|
from .version_manager import VersionManager
|
||||||
from .async_logger import AsyncLogger
|
from .async_logger import AsyncLogger
|
||||||
from .utils import get_error_context, create_box_message
|
from .utils import get_error_context, create_box_message
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
|
base_directory = DB_PATH = os.path.join(
|
||||||
|
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
|
||||||
|
)
|
||||||
os.makedirs(DB_PATH, exist_ok=True)
|
os.makedirs(DB_PATH, exist_ok=True)
|
||||||
DB_PATH = os.path.join(base_directory, "crawl4ai.db")
|
DB_PATH = os.path.join(base_directory, "crawl4ai.db")
|
||||||
|
|
||||||
|
|
||||||
class AsyncDatabaseManager:
|
class AsyncDatabaseManager:
|
||||||
def __init__(self, pool_size: int = 10, max_retries: int = 3):
|
def __init__(self, pool_size: int = 10, max_retries: int = 3):
|
||||||
self.db_path = DB_PATH
|
self.db_path = DB_PATH
|
||||||
@@ -37,10 +39,9 @@ class AsyncDatabaseManager:
|
|||||||
self.logger = AsyncLogger(
|
self.logger = AsyncLogger(
|
||||||
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
|
log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
|
||||||
verbose=False,
|
verbose=False,
|
||||||
tag_width=10
|
tag_width=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize the database and connection pool"""
|
"""Initialize the database and connection pool"""
|
||||||
try:
|
try:
|
||||||
@@ -67,28 +68,32 @@ class AsyncDatabaseManager:
|
|||||||
if needs_update:
|
if needs_update:
|
||||||
self.logger.info("New version detected, running updates", tag="INIT")
|
self.logger.info("New version detected, running updates", tag="INIT")
|
||||||
await self.update_db_schema()
|
await self.update_db_schema()
|
||||||
from .migrations import run_migration # Import here to avoid circular imports
|
from .migrations import (
|
||||||
|
run_migration,
|
||||||
|
) # Import here to avoid circular imports
|
||||||
|
|
||||||
await run_migration()
|
await run_migration()
|
||||||
self.version_manager.update_version() # Update stored version after successful migration
|
self.version_manager.update_version() # Update stored version after successful migration
|
||||||
self.logger.success("Version update completed successfully", tag="COMPLETE")
|
self.logger.success(
|
||||||
|
"Version update completed successfully", tag="COMPLETE"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.success("Database initialization completed successfully", tag="COMPLETE")
|
self.logger.success(
|
||||||
|
"Database initialization completed successfully", tag="COMPLETE"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Database initialization error: {error}",
|
message="Database initialization error: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
message="Database will be initialized on first use",
|
message="Database will be initialized on first use", tag="INIT"
|
||||||
tag="INIT"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""Cleanup connections when shutting down"""
|
"""Cleanup connections when shutting down"""
|
||||||
async with self.pool_lock:
|
async with self.pool_lock:
|
||||||
@@ -107,6 +112,7 @@ class AsyncDatabaseManager:
|
|||||||
self._initialized = True
|
self._initialized = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
error_context = get_error_context(sys.exc_info())
|
error_context = get_error_context(sys.exc_info())
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
|
message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
|
||||||
@@ -115,8 +121,8 @@ class AsyncDatabaseManager:
|
|||||||
params={
|
params={
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"context": error_context["code_context"],
|
"context": error_context["code_context"],
|
||||||
"traceback": error_context["full_traceback"]
|
"traceback": error_context["full_traceback"],
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -127,29 +133,40 @@ class AsyncDatabaseManager:
|
|||||||
async with self.pool_lock:
|
async with self.pool_lock:
|
||||||
if task_id not in self.connection_pool:
|
if task_id not in self.connection_pool:
|
||||||
try:
|
try:
|
||||||
conn = await aiosqlite.connect(
|
conn = await aiosqlite.connect(self.db_path, timeout=30.0)
|
||||||
self.db_path,
|
await conn.execute("PRAGMA journal_mode = WAL")
|
||||||
timeout=30.0
|
await conn.execute("PRAGMA busy_timeout = 5000")
|
||||||
)
|
|
||||||
await conn.execute('PRAGMA journal_mode = WAL')
|
|
||||||
await conn.execute('PRAGMA busy_timeout = 5000')
|
|
||||||
|
|
||||||
# Verify database structure
|
# Verify database structure
|
||||||
async with conn.execute("PRAGMA table_info(crawled_data)") as cursor:
|
async with conn.execute(
|
||||||
|
"PRAGMA table_info(crawled_data)"
|
||||||
|
) as cursor:
|
||||||
columns = await cursor.fetchall()
|
columns = await cursor.fetchall()
|
||||||
column_names = [col[1] for col in columns]
|
column_names = [col[1] for col in columns]
|
||||||
expected_columns = {
|
expected_columns = {
|
||||||
'url', 'html', 'cleaned_html', 'markdown', 'extracted_content',
|
"url",
|
||||||
'success', 'media', 'links', 'metadata', 'screenshot',
|
"html",
|
||||||
'response_headers', 'downloaded_files'
|
"cleaned_html",
|
||||||
|
"markdown",
|
||||||
|
"extracted_content",
|
||||||
|
"success",
|
||||||
|
"media",
|
||||||
|
"links",
|
||||||
|
"metadata",
|
||||||
|
"screenshot",
|
||||||
|
"response_headers",
|
||||||
|
"downloaded_files",
|
||||||
}
|
}
|
||||||
missing_columns = expected_columns - set(column_names)
|
missing_columns = expected_columns - set(column_names)
|
||||||
if missing_columns:
|
if missing_columns:
|
||||||
raise ValueError(f"Database missing columns: {missing_columns}")
|
raise ValueError(
|
||||||
|
f"Database missing columns: {missing_columns}"
|
||||||
|
)
|
||||||
|
|
||||||
self.connection_pool[task_id] = conn
|
self.connection_pool[task_id] = conn
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
error_context = get_error_context(sys.exc_info())
|
error_context = get_error_context(sys.exc_info())
|
||||||
error_message = (
|
error_message = (
|
||||||
f"Unexpected error in db get_connection at line {error_context['line_no']} "
|
f"Unexpected error in db get_connection at line {error_context['line_no']} "
|
||||||
@@ -167,6 +184,7 @@ class AsyncDatabaseManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
error_context = get_error_context(sys.exc_info())
|
error_context = get_error_context(sys.exc_info())
|
||||||
error_message = (
|
error_message = (
|
||||||
f"Unexpected error in db get_connection at line {error_context['line_no']} "
|
f"Unexpected error in db get_connection at line {error_context['line_no']} "
|
||||||
@@ -185,7 +203,6 @@ class AsyncDatabaseManager:
|
|||||||
del self.connection_pool[task_id]
|
del self.connection_pool[task_id]
|
||||||
self.connection_semaphore.release()
|
self.connection_semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
async def execute_with_retry(self, operation, *args):
|
async def execute_with_retry(self, operation, *args):
|
||||||
"""Execute database operations with retry logic"""
|
"""Execute database operations with retry logic"""
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
@@ -200,10 +217,7 @@ class AsyncDatabaseManager:
|
|||||||
message="Operation failed after {retries} attempts: {error}",
|
message="Operation failed after {retries} attempts: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={
|
params={"retries": self.max_retries, "error": str(e)},
|
||||||
"retries": self.max_retries,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
|
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
|
||||||
@@ -211,7 +225,8 @@ class AsyncDatabaseManager:
|
|||||||
async def ainit_db(self):
|
async def ainit_db(self):
|
||||||
"""Initialize database schema"""
|
"""Initialize database schema"""
|
||||||
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
||||||
await db.execute('''
|
await db.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS crawled_data (
|
CREATE TABLE IF NOT EXISTS crawled_data (
|
||||||
url TEXT PRIMARY KEY,
|
url TEXT PRIMARY KEY,
|
||||||
html TEXT,
|
html TEXT,
|
||||||
@@ -226,11 +241,10 @@ class AsyncDatabaseManager:
|
|||||||
response_headers TEXT DEFAULT "{}",
|
response_headers TEXT DEFAULT "{}",
|
||||||
downloaded_files TEXT DEFAULT "{}" -- New column added
|
downloaded_files TEXT DEFAULT "{}" -- New column added
|
||||||
)
|
)
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def update_db_schema(self):
|
async def update_db_schema(self):
|
||||||
"""Update database schema if needed"""
|
"""Update database schema if needed"""
|
||||||
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
|
||||||
@@ -239,7 +253,14 @@ class AsyncDatabaseManager:
|
|||||||
column_names = [column[1] for column in columns]
|
column_names = [column[1] for column in columns]
|
||||||
|
|
||||||
# List of new columns to add
|
# List of new columns to add
|
||||||
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files']
|
new_columns = [
|
||||||
|
"media",
|
||||||
|
"links",
|
||||||
|
"metadata",
|
||||||
|
"screenshot",
|
||||||
|
"response_headers",
|
||||||
|
"downloaded_files",
|
||||||
|
]
|
||||||
|
|
||||||
for column in new_columns:
|
for column in new_columns:
|
||||||
if column not in column_names:
|
if column not in column_names:
|
||||||
@@ -248,22 +269,26 @@ class AsyncDatabaseManager:
|
|||||||
|
|
||||||
async def aalter_db_add_column(self, new_column: str, db):
|
async def aalter_db_add_column(self, new_column: str, db):
|
||||||
"""Add new column to the database"""
|
"""Add new column to the database"""
|
||||||
if new_column == 'response_headers':
|
if new_column == "response_headers":
|
||||||
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"')
|
await db.execute(
|
||||||
|
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
|
await db.execute(
|
||||||
|
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
|
||||||
|
)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
message="Added column '{column}' to the database",
|
message="Added column '{column}' to the database",
|
||||||
tag="INIT",
|
tag="INIT",
|
||||||
params={"column": new_column}
|
params={"column": new_column},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
|
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
|
||||||
"""Retrieve cached URL data as CrawlResult"""
|
"""Retrieve cached URL data as CrawlResult"""
|
||||||
|
|
||||||
async def _get(db):
|
async def _get(db):
|
||||||
async with db.execute(
|
async with db.execute(
|
||||||
'SELECT * FROM crawled_data WHERE url = ?', (url,)
|
"SELECT * FROM crawled_data WHERE url = ?", (url,)
|
||||||
) as cursor:
|
) as cursor:
|
||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
@@ -276,42 +301,54 @@ class AsyncDatabaseManager:
|
|||||||
|
|
||||||
# Load content from files using stored hashes
|
# Load content from files using stored hashes
|
||||||
content_fields = {
|
content_fields = {
|
||||||
'html': row_dict['html'],
|
"html": row_dict["html"],
|
||||||
'cleaned_html': row_dict['cleaned_html'],
|
"cleaned_html": row_dict["cleaned_html"],
|
||||||
'markdown': row_dict['markdown'],
|
"markdown": row_dict["markdown"],
|
||||||
'extracted_content': row_dict['extracted_content'],
|
"extracted_content": row_dict["extracted_content"],
|
||||||
'screenshot': row_dict['screenshot'],
|
"screenshot": row_dict["screenshot"],
|
||||||
'screenshots': row_dict['screenshot'],
|
"screenshots": row_dict["screenshot"],
|
||||||
}
|
}
|
||||||
|
|
||||||
for field, hash_value in content_fields.items():
|
for field, hash_value in content_fields.items():
|
||||||
if hash_value:
|
if hash_value:
|
||||||
content = await self._load_content(
|
content = await self._load_content(
|
||||||
hash_value,
|
hash_value,
|
||||||
field.split('_')[0] # Get content type from field name
|
field.split("_")[0], # Get content type from field name
|
||||||
)
|
)
|
||||||
row_dict[field] = content or ""
|
row_dict[field] = content or ""
|
||||||
else:
|
else:
|
||||||
row_dict[field] = ""
|
row_dict[field] = ""
|
||||||
|
|
||||||
# Parse JSON fields
|
# Parse JSON fields
|
||||||
json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown']
|
json_fields = [
|
||||||
|
"media",
|
||||||
|
"links",
|
||||||
|
"metadata",
|
||||||
|
"response_headers",
|
||||||
|
"markdown",
|
||||||
|
]
|
||||||
for field in json_fields:
|
for field in json_fields:
|
||||||
try:
|
try:
|
||||||
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {}
|
row_dict[field] = (
|
||||||
|
json.loads(row_dict[field]) if row_dict[field] else {}
|
||||||
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
row_dict[field] = {}
|
row_dict[field] = {}
|
||||||
|
|
||||||
if isinstance(row_dict['markdown'], Dict):
|
if isinstance(row_dict["markdown"], Dict):
|
||||||
row_dict['markdown_v2'] = row_dict['markdown']
|
row_dict["markdown_v2"] = row_dict["markdown"]
|
||||||
if row_dict['markdown'].get('raw_markdown'):
|
if row_dict["markdown"].get("raw_markdown"):
|
||||||
row_dict['markdown'] = row_dict['markdown']['raw_markdown']
|
row_dict["markdown"] = row_dict["markdown"]["raw_markdown"]
|
||||||
|
|
||||||
# Parse downloaded_files
|
# Parse downloaded_files
|
||||||
try:
|
try:
|
||||||
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else []
|
row_dict["downloaded_files"] = (
|
||||||
|
json.loads(row_dict["downloaded_files"])
|
||||||
|
if row_dict["downloaded_files"]
|
||||||
|
else []
|
||||||
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
row_dict['downloaded_files'] = []
|
row_dict["downloaded_files"] = []
|
||||||
|
|
||||||
# Remove any fields not in CrawlResult model
|
# Remove any fields not in CrawlResult model
|
||||||
valid_fields = CrawlResult.__annotations__.keys()
|
valid_fields = CrawlResult.__annotations__.keys()
|
||||||
@@ -326,7 +363,7 @@ class AsyncDatabaseManager:
|
|||||||
message="Error retrieving cached URL: {error}",
|
message="Error retrieving cached URL: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -334,37 +371,52 @@ class AsyncDatabaseManager:
|
|||||||
"""Cache CrawlResult data"""
|
"""Cache CrawlResult data"""
|
||||||
# Store content files and get hashes
|
# Store content files and get hashes
|
||||||
content_map = {
|
content_map = {
|
||||||
'html': (result.html, 'html'),
|
"html": (result.html, "html"),
|
||||||
'cleaned_html': (result.cleaned_html or "", 'cleaned'),
|
"cleaned_html": (result.cleaned_html or "", "cleaned"),
|
||||||
'markdown': None,
|
"markdown": None,
|
||||||
'extracted_content': (result.extracted_content or "", 'extracted'),
|
"extracted_content": (result.extracted_content or "", "extracted"),
|
||||||
'screenshot': (result.screenshot or "", 'screenshots')
|
"screenshot": (result.screenshot or "", "screenshots"),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(result.markdown, MarkdownGenerationResult):
|
if isinstance(result.markdown, MarkdownGenerationResult):
|
||||||
content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown')
|
content_map["markdown"] = (
|
||||||
elif hasattr(result, 'markdown_v2'):
|
result.markdown.model_dump_json(),
|
||||||
content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown')
|
"markdown",
|
||||||
|
)
|
||||||
|
elif hasattr(result, "markdown_v2"):
|
||||||
|
content_map["markdown"] = (
|
||||||
|
result.markdown_v2.model_dump_json(),
|
||||||
|
"markdown",
|
||||||
|
)
|
||||||
elif isinstance(result.markdown, str):
|
elif isinstance(result.markdown, str):
|
||||||
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
|
markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
|
||||||
content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown')
|
content_map["markdown"] = (
|
||||||
|
markdown_result.model_dump_json(),
|
||||||
|
"markdown",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
|
content_map["markdown"] = (
|
||||||
|
MarkdownGenerationResult().model_dump_json(),
|
||||||
|
"markdown",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message=f"Error processing markdown content: {str(e)}",
|
message=f"Error processing markdown content: {str(e)}", tag="WARNING"
|
||||||
tag="WARNING"
|
|
||||||
)
|
)
|
||||||
# Fallback to empty markdown result
|
# Fallback to empty markdown result
|
||||||
content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
|
content_map["markdown"] = (
|
||||||
|
MarkdownGenerationResult().model_dump_json(),
|
||||||
|
"markdown",
|
||||||
|
)
|
||||||
|
|
||||||
content_hashes = {}
|
content_hashes = {}
|
||||||
for field, (content, content_type) in content_map.items():
|
for field, (content, content_type) in content_map.items():
|
||||||
content_hashes[field] = await self._store_content(content, content_type)
|
content_hashes[field] = await self._store_content(content, content_type)
|
||||||
|
|
||||||
async def _cache(db):
|
async def _cache(db):
|
||||||
await db.execute('''
|
await db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO crawled_data (
|
INSERT INTO crawled_data (
|
||||||
url, html, cleaned_html, markdown,
|
url, html, cleaned_html, markdown,
|
||||||
extracted_content, success, media, links, metadata,
|
extracted_content, success, media, links, metadata,
|
||||||
@@ -383,20 +435,22 @@ class AsyncDatabaseManager:
|
|||||||
screenshot = excluded.screenshot,
|
screenshot = excluded.screenshot,
|
||||||
response_headers = excluded.response_headers,
|
response_headers = excluded.response_headers,
|
||||||
downloaded_files = excluded.downloaded_files
|
downloaded_files = excluded.downloaded_files
|
||||||
''', (
|
""",
|
||||||
|
(
|
||||||
result.url,
|
result.url,
|
||||||
content_hashes['html'],
|
content_hashes["html"],
|
||||||
content_hashes['cleaned_html'],
|
content_hashes["cleaned_html"],
|
||||||
content_hashes['markdown'],
|
content_hashes["markdown"],
|
||||||
content_hashes['extracted_content'],
|
content_hashes["extracted_content"],
|
||||||
result.success,
|
result.success,
|
||||||
json.dumps(result.media),
|
json.dumps(result.media),
|
||||||
json.dumps(result.links),
|
json.dumps(result.links),
|
||||||
json.dumps(result.metadata or {}),
|
json.dumps(result.metadata or {}),
|
||||||
content_hashes['screenshot'],
|
content_hashes["screenshot"],
|
||||||
json.dumps(result.response_headers or {}),
|
json.dumps(result.response_headers or {}),
|
||||||
json.dumps(result.downloaded_files or [])
|
json.dumps(result.downloaded_files or []),
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.execute_with_retry(_cache)
|
await self.execute_with_retry(_cache)
|
||||||
@@ -405,14 +459,14 @@ class AsyncDatabaseManager:
|
|||||||
message="Error caching URL: {error}",
|
message="Error caching URL: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def aget_total_count(self) -> int:
|
async def aget_total_count(self) -> int:
|
||||||
"""Get total number of cached URLs"""
|
"""Get total number of cached URLs"""
|
||||||
|
|
||||||
async def _count(db):
|
async def _count(db):
|
||||||
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor:
|
async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor:
|
||||||
result = await cursor.fetchone()
|
result = await cursor.fetchone()
|
||||||
return result[0] if result else 0
|
return result[0] if result else 0
|
||||||
|
|
||||||
@@ -423,14 +477,15 @@ class AsyncDatabaseManager:
|
|||||||
message="Error getting total count: {error}",
|
message="Error getting total count: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def aclear_db(self):
|
async def aclear_db(self):
|
||||||
"""Clear all data from the database"""
|
"""Clear all data from the database"""
|
||||||
|
|
||||||
async def _clear(db):
|
async def _clear(db):
|
||||||
await db.execute('DELETE FROM crawled_data')
|
await db.execute("DELETE FROM crawled_data")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.execute_with_retry(_clear)
|
await self.execute_with_retry(_clear)
|
||||||
@@ -439,13 +494,14 @@ class AsyncDatabaseManager:
|
|||||||
message="Error clearing database: {error}",
|
message="Error clearing database: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aflush_db(self):
|
async def aflush_db(self):
|
||||||
"""Drop the entire table"""
|
"""Drop the entire table"""
|
||||||
|
|
||||||
async def _flush(db):
|
async def _flush(db):
|
||||||
await db.execute('DROP TABLE IF EXISTS crawled_data')
|
await db.execute("DROP TABLE IF EXISTS crawled_data")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.execute_with_retry(_flush)
|
await self.execute_with_retry(_flush)
|
||||||
@@ -454,10 +510,9 @@ class AsyncDatabaseManager:
|
|||||||
message="Error flushing database: {error}",
|
message="Error flushing database: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _store_content(self, content: str, content_type: str) -> str:
|
async def _store_content(self, content: str, content_type: str) -> str:
|
||||||
"""Store content in filesystem and return hash"""
|
"""Store content in filesystem and return hash"""
|
||||||
if not content:
|
if not content:
|
||||||
@@ -468,28 +523,31 @@ class AsyncDatabaseManager:
|
|||||||
|
|
||||||
# Only write if file doesn't exist
|
# Only write if file doesn't exist
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
|
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
|
||||||
await f.write(content)
|
await f.write(content)
|
||||||
|
|
||||||
return content_hash
|
return content_hash
|
||||||
|
|
||||||
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
|
async def _load_content(
|
||||||
|
self, content_hash: str, content_type: str
|
||||||
|
) -> Optional[str]:
|
||||||
"""Load content from filesystem by hash"""
|
"""Load content from filesystem by hash"""
|
||||||
if not content_hash:
|
if not content_hash:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
|
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
|
||||||
return await f.read()
|
return await f.read()
|
||||||
except:
|
except:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
message="Failed to load content: {file_path}",
|
message="Failed to load content: {file_path}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
force_verbose=True,
|
force_verbose=True,
|
||||||
params={"file_path": file_path}
|
params={"file_path": file_path},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance
|
# Create a singleton instance
|
||||||
async_db_manager = AsyncDatabaseManager()
|
async_db_manager = AsyncDatabaseManager()
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
from typing import Dict, Optional, List
|
from typing import Dict, Optional, List, Tuple
|
||||||
from .async_configs import *
|
from .async_configs import CrawlerRunConfig
|
||||||
from .models import *
|
from .models import (
|
||||||
|
CrawlResult,
|
||||||
|
CrawlerTaskResult,
|
||||||
|
CrawlStatus,
|
||||||
|
DisplayMode,
|
||||||
|
CrawlStats,
|
||||||
|
DomainState,
|
||||||
|
)
|
||||||
|
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.style import Style
|
|
||||||
from rich import box
|
from rich import box
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import psutil
|
import psutil
|
||||||
@@ -26,7 +31,7 @@ class RateLimiter:
|
|||||||
base_delay: Tuple[float, float] = (1.0, 3.0),
|
base_delay: Tuple[float, float] = (1.0, 3.0),
|
||||||
max_delay: float = 60.0,
|
max_delay: float = 60.0,
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
rate_limit_codes: List[int] = None
|
rate_limit_codes: List[int] = None,
|
||||||
):
|
):
|
||||||
self.base_delay = base_delay
|
self.base_delay = base_delay
|
||||||
self.max_delay = max_delay
|
self.max_delay = max_delay
|
||||||
@@ -68,21 +73,24 @@ class RateLimiter:
|
|||||||
|
|
||||||
# Exponential backoff with random jitter
|
# Exponential backoff with random jitter
|
||||||
state.current_delay = min(
|
state.current_delay = min(
|
||||||
state.current_delay * 2 * random.uniform(0.75, 1.25),
|
state.current_delay * 2 * random.uniform(0.75, 1.25), self.max_delay
|
||||||
self.max_delay
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Gradually reduce delay on success
|
# Gradually reduce delay on success
|
||||||
state.current_delay = max(
|
state.current_delay = max(
|
||||||
random.uniform(*self.base_delay),
|
random.uniform(*self.base_delay), state.current_delay * 0.75
|
||||||
state.current_delay * 0.75
|
|
||||||
)
|
)
|
||||||
state.fail_count = 0
|
state.fail_count = 0
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class CrawlerMonitor:
|
class CrawlerMonitor:
|
||||||
def __init__(self, max_visible_rows: int = 15, display_mode: DisplayMode = DisplayMode.DETAILED):
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_visible_rows: int = 15,
|
||||||
|
display_mode: DisplayMode = DisplayMode.DETAILED,
|
||||||
|
):
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
self.max_visible_rows = max_visible_rows
|
self.max_visible_rows = max_visible_rows
|
||||||
self.display_mode = display_mode
|
self.display_mode = display_mode
|
||||||
@@ -98,7 +106,9 @@ class CrawlerMonitor:
|
|||||||
self.live.stop()
|
self.live.stop()
|
||||||
|
|
||||||
def add_task(self, task_id: str, url: str):
|
def add_task(self, task_id: str, url: str):
|
||||||
self.stats[task_id] = CrawlStats(task_id=task_id, url=url, status=CrawlStatus.QUEUED)
|
self.stats[task_id] = CrawlStats(
|
||||||
|
task_id=task_id, url=url, status=CrawlStatus.QUEUED
|
||||||
|
)
|
||||||
self.live.update(self._create_table())
|
self.live.update(self._create_table())
|
||||||
|
|
||||||
def update_task(self, task_id: str, **kwargs):
|
def update_task(self, task_id: str, **kwargs):
|
||||||
@@ -114,20 +124,30 @@ class CrawlerMonitor:
|
|||||||
title="Crawler Status Overview",
|
title="Crawler Status Overview",
|
||||||
title_style="bold magenta",
|
title_style="bold magenta",
|
||||||
header_style="bold blue",
|
header_style="bold blue",
|
||||||
show_lines=True
|
show_lines=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate statistics
|
# Calculate statistics
|
||||||
total_tasks = len(self.stats)
|
total_tasks = len(self.stats)
|
||||||
queued = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED)
|
queued = sum(
|
||||||
in_progress = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS)
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.QUEUED
|
||||||
completed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED)
|
)
|
||||||
failed = sum(1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED)
|
in_progress = sum(
|
||||||
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
|
||||||
|
)
|
||||||
|
completed = sum(
|
||||||
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
|
||||||
|
)
|
||||||
|
failed = sum(
|
||||||
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
|
||||||
|
)
|
||||||
|
|
||||||
# Memory statistics
|
# Memory statistics
|
||||||
current_memory = self.process.memory_info().rss / (1024 * 1024)
|
current_memory = self.process.memory_info().rss / (1024 * 1024)
|
||||||
total_task_memory = sum(stat.memory_usage for stat in self.stats.values())
|
total_task_memory = sum(stat.memory_usage for stat in self.stats.values())
|
||||||
peak_memory = max((stat.peak_memory for stat in self.stats.values()), default=0.0)
|
peak_memory = max(
|
||||||
|
(stat.peak_memory for stat in self.stats.values()), default=0.0
|
||||||
|
)
|
||||||
|
|
||||||
# Duration
|
# Duration
|
||||||
duration = datetime.now() - self.start_time
|
duration = datetime.now() - self.start_time
|
||||||
@@ -137,53 +157,43 @@ class CrawlerMonitor:
|
|||||||
table.add_column("Count", justify="right")
|
table.add_column("Count", justify="right")
|
||||||
table.add_column("Percentage", justify="right")
|
table.add_column("Percentage", justify="right")
|
||||||
|
|
||||||
table.add_row(
|
table.add_row("Total Tasks", str(total_tasks), "100%")
|
||||||
"Total Tasks",
|
|
||||||
str(total_tasks),
|
|
||||||
"100%"
|
|
||||||
)
|
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[yellow]In Queue[/yellow]",
|
"[yellow]In Queue[/yellow]",
|
||||||
str(queued),
|
str(queued),
|
||||||
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
f"{(queued/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[blue]In Progress[/blue]",
|
"[blue]In Progress[/blue]",
|
||||||
str(in_progress),
|
str(in_progress),
|
||||||
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
f"{(in_progress/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[green]Completed[/green]",
|
"[green]Completed[/green]",
|
||||||
str(completed),
|
str(completed),
|
||||||
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
f"{(completed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[red]Failed[/red]",
|
"[red]Failed[/red]",
|
||||||
str(failed),
|
str(failed),
|
||||||
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%"
|
f"{(failed/total_tasks*100):.1f}%" if total_tasks > 0 else "0%",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add memory information
|
# Add memory information
|
||||||
table.add_section()
|
table.add_section()
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[magenta]Current Memory[/magenta]",
|
"[magenta]Current Memory[/magenta]", f"{current_memory:.1f} MB", ""
|
||||||
f"{current_memory:.1f} MB",
|
|
||||||
""
|
|
||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[magenta]Total Task Memory[/magenta]",
|
"[magenta]Total Task Memory[/magenta]", f"{total_task_memory:.1f} MB", ""
|
||||||
f"{total_task_memory:.1f} MB",
|
|
||||||
""
|
|
||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[magenta]Peak Task Memory[/magenta]",
|
"[magenta]Peak Task Memory[/magenta]", f"{peak_memory:.1f} MB", ""
|
||||||
f"{peak_memory:.1f} MB",
|
|
||||||
""
|
|
||||||
)
|
)
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[yellow]Runtime[/yellow]",
|
"[yellow]Runtime[/yellow]",
|
||||||
str(timedelta(seconds=int(duration.total_seconds()))),
|
str(timedelta(seconds=int(duration.total_seconds()))),
|
||||||
""
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
@@ -193,7 +203,7 @@ class CrawlerMonitor:
|
|||||||
box=box.ROUNDED,
|
box=box.ROUNDED,
|
||||||
title="Crawler Performance Monitor",
|
title="Crawler Performance Monitor",
|
||||||
title_style="bold magenta",
|
title_style="bold magenta",
|
||||||
header_style="bold blue"
|
header_style="bold blue",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add columns
|
# Add columns
|
||||||
@@ -207,12 +217,15 @@ class CrawlerMonitor:
|
|||||||
|
|
||||||
# Add summary row
|
# Add summary row
|
||||||
total_memory = sum(stat.memory_usage for stat in self.stats.values())
|
total_memory = sum(stat.memory_usage for stat in self.stats.values())
|
||||||
active_count = sum(1 for stat in self.stats.values()
|
active_count = sum(
|
||||||
if stat.status == CrawlStatus.IN_PROGRESS)
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.IN_PROGRESS
|
||||||
completed_count = sum(1 for stat in self.stats.values()
|
)
|
||||||
if stat.status == CrawlStatus.COMPLETED)
|
completed_count = sum(
|
||||||
failed_count = sum(1 for stat in self.stats.values()
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.COMPLETED
|
||||||
if stat.status == CrawlStatus.FAILED)
|
)
|
||||||
|
failed_count = sum(
|
||||||
|
1 for stat in self.stats.values() if stat.status == CrawlStatus.FAILED
|
||||||
|
)
|
||||||
|
|
||||||
table.add_row(
|
table.add_row(
|
||||||
"[bold yellow]SUMMARY",
|
"[bold yellow]SUMMARY",
|
||||||
@@ -220,9 +233,13 @@ class CrawlerMonitor:
|
|||||||
f"Active: {active_count}",
|
f"Active: {active_count}",
|
||||||
f"{total_memory:.1f}",
|
f"{total_memory:.1f}",
|
||||||
f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
|
f"{self.process.memory_info().rss / (1024 * 1024):.1f}",
|
||||||
str(timedelta(seconds=int((datetime.now() - self.start_time).total_seconds()))),
|
str(
|
||||||
|
timedelta(
|
||||||
|
seconds=int((datetime.now() - self.start_time).total_seconds())
|
||||||
|
)
|
||||||
|
),
|
||||||
f"✓{completed_count} ✗{failed_count}",
|
f"✓{completed_count} ✗{failed_count}",
|
||||||
style="bold"
|
style="bold",
|
||||||
)
|
)
|
||||||
|
|
||||||
table.add_section()
|
table.add_section()
|
||||||
@@ -233,8 +250,8 @@ class CrawlerMonitor:
|
|||||||
key=lambda x: (
|
key=lambda x: (
|
||||||
x.status != CrawlStatus.IN_PROGRESS,
|
x.status != CrawlStatus.IN_PROGRESS,
|
||||||
x.status != CrawlStatus.QUEUED,
|
x.status != CrawlStatus.QUEUED,
|
||||||
x.end_time or datetime.max
|
x.end_time or datetime.max,
|
||||||
)
|
),
|
||||||
)[: self.max_visible_rows]
|
)[: self.max_visible_rows]
|
||||||
|
|
||||||
for stat in visible_stats:
|
for stat in visible_stats:
|
||||||
@@ -242,7 +259,7 @@ class CrawlerMonitor:
|
|||||||
CrawlStatus.QUEUED: "white",
|
CrawlStatus.QUEUED: "white",
|
||||||
CrawlStatus.IN_PROGRESS: "yellow",
|
CrawlStatus.IN_PROGRESS: "yellow",
|
||||||
CrawlStatus.COMPLETED: "green",
|
CrawlStatus.COMPLETED: "green",
|
||||||
CrawlStatus.FAILED: "red"
|
CrawlStatus.FAILED: "red",
|
||||||
}[stat.status]
|
}[stat.status]
|
||||||
|
|
||||||
table.add_row(
|
table.add_row(
|
||||||
@@ -252,7 +269,7 @@ class CrawlerMonitor:
|
|||||||
f"{stat.memory_usage:.1f}",
|
f"{stat.memory_usage:.1f}",
|
||||||
f"{stat.peak_memory:.1f}",
|
f"{stat.peak_memory:.1f}",
|
||||||
stat.duration,
|
stat.duration,
|
||||||
stat.error_message[:40] if stat.error_message else ""
|
stat.error_message[:40] if stat.error_message else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
@@ -268,7 +285,7 @@ class BaseDispatcher(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
rate_limiter: Optional[RateLimiter] = None,
|
rate_limiter: Optional[RateLimiter] = None,
|
||||||
monitor: Optional[CrawlerMonitor] = None
|
monitor: Optional[CrawlerMonitor] = None,
|
||||||
):
|
):
|
||||||
self.crawler = None
|
self.crawler = None
|
||||||
self._domain_last_hit: Dict[str, float] = {}
|
self._domain_last_hit: Dict[str, float] = {}
|
||||||
@@ -282,7 +299,7 @@ class BaseDispatcher(ABC):
|
|||||||
url: str,
|
url: str,
|
||||||
config: CrawlerRunConfig,
|
config: CrawlerRunConfig,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
monitor: Optional[CrawlerMonitor] = None
|
monitor: Optional[CrawlerMonitor] = None,
|
||||||
) -> CrawlerTaskResult:
|
) -> CrawlerTaskResult:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -290,12 +307,13 @@ class BaseDispatcher(ABC):
|
|||||||
async def run_urls(
|
async def run_urls(
|
||||||
self,
|
self,
|
||||||
urls: List[str],
|
urls: List[str],
|
||||||
crawler: "AsyncWebCrawler",
|
crawler: "AsyncWebCrawler", # noqa: F821
|
||||||
config: CrawlerRunConfig,
|
config: CrawlerRunConfig,
|
||||||
monitor: Optional[CrawlerMonitor] = None
|
monitor: Optional[CrawlerMonitor] = None,
|
||||||
) -> List[CrawlerTaskResult]:
|
) -> List[CrawlerTaskResult]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MemoryAdaptiveDispatcher(BaseDispatcher):
|
class MemoryAdaptiveDispatcher(BaseDispatcher):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -304,7 +322,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
max_session_permit: int = 20,
|
max_session_permit: int = 20,
|
||||||
memory_wait_timeout: float = 300.0, # 5 minutes default timeout
|
memory_wait_timeout: float = 300.0, # 5 minutes default timeout
|
||||||
rate_limiter: Optional[RateLimiter] = None,
|
rate_limiter: Optional[RateLimiter] = None,
|
||||||
monitor: Optional[CrawlerMonitor] = None
|
monitor: Optional[CrawlerMonitor] = None,
|
||||||
):
|
):
|
||||||
super().__init__(rate_limiter, monitor)
|
super().__init__(rate_limiter, monitor)
|
||||||
self.memory_threshold_percent = memory_threshold_percent
|
self.memory_threshold_percent = memory_threshold_percent
|
||||||
@@ -324,7 +342,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.monitor:
|
if self.monitor:
|
||||||
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time)
|
self.monitor.update_task(
|
||||||
|
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
|
||||||
|
)
|
||||||
self.concurrent_sessions += 1
|
self.concurrent_sessions += 1
|
||||||
|
|
||||||
if self.rate_limiter:
|
if self.rate_limiter:
|
||||||
@@ -350,7 +370,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
peak_memory=peak_memory,
|
peak_memory=peak_memory,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.success:
|
if not result.success:
|
||||||
@@ -364,7 +384,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
if self.monitor:
|
if self.monitor:
|
||||||
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
||||||
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e))
|
result = CrawlResult(
|
||||||
|
url=url, html="", metadata={}, success=False, error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
@@ -374,7 +396,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
memory_usage=memory_usage,
|
memory_usage=memory_usage,
|
||||||
peak_memory=peak_memory,
|
peak_memory=peak_memory,
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
self.concurrent_sessions -= 1
|
self.concurrent_sessions -= 1
|
||||||
|
|
||||||
@@ -386,13 +408,13 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
peak_memory=peak_memory,
|
peak_memory=peak_memory,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_urls(
|
async def run_urls(
|
||||||
self,
|
self,
|
||||||
urls: List[str],
|
urls: List[str],
|
||||||
crawler: "AsyncWebCrawler",
|
crawler: "AsyncWebCrawler", # noqa: F821
|
||||||
config: CrawlerRunConfig,
|
config: CrawlerRunConfig,
|
||||||
) -> List[CrawlerTaskResult]:
|
) -> List[CrawlerTaskResult]:
|
||||||
self.crawler = crawler
|
self.crawler = crawler
|
||||||
@@ -417,7 +439,9 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
if psutil.virtual_memory().percent >= self.memory_threshold_percent:
|
if psutil.virtual_memory().percent >= self.memory_threshold_percent:
|
||||||
# Check if we've exceeded the timeout
|
# Check if we've exceeded the timeout
|
||||||
if time.time() - wait_start_time > self.memory_wait_timeout:
|
if time.time() - wait_start_time > self.memory_wait_timeout:
|
||||||
raise MemoryError(f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds")
|
raise MemoryError(
|
||||||
|
f"Memory usage above threshold ({self.memory_threshold_percent}%) for more than {self.memory_wait_timeout} seconds"
|
||||||
|
)
|
||||||
await asyncio.sleep(self.check_interval)
|
await asyncio.sleep(self.check_interval)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -430,8 +454,7 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
done, pending = await asyncio.wait(
|
done, pending = await asyncio.wait(
|
||||||
active_tasks,
|
active_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||||
return_when=asyncio.FIRST_COMPLETED
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pending_tasks.extend(done)
|
pending_tasks.extend(done)
|
||||||
@@ -442,13 +465,14 @@ class MemoryAdaptiveDispatcher(BaseDispatcher):
|
|||||||
if self.monitor:
|
if self.monitor:
|
||||||
self.monitor.stop()
|
self.monitor.stop()
|
||||||
|
|
||||||
|
|
||||||
class SemaphoreDispatcher(BaseDispatcher):
|
class SemaphoreDispatcher(BaseDispatcher):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
semaphore_count: int = 5,
|
semaphore_count: int = 5,
|
||||||
max_session_permit: int = 20,
|
max_session_permit: int = 20,
|
||||||
rate_limiter: Optional[RateLimiter] = None,
|
rate_limiter: Optional[RateLimiter] = None,
|
||||||
monitor: Optional[CrawlerMonitor] = None
|
monitor: Optional[CrawlerMonitor] = None,
|
||||||
):
|
):
|
||||||
super().__init__(rate_limiter, monitor)
|
super().__init__(rate_limiter, monitor)
|
||||||
self.semaphore_count = semaphore_count
|
self.semaphore_count = semaphore_count
|
||||||
@@ -459,7 +483,7 @@ class SemaphoreDispatcher(BaseDispatcher):
|
|||||||
url: str,
|
url: str,
|
||||||
config: CrawlerRunConfig,
|
config: CrawlerRunConfig,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
semaphore: asyncio.Semaphore = None
|
semaphore: asyncio.Semaphore = None,
|
||||||
) -> CrawlerTaskResult:
|
) -> CrawlerTaskResult:
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
error_message = ""
|
error_message = ""
|
||||||
@@ -467,7 +491,9 @@ class SemaphoreDispatcher(BaseDispatcher):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.monitor:
|
if self.monitor:
|
||||||
self.monitor.update_task(task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time)
|
self.monitor.update_task(
|
||||||
|
task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time
|
||||||
|
)
|
||||||
|
|
||||||
if self.rate_limiter:
|
if self.rate_limiter:
|
||||||
await self.rate_limiter.wait_if_needed(url)
|
await self.rate_limiter.wait_if_needed(url)
|
||||||
@@ -493,7 +519,7 @@ class SemaphoreDispatcher(BaseDispatcher):
|
|||||||
peak_memory=peak_memory,
|
peak_memory=peak_memory,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.success:
|
if not result.success:
|
||||||
@@ -507,7 +533,9 @@ class SemaphoreDispatcher(BaseDispatcher):
|
|||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
if self.monitor:
|
if self.monitor:
|
||||||
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
self.monitor.update_task(task_id, status=CrawlStatus.FAILED)
|
||||||
result = CrawlResult(url=url, html="", metadata={}, success=False, error_message=str(e))
|
result = CrawlResult(
|
||||||
|
url=url, html="", metadata={}, success=False, error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
end_time = datetime.now()
|
end_time = datetime.now()
|
||||||
@@ -517,7 +545,7 @@ class SemaphoreDispatcher(BaseDispatcher):
|
|||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
memory_usage=memory_usage,
|
memory_usage=memory_usage,
|
||||||
peak_memory=peak_memory,
|
peak_memory=peak_memory,
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CrawlerTaskResult(
|
return CrawlerTaskResult(
|
||||||
@@ -528,12 +556,12 @@ class SemaphoreDispatcher(BaseDispatcher):
|
|||||||
peak_memory=peak_memory,
|
peak_memory=peak_memory,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
error_message=error_message
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_urls(
|
async def run_urls(
|
||||||
self,
|
self,
|
||||||
crawler: "AsyncWebCrawler",
|
crawler: "AsyncWebCrawler", # noqa: F821
|
||||||
urls: List[str],
|
urls: List[str],
|
||||||
config: CrawlerRunConfig,
|
config: CrawlerRunConfig,
|
||||||
) -> List[CrawlerTaskResult]:
|
) -> List[CrawlerTaskResult]:
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Dict, Any, Union
|
from typing import Optional, Dict, Any
|
||||||
from colorama import Fore, Back, Style, init
|
from colorama import Fore, Style, init
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(Enum):
|
class LogLevel(Enum):
|
||||||
DEBUG = 1
|
DEBUG = 1
|
||||||
INFO = 2
|
INFO = 2
|
||||||
@@ -12,6 +12,7 @@ class LogLevel(Enum):
|
|||||||
WARNING = 4
|
WARNING = 4
|
||||||
ERROR = 5
|
ERROR = 5
|
||||||
|
|
||||||
|
|
||||||
class AsyncLogger:
|
class AsyncLogger:
|
||||||
"""
|
"""
|
||||||
Asynchronous logger with support for colored console output and file logging.
|
Asynchronous logger with support for colored console output and file logging.
|
||||||
@@ -19,16 +20,16 @@ class AsyncLogger:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_ICONS = {
|
DEFAULT_ICONS = {
|
||||||
'INIT': '→',
|
"INIT": "→",
|
||||||
'READY': '✓',
|
"READY": "✓",
|
||||||
'FETCH': '↓',
|
"FETCH": "↓",
|
||||||
'SCRAPE': '◆',
|
"SCRAPE": "◆",
|
||||||
'EXTRACT': '■',
|
"EXTRACT": "■",
|
||||||
'COMPLETE': '●',
|
"COMPLETE": "●",
|
||||||
'ERROR': '×',
|
"ERROR": "×",
|
||||||
'DEBUG': '⋯',
|
"DEBUG": "⋯",
|
||||||
'INFO': 'ℹ',
|
"INFO": "ℹ",
|
||||||
'WARNING': '⚠',
|
"WARNING": "⚠",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_COLORS = {
|
DEFAULT_COLORS = {
|
||||||
@@ -46,7 +47,7 @@ class AsyncLogger:
|
|||||||
tag_width: int = 10,
|
tag_width: int = 10,
|
||||||
icons: Optional[Dict[str, str]] = None,
|
icons: Optional[Dict[str, str]] = None,
|
||||||
colors: Optional[Dict[LogLevel, str]] = None,
|
colors: Optional[Dict[LogLevel, str]] = None,
|
||||||
verbose: bool = True
|
verbose: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the logger.
|
Initialize the logger.
|
||||||
@@ -77,18 +78,20 @@ class AsyncLogger:
|
|||||||
|
|
||||||
def _get_icon(self, tag: str) -> str:
|
def _get_icon(self, tag: str) -> str:
|
||||||
"""Get the icon for a tag, defaulting to info icon if not found."""
|
"""Get the icon for a tag, defaulting to info icon if not found."""
|
||||||
return self.icons.get(tag, self.icons['INFO'])
|
return self.icons.get(tag, self.icons["INFO"])
|
||||||
|
|
||||||
def _write_to_file(self, message: str):
|
def _write_to_file(self, message: str):
|
||||||
"""Write a message to the log file if configured."""
|
"""Write a message to the log file if configured."""
|
||||||
if self.log_file:
|
if self.log_file:
|
||||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||||
with open(self.log_file, 'a', encoding='utf-8') as f:
|
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||||
# Strip ANSI color codes for file output
|
# Strip ANSI color codes for file output
|
||||||
clean_message = message.replace(Fore.RESET, '').replace(Style.RESET_ALL, '')
|
clean_message = message.replace(Fore.RESET, "").replace(
|
||||||
|
Style.RESET_ALL, ""
|
||||||
|
)
|
||||||
for color in vars(Fore).values():
|
for color in vars(Fore).values():
|
||||||
if isinstance(color, str):
|
if isinstance(color, str):
|
||||||
clean_message = clean_message.replace(color, '')
|
clean_message = clean_message.replace(color, "")
|
||||||
f.write(f"[{timestamp}] {clean_message}\n")
|
f.write(f"[{timestamp}] {clean_message}\n")
|
||||||
|
|
||||||
def _log(
|
def _log(
|
||||||
@@ -99,7 +102,7 @@ class AsyncLogger:
|
|||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
colors: Optional[Dict[str, str]] = None,
|
colors: Optional[Dict[str, str]] = None,
|
||||||
base_color: Optional[str] = None,
|
base_color: Optional[str] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Core logging method that handles message formatting and output.
|
Core logging method that handles message formatting and output.
|
||||||
@@ -128,12 +131,13 @@ class AsyncLogger:
|
|||||||
if key in params:
|
if key in params:
|
||||||
value_str = str(params[key])
|
value_str = str(params[key])
|
||||||
formatted_message = formatted_message.replace(
|
formatted_message = formatted_message.replace(
|
||||||
value_str,
|
value_str, f"{color}{value_str}{Style.RESET_ALL}"
|
||||||
f"{color}{value_str}{Style.RESET_ALL}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
formatted_message = f"LOGGING ERROR: Missing parameter {e} in message template"
|
formatted_message = (
|
||||||
|
f"LOGGING ERROR: Missing parameter {e} in message template"
|
||||||
|
)
|
||||||
level = LogLevel.ERROR
|
level = LogLevel.ERROR
|
||||||
else:
|
else:
|
||||||
formatted_message = message
|
formatted_message = message
|
||||||
@@ -175,7 +179,7 @@ class AsyncLogger:
|
|||||||
success: bool,
|
success: bool,
|
||||||
timing: float,
|
timing: float,
|
||||||
tag: str = "FETCH",
|
tag: str = "FETCH",
|
||||||
url_length: int = 50
|
url_length: int = 50,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Convenience method for logging URL fetch status.
|
Convenience method for logging URL fetch status.
|
||||||
@@ -195,20 +199,16 @@ class AsyncLogger:
|
|||||||
"url": url,
|
"url": url,
|
||||||
"url_length": url_length,
|
"url_length": url_length,
|
||||||
"status": success,
|
"status": success,
|
||||||
"timing": timing
|
"timing": timing,
|
||||||
},
|
},
|
||||||
colors={
|
colors={
|
||||||
"status": Fore.GREEN if success else Fore.RED,
|
"status": Fore.GREEN if success else Fore.RED,
|
||||||
"timing": Fore.YELLOW
|
"timing": Fore.YELLOW,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def error_status(
|
def error_status(
|
||||||
self,
|
self, url: str, error: str, tag: str = "ERROR", url_length: int = 50
|
||||||
url: str,
|
|
||||||
error: str,
|
|
||||||
tag: str = "ERROR",
|
|
||||||
url_length: int = 50
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Convenience method for logging error status.
|
Convenience method for logging error status.
|
||||||
@@ -223,9 +223,5 @@ class AsyncLogger:
|
|||||||
level=LogLevel.ERROR,
|
level=LogLevel.ERROR,
|
||||||
message="{url:.{url_length}}... | Error: {error}",
|
message="{url:.{url_length}}... | Error: {error}",
|
||||||
tag=tag,
|
tag=tag,
|
||||||
params={
|
params={"url": url, "url_length": url_length, "error": error},
|
||||||
"url": url,
|
|
||||||
"url_length": url_length,
|
|
||||||
"error": error
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
@@ -1,42 +1,47 @@
|
|||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from colorama import Fore
|
||||||
from colorama import init, Fore, Back, Style
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# from contextlib import nullcontext, asynccontextmanager
|
# from contextlib import nullcontext, asynccontextmanager
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult
|
from .models import CrawlResult, MarkdownGenerationResult, CrawlerTaskResult, DispatchResult, RateLimiter
|
||||||
from .async_database import async_db_manager
|
from .async_database import async_db_manager
|
||||||
from .chunking_strategy import *
|
from .chunking_strategy import * # noqa: F403
|
||||||
from .content_filter_strategy import *
|
from .chunking_strategy import RegexChunking, ChunkingStrategy, IdentityChunking
|
||||||
from .extraction_strategy import *
|
from .content_filter_strategy import * # noqa: F403
|
||||||
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse
|
from .content_filter_strategy import RelevantContentFilter
|
||||||
|
from .extraction_strategy import * # noqa: F403
|
||||||
|
from .extraction_strategy import NoExtractionStrategy, ExtractionStrategy
|
||||||
|
from .async_crawler_strategy import (
|
||||||
|
AsyncCrawlerStrategy,
|
||||||
|
AsyncPlaywrightCrawlerStrategy,
|
||||||
|
AsyncCrawlResponse,
|
||||||
|
)
|
||||||
from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode
|
from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode
|
||||||
from .markdown_generation_strategy import DefaultMarkdownGenerator, MarkdownGenerationStrategy
|
from .markdown_generation_strategy import (
|
||||||
from .content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy
|
DefaultMarkdownGenerator,
|
||||||
|
MarkdownGenerationStrategy,
|
||||||
|
)
|
||||||
from .async_logger import AsyncLogger
|
from .async_logger import AsyncLogger
|
||||||
from .async_configs import BrowserConfig, CrawlerRunConfig
|
from .async_configs import BrowserConfig, CrawlerRunConfig
|
||||||
from .async_dispatcher import *
|
from .async_dispatcher import * # noqa: F403
|
||||||
|
from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher
|
||||||
|
|
||||||
from .config import (
|
from .config import MIN_WORD_THRESHOLD
|
||||||
MIN_WORD_THRESHOLD,
|
|
||||||
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
|
||||||
URL_LOG_SHORTEN_LENGTH
|
|
||||||
)
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
sanitize_input_encode,
|
sanitize_input_encode,
|
||||||
InvalidCSSSelectorError,
|
InvalidCSSSelectorError,
|
||||||
format_html,
|
|
||||||
fast_format_html,
|
fast_format_html,
|
||||||
create_box_message
|
create_box_message,
|
||||||
|
get_error_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
import random
|
|
||||||
from .__version__ import __version__ as crawl4ai_version
|
from .__version__ import __version__ as crawl4ai_version
|
||||||
|
|
||||||
|
|
||||||
@@ -104,6 +109,7 @@ class AsyncWebCrawler:
|
|||||||
result = await crawler.arun(url="https://example.com", config=crawler_config)
|
result = await crawler.arun(url="https://example.com", config=crawler_config)
|
||||||
print(result.markdown)
|
print(result.markdown)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_domain_last_hit = {}
|
_domain_last_hit = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -131,10 +137,18 @@ class AsyncWebCrawler:
|
|||||||
# Handle browser configuration
|
# Handle browser configuration
|
||||||
browser_config = config
|
browser_config = config
|
||||||
if browser_config is not None:
|
if browser_config is not None:
|
||||||
if any(k in kwargs for k in ["browser_type", "headless", "viewport_width", "viewport_height"]):
|
if any(
|
||||||
|
k in kwargs
|
||||||
|
for k in [
|
||||||
|
"browser_type",
|
||||||
|
"headless",
|
||||||
|
"viewport_width",
|
||||||
|
"viewport_height",
|
||||||
|
]
|
||||||
|
):
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.",
|
message="Both browser_config and legacy browser parameters provided. browser_config will take precedence.",
|
||||||
tag="WARNING"
|
tag="WARNING",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Create browser config from kwargs for backwards compatibility
|
# Create browser config from kwargs for backwards compatibility
|
||||||
@@ -146,18 +160,15 @@ class AsyncWebCrawler:
|
|||||||
self.logger = AsyncLogger(
|
self.logger = AsyncLogger(
|
||||||
log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"),
|
log_file=os.path.join(base_directory, ".crawl4ai", "crawler.log"),
|
||||||
verbose=self.browser_config.verbose,
|
verbose=self.browser_config.verbose,
|
||||||
tag_width=10
|
tag_width=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Initialize crawler strategy
|
# Initialize crawler strategy
|
||||||
params = {
|
params = {k: v for k, v in kwargs.items() if k in ["browser_congig", "logger"]}
|
||||||
k:v for k, v in kwargs.items() if k in ['browser_congig', 'logger']
|
|
||||||
}
|
|
||||||
self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy(
|
self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy(
|
||||||
browser_config=browser_config,
|
browser_config=browser_config,
|
||||||
logger=self.logger,
|
logger=self.logger,
|
||||||
**params # Pass remaining kwargs for backwards compatibility
|
**params, # Pass remaining kwargs for backwards compatibility
|
||||||
)
|
)
|
||||||
|
|
||||||
# If craweler strategy doesnt have logger, use crawler logger
|
# If craweler strategy doesnt have logger, use crawler logger
|
||||||
@@ -172,7 +183,7 @@ class AsyncWebCrawler:
|
|||||||
"Use 'always_bypass_cache' instead. "
|
"Use 'always_bypass_cache' instead. "
|
||||||
"Pass warning=False to suppress this warning.",
|
"Pass warning=False to suppress this warning.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
self.always_bypass_cache = always_by_pass_cache
|
self.always_bypass_cache = always_by_pass_cache
|
||||||
else:
|
else:
|
||||||
@@ -323,7 +334,7 @@ class AsyncWebCrawler:
|
|||||||
"screenshot": screenshot,
|
"screenshot": screenshot,
|
||||||
"pdf": pdf,
|
"pdf": pdf,
|
||||||
"verbose": verbose,
|
"verbose": verbose,
|
||||||
**kwargs
|
**kwargs,
|
||||||
}
|
}
|
||||||
config = CrawlerRunConfig.from_kwargs(config_kwargs)
|
config = CrawlerRunConfig.from_kwargs(config_kwargs)
|
||||||
|
|
||||||
@@ -334,7 +345,7 @@ class AsyncWebCrawler:
|
|||||||
"Cache control boolean flags are deprecated and will be removed in version 0.5.0. "
|
"Cache control boolean flags are deprecated and will be removed in version 0.5.0. "
|
||||||
"Use 'cache_mode' parameter instead.",
|
"Use 'cache_mode' parameter instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert legacy parameters if cache_mode not provided
|
# Convert legacy parameters if cache_mode not provided
|
||||||
@@ -343,7 +354,7 @@ class AsyncWebCrawler:
|
|||||||
disable_cache=disable_cache,
|
disable_cache=disable_cache,
|
||||||
bypass_cache=bypass_cache,
|
bypass_cache=bypass_cache,
|
||||||
no_cache_read=no_cache_read,
|
no_cache_read=no_cache_read,
|
||||||
no_cache_write=no_cache_write
|
no_cache_write=no_cache_write,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Default to ENABLED if no cache mode specified
|
# Default to ENABLED if no cache mode specified
|
||||||
@@ -351,7 +362,9 @@ class AsyncWebCrawler:
|
|||||||
config.cache_mode = CacheMode.ENABLED
|
config.cache_mode = CacheMode.ENABLED
|
||||||
|
|
||||||
# Create cache context
|
# Create cache context
|
||||||
cache_context = CacheContext(url, config.cache_mode, self.always_bypass_cache)
|
cache_context = CacheContext(
|
||||||
|
url, config.cache_mode, self.always_bypass_cache
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize processing variables
|
# Initialize processing variables
|
||||||
async_response: AsyncCrawlResponse = None
|
async_response: AsyncCrawlResponse = None
|
||||||
@@ -367,8 +380,14 @@ class AsyncWebCrawler:
|
|||||||
|
|
||||||
if cached_result:
|
if cached_result:
|
||||||
html = sanitize_input_encode(cached_result.html)
|
html = sanitize_input_encode(cached_result.html)
|
||||||
extracted_content = sanitize_input_encode(cached_result.extracted_content or "")
|
extracted_content = sanitize_input_encode(
|
||||||
extracted_content = None if not extracted_content or extracted_content == "[]" else extracted_content
|
cached_result.extracted_content or ""
|
||||||
|
)
|
||||||
|
extracted_content = (
|
||||||
|
None
|
||||||
|
if not extracted_content or extracted_content == "[]"
|
||||||
|
else extracted_content
|
||||||
|
)
|
||||||
# If screenshot is requested but its not in cache, then set cache_result to None
|
# If screenshot is requested but its not in cache, then set cache_result to None
|
||||||
screenshot_data = cached_result.screenshot
|
screenshot_data = cached_result.screenshot
|
||||||
pdf_data = cached_result.pdf
|
pdf_data = cached_result.pdf
|
||||||
@@ -379,7 +398,7 @@ class AsyncWebCrawler:
|
|||||||
url=cache_context.display_url,
|
url=cache_context.display_url,
|
||||||
success=bool(html),
|
success=bool(html),
|
||||||
timing=time.perf_counter() - start_time,
|
timing=time.perf_counter() - start_time,
|
||||||
tag="FETCH"
|
tag="FETCH",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch fresh content if needed
|
# Fetch fresh content if needed
|
||||||
@@ -392,7 +411,7 @@ class AsyncWebCrawler:
|
|||||||
# Pass config to crawl method
|
# Pass config to crawl method
|
||||||
async_response = await self.crawler_strategy.crawl(
|
async_response = await self.crawler_strategy.crawl(
|
||||||
url,
|
url,
|
||||||
config=config # Pass the entire config object
|
config=config, # Pass the entire config object
|
||||||
)
|
)
|
||||||
|
|
||||||
html = sanitize_input_encode(async_response.html)
|
html = sanitize_input_encode(async_response.html)
|
||||||
@@ -404,7 +423,7 @@ class AsyncWebCrawler:
|
|||||||
url=cache_context.display_url,
|
url=cache_context.display_url,
|
||||||
success=bool(html),
|
success=bool(html),
|
||||||
timing=t2 - t1,
|
timing=t2 - t1,
|
||||||
tag="FETCH"
|
tag="FETCH",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the HTML content
|
# Process the HTML content
|
||||||
@@ -417,13 +436,15 @@ class AsyncWebCrawler:
|
|||||||
pdf_data=pdf_data,
|
pdf_data=pdf_data,
|
||||||
verbose=config.verbose,
|
verbose=config.verbose,
|
||||||
is_raw_html=True if url.startswith("raw:") else False,
|
is_raw_html=True if url.startswith("raw:") else False,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
crawl_result.status_code = async_response.status_code
|
crawl_result.status_code = async_response.status_code
|
||||||
crawl_result.response_headers = async_response.response_headers
|
crawl_result.response_headers = async_response.response_headers
|
||||||
crawl_result.downloaded_files = async_response.downloaded_files
|
crawl_result.downloaded_files = async_response.downloaded_files
|
||||||
crawl_result.ssl_certificate = async_response.ssl_certificate # Add SSL certificate
|
crawl_result.ssl_certificate = (
|
||||||
|
async_response.ssl_certificate
|
||||||
|
) # Add SSL certificate
|
||||||
|
|
||||||
# # Check and set values from async_response to crawl_result
|
# # Check and set values from async_response to crawl_result
|
||||||
# try:
|
# try:
|
||||||
@@ -446,7 +467,7 @@ class AsyncWebCrawler:
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
crawl_result.success = bool(html)
|
crawl_result.success = bool(html)
|
||||||
crawl_result.session_id = getattr(config, 'session_id', None)
|
crawl_result.session_id = getattr(config, "session_id", None)
|
||||||
|
|
||||||
self.logger.success(
|
self.logger.success(
|
||||||
message="{url:.50}... | Status: {status} | Total: {timing}",
|
message="{url:.50}... | Status: {status} | Total: {timing}",
|
||||||
@@ -454,12 +475,12 @@ class AsyncWebCrawler:
|
|||||||
params={
|
params={
|
||||||
"url": cache_context.display_url,
|
"url": cache_context.display_url,
|
||||||
"status": crawl_result.success,
|
"status": crawl_result.success,
|
||||||
"timing": f"{time.perf_counter() - start_time:.2f}s"
|
"timing": f"{time.perf_counter() - start_time:.2f}s",
|
||||||
},
|
},
|
||||||
colors={
|
colors={
|
||||||
"status": Fore.GREEN if crawl_result.success else Fore.RED,
|
"status": Fore.GREEN if crawl_result.success else Fore.RED,
|
||||||
"timing": Fore.YELLOW
|
"timing": Fore.YELLOW,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update cache if appropriate
|
# Update cache if appropriate
|
||||||
@@ -475,16 +496,13 @@ class AsyncWebCrawler:
|
|||||||
params={
|
params={
|
||||||
"url": cache_context.display_url,
|
"url": cache_context.display_url,
|
||||||
"status": True,
|
"status": True,
|
||||||
"timing": f"{time.perf_counter() - start_time:.2f}s"
|
"timing": f"{time.perf_counter() - start_time:.2f}s",
|
||||||
},
|
},
|
||||||
colors={
|
colors={"status": Fore.GREEN, "timing": Fore.YELLOW},
|
||||||
"status": Fore.GREEN,
|
|
||||||
"timing": Fore.YELLOW
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cached_result.success = bool(html)
|
cached_result.success = bool(html)
|
||||||
cached_result.session_id = getattr(config, 'session_id', None)
|
cached_result.session_id = getattr(config, "session_id", None)
|
||||||
return cached_result
|
return cached_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -502,14 +520,11 @@ class AsyncWebCrawler:
|
|||||||
self.logger.error_status(
|
self.logger.error_status(
|
||||||
url=url,
|
url=url,
|
||||||
error=create_box_message(error_message, type="error"),
|
error=create_box_message(error_message, type="error"),
|
||||||
tag="ERROR"
|
tag="ERROR",
|
||||||
)
|
)
|
||||||
|
|
||||||
return CrawlResult(
|
return CrawlResult(
|
||||||
url=url,
|
url=url, html="", success=False, error_message=error_message
|
||||||
html="",
|
|
||||||
success=False,
|
|
||||||
error_message=error_message
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aprocess_html(
|
async def aprocess_html(
|
||||||
@@ -553,21 +568,19 @@ class AsyncWebCrawler:
|
|||||||
# add keys from kwargs to params that doesn't exist in params
|
# add keys from kwargs to params that doesn't exist in params
|
||||||
params.update({k: v for k, v in kwargs.items() if k not in params.keys()})
|
params.update({k: v for k, v in kwargs.items() if k not in params.keys()})
|
||||||
|
|
||||||
result = scraping_strategy.scrap(
|
result = scraping_strategy.scrap(url, html, **params)
|
||||||
url,
|
|
||||||
html,
|
|
||||||
**params
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}")
|
raise ValueError(
|
||||||
|
f"Process HTML, Failed to extract content from the website: {url}"
|
||||||
|
)
|
||||||
|
|
||||||
except InvalidCSSSelectorError as e:
|
except InvalidCSSSelectorError as e:
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}")
|
raise ValueError(
|
||||||
|
f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Extract results - handle both dict and ScrapingResult
|
# Extract results - handle both dict and ScrapingResult
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
@@ -582,17 +595,21 @@ class AsyncWebCrawler:
|
|||||||
metadata = result.metadata
|
metadata = result.metadata
|
||||||
|
|
||||||
# Markdown Generation
|
# Markdown Generation
|
||||||
markdown_generator: Optional[MarkdownGenerationStrategy] = config.markdown_generator or DefaultMarkdownGenerator()
|
markdown_generator: Optional[MarkdownGenerationStrategy] = (
|
||||||
|
config.markdown_generator or DefaultMarkdownGenerator()
|
||||||
|
)
|
||||||
|
|
||||||
# Uncomment if by default we want to use PruningContentFilter
|
# Uncomment if by default we want to use PruningContentFilter
|
||||||
# if not config.content_filter and not markdown_generator.content_filter:
|
# if not config.content_filter and not markdown_generator.content_filter:
|
||||||
# markdown_generator.content_filter = PruningContentFilter()
|
# markdown_generator.content_filter = PruningContentFilter()
|
||||||
|
|
||||||
markdown_result: MarkdownGenerationResult = markdown_generator.generate_markdown(
|
markdown_result: MarkdownGenerationResult = (
|
||||||
|
markdown_generator.generate_markdown(
|
||||||
cleaned_html=cleaned_html,
|
cleaned_html=cleaned_html,
|
||||||
base_url=url,
|
base_url=url,
|
||||||
# html2text_options=kwargs.get('html2text', {})
|
# html2text_options=kwargs.get('html2text', {})
|
||||||
)
|
)
|
||||||
|
)
|
||||||
markdown_v2 = markdown_result
|
markdown_v2 = markdown_result
|
||||||
markdown = sanitize_input_encode(markdown_result.raw_markdown)
|
markdown = sanitize_input_encode(markdown_result.raw_markdown)
|
||||||
|
|
||||||
@@ -600,15 +617,15 @@ class AsyncWebCrawler:
|
|||||||
self.logger.info(
|
self.logger.info(
|
||||||
message="Processed {url:.50}... | Time: {timing}ms",
|
message="Processed {url:.50}... | Time: {timing}ms",
|
||||||
tag="SCRAPE",
|
tag="SCRAPE",
|
||||||
params={
|
params={"url": _url, "timing": int((time.perf_counter() - t1) * 1000)},
|
||||||
"url": _url,
|
|
||||||
"timing": int((time.perf_counter() - t1) * 1000)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle content extraction if needed
|
# Handle content extraction if needed
|
||||||
if (not bool(extracted_content) and config.extraction_strategy and not isinstance(config.extraction_strategy, NoExtractionStrategy)):
|
if (
|
||||||
|
not bool(extracted_content)
|
||||||
|
and config.extraction_strategy
|
||||||
|
and not isinstance(config.extraction_strategy, NoExtractionStrategy)
|
||||||
|
):
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
|
|
||||||
# Choose content based on input_format
|
# Choose content based on input_format
|
||||||
@@ -617,30 +634,33 @@ class AsyncWebCrawler:
|
|||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
message="Fit markdown requested but not available. Falling back to raw markdown.",
|
message="Fit markdown requested but not available. Falling back to raw markdown.",
|
||||||
tag="EXTRACT",
|
tag="EXTRACT",
|
||||||
params={"url": _url}
|
params={"url": _url},
|
||||||
)
|
)
|
||||||
content_format = "markdown"
|
content_format = "markdown"
|
||||||
|
|
||||||
content = {
|
content = {
|
||||||
"markdown": markdown,
|
"markdown": markdown,
|
||||||
"html": html,
|
"html": html,
|
||||||
"fit_markdown": markdown_result.raw_markdown
|
"fit_markdown": markdown_result.raw_markdown,
|
||||||
}.get(content_format, markdown)
|
}.get(content_format, markdown)
|
||||||
|
|
||||||
# Use IdentityChunking for HTML input, otherwise use provided chunking strategy
|
# Use IdentityChunking for HTML input, otherwise use provided chunking strategy
|
||||||
chunking = IdentityChunking() if content_format == "html" else config.chunking_strategy
|
chunking = (
|
||||||
|
IdentityChunking()
|
||||||
|
if content_format == "html"
|
||||||
|
else config.chunking_strategy
|
||||||
|
)
|
||||||
sections = chunking.chunk(content)
|
sections = chunking.chunk(content)
|
||||||
extracted_content = config.extraction_strategy.run(url, sections)
|
extracted_content = config.extraction_strategy.run(url, sections)
|
||||||
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
|
extracted_content = json.dumps(
|
||||||
|
extracted_content, indent=4, default=str, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
# Log extraction completion
|
# Log extraction completion
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
message="Completed for {url:.50}... | Time: {timing}s",
|
message="Completed for {url:.50}... | Time: {timing}s",
|
||||||
tag="EXTRACT",
|
tag="EXTRACT",
|
||||||
params={
|
params={"url": _url, "timing": time.perf_counter() - t1},
|
||||||
"url": _url,
|
|
||||||
"timing": time.perf_counter() - t1
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle screenshot and PDF data
|
# Handle screenshot and PDF data
|
||||||
@@ -739,7 +759,7 @@ class AsyncWebCrawler:
|
|||||||
screenshot=screenshot,
|
screenshot=screenshot,
|
||||||
pdf=pdf,
|
pdf=pdf,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# # Initialize the dispatcher with the selected strategy
|
# # Initialize the dispatcher with the selected strategy
|
||||||
@@ -754,17 +774,13 @@ class AsyncWebCrawler:
|
|||||||
dispatcher = MemoryAdaptiveDispatcher(
|
dispatcher = MemoryAdaptiveDispatcher(
|
||||||
self,
|
self,
|
||||||
rate_limiter=RateLimiter(
|
rate_limiter=RateLimiter(
|
||||||
base_delay=(1.0, 3.0),
|
base_delay=(1.0, 3.0), max_delay=60.0, max_retries=3
|
||||||
max_delay=60.0,
|
),
|
||||||
max_retries=3
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the URLs through the dispatcher
|
# Run the URLs through the dispatcher
|
||||||
_results: List[CrawlerTaskResult] = await dispatcher.run_urls(
|
_results: List[CrawlerTaskResult] = await dispatcher.run_urls(
|
||||||
crawler=self,
|
crawler=self, urls=urls, config=config
|
||||||
urls=urls,
|
|
||||||
config=config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results: CrawlResult = []
|
results: CrawlResult = []
|
||||||
@@ -776,7 +792,7 @@ class AsyncWebCrawler:
|
|||||||
peak_memory=res.peak_memory,
|
peak_memory=res.peak_memory,
|
||||||
start_time=res.start_time,
|
start_time=res.start_time,
|
||||||
end_time=res.end_time,
|
end_time=res.end_time,
|
||||||
error_message=res.error_message
|
error_message=res.error_message,
|
||||||
)
|
)
|
||||||
_res.dispatch_result = dispatch_result
|
_res.dispatch_result = dispatch_result
|
||||||
results.append(_res)
|
results.append(_res)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class CacheMode(Enum):
|
|||||||
- WRITE_ONLY: Only write to cache, don't read
|
- WRITE_ONLY: Only write to cache, don't read
|
||||||
- BYPASS: Bypass cache for this operation
|
- BYPASS: Bypass cache for this operation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ENABLED = "enabled"
|
ENABLED = "enabled"
|
||||||
DISABLED = "disabled"
|
DISABLED = "disabled"
|
||||||
READ_ONLY = "read_only"
|
READ_ONLY = "read_only"
|
||||||
@@ -36,6 +37,7 @@ class CacheContext:
|
|||||||
is_raw_html (bool): True if the URL is raw HTML, False otherwise.
|
is_raw_html (bool): True if the URL is raw HTML, False otherwise.
|
||||||
_url_display (str): The display name for the URL (web, local file, or raw HTML).
|
_url_display (str): The display name for the URL (web, local file, or raw HTML).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False):
|
def __init__(self, url: str, cache_mode: CacheMode, always_bypass: bool = False):
|
||||||
"""
|
"""
|
||||||
Initializes the CacheContext with the provided URL and cache mode.
|
Initializes the CacheContext with the provided URL and cache mode.
|
||||||
@@ -48,8 +50,8 @@ class CacheContext:
|
|||||||
self.url = url
|
self.url = url
|
||||||
self.cache_mode = cache_mode
|
self.cache_mode = cache_mode
|
||||||
self.always_bypass = always_bypass
|
self.always_bypass = always_bypass
|
||||||
self.is_cacheable = url.startswith(('http://', 'https://', 'file://'))
|
self.is_cacheable = url.startswith(("http://", "https://", "file://"))
|
||||||
self.is_web_url = url.startswith(('http://', 'https://'))
|
self.is_web_url = url.startswith(("http://", "https://"))
|
||||||
self.is_local_file = url.startswith("file://")
|
self.is_local_file = url.startswith("file://")
|
||||||
self.is_raw_html = url.startswith("raw:")
|
self.is_raw_html = url.startswith("raw:")
|
||||||
self._url_display = url if not self.is_raw_html else "Raw HTML"
|
self._url_display = url if not self.is_raw_html else "Raw HTML"
|
||||||
@@ -94,7 +96,7 @@ def _legacy_to_cache_mode(
|
|||||||
disable_cache: bool = False,
|
disable_cache: bool = False,
|
||||||
bypass_cache: bool = False,
|
bypass_cache: bool = False,
|
||||||
no_cache_read: bool = False,
|
no_cache_read: bool = False,
|
||||||
no_cache_write: bool = False
|
no_cache_write: bool = False,
|
||||||
) -> CacheMode:
|
) -> CacheMode:
|
||||||
"""
|
"""
|
||||||
Converts legacy cache parameters to the new CacheMode enum.
|
Converts legacy cache parameters to the new CacheMode enum.
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
import string
|
import string
|
||||||
from .model_loader import load_nltk_punkt
|
from .model_loader import load_nltk_punkt
|
||||||
from .utils import *
|
|
||||||
|
|
||||||
# Define the abstract base class for chunking strategies
|
# Define the abstract base class for chunking strategies
|
||||||
class ChunkingStrategy(ABC):
|
class ChunkingStrategy(ABC):
|
||||||
@@ -24,19 +24,23 @@ class ChunkingStrategy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Create an identity chunking strategy f(x) = [x]
|
# Create an identity chunking strategy f(x) = [x]
|
||||||
class IdentityChunking(ChunkingStrategy):
|
class IdentityChunking(ChunkingStrategy):
|
||||||
"""
|
"""
|
||||||
Chunking strategy that returns the input text as a single chunk.
|
Chunking strategy that returns the input text as a single chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def chunk(self, text: str) -> list:
|
def chunk(self, text: str) -> list:
|
||||||
return [text]
|
return [text]
|
||||||
|
|
||||||
|
|
||||||
# Regex-based chunking
|
# Regex-based chunking
|
||||||
class RegexChunking(ChunkingStrategy):
|
class RegexChunking(ChunkingStrategy):
|
||||||
"""
|
"""
|
||||||
Chunking strategy that splits text based on regular expression patterns.
|
Chunking strategy that splits text based on regular expression patterns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, patterns=None, **kwargs):
|
def __init__(self, patterns=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize the RegexChunking object.
|
Initialize the RegexChunking object.
|
||||||
@@ -45,7 +49,7 @@ class RegexChunking(ChunkingStrategy):
|
|||||||
patterns (list): A list of regular expression patterns to split text.
|
patterns (list): A list of regular expression patterns to split text.
|
||||||
"""
|
"""
|
||||||
if patterns is None:
|
if patterns is None:
|
||||||
patterns = [r'\n\n'] # Default split pattern
|
patterns = [r"\n\n"] # Default split pattern
|
||||||
self.patterns = patterns
|
self.patterns = patterns
|
||||||
|
|
||||||
def chunk(self, text: str) -> list:
|
def chunk(self, text: str) -> list:
|
||||||
@@ -57,18 +61,19 @@ class RegexChunking(ChunkingStrategy):
|
|||||||
paragraphs = new_paragraphs
|
paragraphs = new_paragraphs
|
||||||
return paragraphs
|
return paragraphs
|
||||||
|
|
||||||
|
|
||||||
# NLP-based sentence chunking
|
# NLP-based sentence chunking
|
||||||
class NlpSentenceChunking(ChunkingStrategy):
|
class NlpSentenceChunking(ChunkingStrategy):
|
||||||
"""
|
"""
|
||||||
Chunking strategy that splits text into sentences using NLTK's sentence tokenizer.
|
Chunking strategy that splits text into sentences using NLTK's sentence tokenizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize the NlpSentenceChunking object.
|
Initialize the NlpSentenceChunking object.
|
||||||
"""
|
"""
|
||||||
load_nltk_punkt()
|
load_nltk_punkt()
|
||||||
|
|
||||||
|
|
||||||
def chunk(self, text: str) -> list:
|
def chunk(self, text: str) -> list:
|
||||||
# Improved regex for sentence splitting
|
# Improved regex for sentence splitting
|
||||||
# sentence_endings = re.compile(
|
# sentence_endings = re.compile(
|
||||||
@@ -77,11 +82,13 @@ class NlpSentenceChunking(ChunkingStrategy):
|
|||||||
# sentences = sentence_endings.split(text)
|
# sentences = sentence_endings.split(text)
|
||||||
# sens = [sent.strip() for sent in sentences if sent]
|
# sens = [sent.strip() for sent in sentences if sent]
|
||||||
from nltk.tokenize import sent_tokenize
|
from nltk.tokenize import sent_tokenize
|
||||||
|
|
||||||
sentences = sent_tokenize(text)
|
sentences = sent_tokenize(text)
|
||||||
sens = [sent.strip() for sent in sentences]
|
sens = [sent.strip() for sent in sentences]
|
||||||
|
|
||||||
return list(set(sens))
|
return list(set(sens))
|
||||||
|
|
||||||
|
|
||||||
# Topic-based segmentation using TextTiling
|
# Topic-based segmentation using TextTiling
|
||||||
class TopicSegmentationChunking(ChunkingStrategy):
|
class TopicSegmentationChunking(ChunkingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -100,6 +107,7 @@ class TopicSegmentationChunking(ChunkingStrategy):
|
|||||||
num_keywords (int): The number of keywords to extract for each topic segment.
|
num_keywords (int): The number of keywords to extract for each topic segment.
|
||||||
"""
|
"""
|
||||||
import nltk as nl
|
import nltk as nl
|
||||||
|
|
||||||
self.tokenizer = nl.tokenize.TextTilingTokenizer()
|
self.tokenizer = nl.tokenize.TextTilingTokenizer()
|
||||||
self.num_keywords = num_keywords
|
self.num_keywords = num_keywords
|
||||||
|
|
||||||
@@ -111,8 +119,14 @@ class TopicSegmentationChunking(ChunkingStrategy):
|
|||||||
def extract_keywords(self, text: str) -> list:
|
def extract_keywords(self, text: str) -> list:
|
||||||
# Tokenize and remove stopwords and punctuation
|
# Tokenize and remove stopwords and punctuation
|
||||||
import nltk as nl
|
import nltk as nl
|
||||||
|
|
||||||
tokens = nl.toknize.word_tokenize(text)
|
tokens = nl.toknize.word_tokenize(text)
|
||||||
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
|
tokens = [
|
||||||
|
token.lower()
|
||||||
|
for token in tokens
|
||||||
|
if token not in nl.corpus.stopwords.words("english")
|
||||||
|
and token not in string.punctuation
|
||||||
|
]
|
||||||
|
|
||||||
# Calculate frequency distribution
|
# Calculate frequency distribution
|
||||||
freq_dist = Counter(tokens)
|
freq_dist = Counter(tokens)
|
||||||
@@ -123,9 +137,12 @@ class TopicSegmentationChunking(ChunkingStrategy):
|
|||||||
# Segment the text into topics
|
# Segment the text into topics
|
||||||
segments = self.chunk(text)
|
segments = self.chunk(text)
|
||||||
# Extract keywords for each topic segment
|
# Extract keywords for each topic segment
|
||||||
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
|
segments_with_topics = [
|
||||||
|
(segment, self.extract_keywords(segment)) for segment in segments
|
||||||
|
]
|
||||||
return segments_with_topics
|
return segments_with_topics
|
||||||
|
|
||||||
|
|
||||||
# Fixed-length word chunks
|
# Fixed-length word chunks
|
||||||
class FixedLengthWordChunking(ChunkingStrategy):
|
class FixedLengthWordChunking(ChunkingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -136,6 +153,7 @@ class FixedLengthWordChunking(ChunkingStrategy):
|
|||||||
2. Create chunks of fixed length
|
2. Create chunks of fixed length
|
||||||
3. Return the list of chunks
|
3. Return the list of chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chunk_size=100, **kwargs):
|
def __init__(self, chunk_size=100, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize the fixed-length word chunking strategy with the given chunk size.
|
Initialize the fixed-length word chunking strategy with the given chunk size.
|
||||||
@@ -147,7 +165,11 @@ class FixedLengthWordChunking(ChunkingStrategy):
|
|||||||
|
|
||||||
def chunk(self, text: str) -> list:
|
def chunk(self, text: str) -> list:
|
||||||
words = text.split()
|
words = text.split()
|
||||||
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
|
return [
|
||||||
|
" ".join(words[i : i + self.chunk_size])
|
||||||
|
for i in range(0, len(words), self.chunk_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Sliding window chunking
|
# Sliding window chunking
|
||||||
class SlidingWindowChunking(ChunkingStrategy):
|
class SlidingWindowChunking(ChunkingStrategy):
|
||||||
@@ -159,6 +181,7 @@ class SlidingWindowChunking(ChunkingStrategy):
|
|||||||
2. Create chunks of fixed length
|
2. Create chunks of fixed length
|
||||||
3. Return the list of chunks
|
3. Return the list of chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size=100, step=50, **kwargs):
|
def __init__(self, window_size=100, step=50, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize the sliding window chunking strategy with the given window size and
|
Initialize the sliding window chunking strategy with the given window size and
|
||||||
@@ -179,15 +202,16 @@ class SlidingWindowChunking(ChunkingStrategy):
|
|||||||
return [text]
|
return [text]
|
||||||
|
|
||||||
for i in range(0, len(words) - self.window_size + 1, self.step):
|
for i in range(0, len(words) - self.window_size + 1, self.step):
|
||||||
chunk = ' '.join(words[i:i + self.window_size])
|
chunk = " ".join(words[i : i + self.window_size])
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
# Handle the last chunk if it doesn't align perfectly
|
# Handle the last chunk if it doesn't align perfectly
|
||||||
if i + self.window_size < len(words):
|
if i + self.window_size < len(words):
|
||||||
chunks.append(' '.join(words[-self.window_size:]))
|
chunks.append(" ".join(words[-self.window_size :]))
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
class OverlappingWindowChunking(ChunkingStrategy):
|
class OverlappingWindowChunking(ChunkingStrategy):
|
||||||
"""
|
"""
|
||||||
Chunking strategy that splits text into overlapping word chunks.
|
Chunking strategy that splits text into overlapping word chunks.
|
||||||
@@ -198,6 +222,7 @@ class OverlappingWindowChunking(ChunkingStrategy):
|
|||||||
3. Slide the window by the overlap size
|
3. Slide the window by the overlap size
|
||||||
4. Return the list of chunks
|
4. Return the list of chunks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size=1000, overlap=100, **kwargs):
|
def __init__(self, window_size=1000, overlap=100, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize the overlapping window chunking strategy with the given window size and
|
Initialize the overlapping window chunking strategy with the given window size and
|
||||||
@@ -220,7 +245,7 @@ class OverlappingWindowChunking(ChunkingStrategy):
|
|||||||
start = 0
|
start = 0
|
||||||
while start < len(words):
|
while start < len(words):
|
||||||
end = start + self.window_size
|
end = start + self.window_size
|
||||||
chunk = ' '.join(words[start:end])
|
chunk = " ".join(words[start:end])
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
if end >= len(words):
|
if end >= len(words):
|
||||||
|
|||||||
@@ -8,14 +8,21 @@ from .async_logger import AsyncLogger
|
|||||||
logger = AsyncLogger(verbose=True)
|
logger = AsyncLogger(verbose=True)
|
||||||
docs_manager = DocsManager(logger)
|
docs_manager = DocsManager(logger)
|
||||||
|
|
||||||
|
|
||||||
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
|
def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
|
||||||
"""Print formatted table with headers and rows"""
|
"""Print formatted table with headers and rows"""
|
||||||
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
|
widths = [max(len(str(cell)) for cell in col) for col in zip(headers, *rows)]
|
||||||
border = '+' + '+'.join('-' * (w + 2 * padding) for w in widths) + '+'
|
border = "+" + "+".join("-" * (w + 2 * padding) for w in widths) + "+"
|
||||||
|
|
||||||
def format_row(row):
|
def format_row(row):
|
||||||
return '|' + '|'.join(f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
|
return (
|
||||||
for cell, w in zip(row, widths)) + '|'
|
"|"
|
||||||
|
+ "|".join(
|
||||||
|
f"{' ' * padding}{str(cell):<{w}}{' ' * padding}"
|
||||||
|
for cell, w in zip(row, widths)
|
||||||
|
)
|
||||||
|
+ "|"
|
||||||
|
)
|
||||||
|
|
||||||
click.echo(border)
|
click.echo(border)
|
||||||
click.echo(format_row(headers))
|
click.echo(format_row(headers))
|
||||||
@@ -24,19 +31,24 @@ def print_table(headers: List[str], rows: List[List[str]], padding: int = 2):
|
|||||||
click.echo(format_row(row))
|
click.echo(format_row(row))
|
||||||
click.echo(border)
|
click.echo(border)
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def cli():
|
def cli():
|
||||||
"""Crawl4AI Command Line Interface"""
|
"""Crawl4AI Command Line Interface"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@cli.group()
|
@cli.group()
|
||||||
def docs():
|
def docs():
|
||||||
"""Documentation operations"""
|
"""Documentation operations"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@docs.command()
|
@docs.command()
|
||||||
@click.argument('sections', nargs=-1)
|
@click.argument("sections", nargs=-1)
|
||||||
@click.option('--mode', type=click.Choice(['extended', 'condensed']), default='extended')
|
@click.option(
|
||||||
|
"--mode", type=click.Choice(["extended", "condensed"]), default="extended"
|
||||||
|
)
|
||||||
def combine(sections: tuple, mode: str):
|
def combine(sections: tuple, mode: str):
|
||||||
"""Combine documentation sections"""
|
"""Combine documentation sections"""
|
||||||
try:
|
try:
|
||||||
@@ -46,16 +58,17 @@ def combine(sections: tuple, mode: str):
|
|||||||
logger.error(str(e), tag="ERROR")
|
logger.error(str(e), tag="ERROR")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
@docs.command()
|
@docs.command()
|
||||||
@click.argument('query')
|
@click.argument("query")
|
||||||
@click.option('--top-k', '-k', default=5)
|
@click.option("--top-k", "-k", default=5)
|
||||||
@click.option('--build-index', is_flag=True, help='Build index if missing')
|
@click.option("--build-index", is_flag=True, help="Build index if missing")
|
||||||
def search(query: str, top_k: int, build_index: bool):
|
def search(query: str, top_k: int, build_index: bool):
|
||||||
"""Search documentation"""
|
"""Search documentation"""
|
||||||
try:
|
try:
|
||||||
result = docs_manager.search(query, top_k)
|
result = docs_manager.search(query, top_k)
|
||||||
if result == "No search index available. Call build_search_index() first.":
|
if result == "No search index available. Call build_search_index() first.":
|
||||||
if build_index or click.confirm('No search index found. Build it now?'):
|
if build_index or click.confirm("No search index found. Build it now?"):
|
||||||
asyncio.run(docs_manager.llm_text.generate_index_files())
|
asyncio.run(docs_manager.llm_text.generate_index_files())
|
||||||
result = docs_manager.search(query, top_k)
|
result = docs_manager.search(query, top_k)
|
||||||
click.echo(result)
|
click.echo(result)
|
||||||
@@ -63,6 +76,7 @@ def search(query: str, top_k: int, build_index: bool):
|
|||||||
click.echo(f"Error: {str(e)}", err=True)
|
click.echo(f"Error: {str(e)}", err=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
@docs.command()
|
@docs.command()
|
||||||
def update():
|
def update():
|
||||||
"""Update docs from GitHub"""
|
"""Update docs from GitHub"""
|
||||||
@@ -73,22 +87,25 @@ def update():
|
|||||||
click.echo(f"Error: {str(e)}", err=True)
|
click.echo(f"Error: {str(e)}", err=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
@docs.command()
|
@docs.command()
|
||||||
@click.option('--force-facts', is_flag=True, help='Force regenerate fact files')
|
@click.option("--force-facts", is_flag=True, help="Force regenerate fact files")
|
||||||
@click.option('--clear-cache', is_flag=True, help='Clear BM25 cache')
|
@click.option("--clear-cache", is_flag=True, help="Clear BM25 cache")
|
||||||
def index(force_facts: bool, clear_cache: bool):
|
def index(force_facts: bool, clear_cache: bool):
|
||||||
"""Build or rebuild search indexes"""
|
"""Build or rebuild search indexes"""
|
||||||
try:
|
try:
|
||||||
asyncio.run(docs_manager.ensure_docs_exist())
|
asyncio.run(docs_manager.ensure_docs_exist())
|
||||||
asyncio.run(docs_manager.llm_text.generate_index_files(
|
asyncio.run(
|
||||||
force_generate_facts=force_facts,
|
docs_manager.llm_text.generate_index_files(
|
||||||
clear_bm25_cache=clear_cache
|
force_generate_facts=force_facts, clear_bm25_cache=clear_cache
|
||||||
))
|
)
|
||||||
|
)
|
||||||
click.echo("Search indexes built successfully")
|
click.echo("Search indexes built successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"Error: {str(e)}", err=True)
|
click.echo(f"Error: {str(e)}", err=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
# Add docs list command
|
# Add docs list command
|
||||||
@docs.command()
|
@docs.command()
|
||||||
def list():
|
def list():
|
||||||
@@ -101,5 +118,6 @@ def list():
|
|||||||
click.echo(f"Error: {str(e)}", err=True)
|
click.echo(f"Error: {str(e)}", err=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
@@ -30,18 +30,40 @@ WORD_TOKEN_RATE = 1.3
|
|||||||
MIN_WORD_THRESHOLD = 1
|
MIN_WORD_THRESHOLD = 1
|
||||||
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1
|
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD = 1
|
||||||
|
|
||||||
IMPORTANT_ATTRS = ['src', 'href', 'alt', 'title', 'width', 'height']
|
IMPORTANT_ATTRS = ["src", "href", "alt", "title", "width", "height"]
|
||||||
ONLY_TEXT_ELIGIBLE_TAGS = ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark']
|
ONLY_TEXT_ELIGIBLE_TAGS = [
|
||||||
|
"b",
|
||||||
|
"i",
|
||||||
|
"u",
|
||||||
|
"span",
|
||||||
|
"del",
|
||||||
|
"ins",
|
||||||
|
"sub",
|
||||||
|
"sup",
|
||||||
|
"strong",
|
||||||
|
"em",
|
||||||
|
"code",
|
||||||
|
"kbd",
|
||||||
|
"var",
|
||||||
|
"s",
|
||||||
|
"q",
|
||||||
|
"abbr",
|
||||||
|
"cite",
|
||||||
|
"dfn",
|
||||||
|
"time",
|
||||||
|
"small",
|
||||||
|
"mark",
|
||||||
|
]
|
||||||
SOCIAL_MEDIA_DOMAINS = [
|
SOCIAL_MEDIA_DOMAINS = [
|
||||||
'facebook.com',
|
"facebook.com",
|
||||||
'twitter.com',
|
"twitter.com",
|
||||||
'x.com',
|
"x.com",
|
||||||
'linkedin.com',
|
"linkedin.com",
|
||||||
'instagram.com',
|
"instagram.com",
|
||||||
'pinterest.com',
|
"pinterest.com",
|
||||||
'tiktok.com',
|
"tiktok.com",
|
||||||
'snapchat.com',
|
"snapchat.com",
|
||||||
'reddit.com',
|
"reddit.com",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Threshold for the Image extraction - Range is 1 to 6
|
# Threshold for the Image extraction - Range is 1 to 6
|
||||||
|
|||||||
@@ -1,44 +1,85 @@
|
|||||||
import re
|
import re
|
||||||
from bs4 import BeautifulSoup, Tag
|
from bs4 import BeautifulSoup, Tag
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple
|
||||||
from rank_bm25 import BM25Okapi
|
from rank_bm25 import BM25Okapi
|
||||||
from time import perf_counter
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from bs4 import BeautifulSoup, NavigableString, Tag, Comment
|
from bs4 import NavigableString, Comment
|
||||||
from .utils import clean_tokens
|
from .utils import clean_tokens
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import math
|
import math
|
||||||
from snowballstemmer import stemmer
|
from snowballstemmer import stemmer
|
||||||
|
|
||||||
|
|
||||||
class RelevantContentFilter(ABC):
|
class RelevantContentFilter(ABC):
|
||||||
"""Abstract base class for content filtering strategies"""
|
"""Abstract base class for content filtering strategies"""
|
||||||
|
|
||||||
def __init__(self, user_query: str = None):
|
def __init__(self, user_query: str = None):
|
||||||
self.user_query = user_query
|
self.user_query = user_query
|
||||||
self.included_tags = {
|
self.included_tags = {
|
||||||
# Primary structure
|
# Primary structure
|
||||||
'article', 'main', 'section', 'div',
|
"article",
|
||||||
|
"main",
|
||||||
|
"section",
|
||||||
|
"div",
|
||||||
# List structures
|
# List structures
|
||||||
'ul', 'ol', 'li', 'dl', 'dt', 'dd',
|
"ul",
|
||||||
|
"ol",
|
||||||
|
"li",
|
||||||
|
"dl",
|
||||||
|
"dt",
|
||||||
|
"dd",
|
||||||
# Text content
|
# Text content
|
||||||
'p', 'span', 'blockquote', 'pre', 'code',
|
"p",
|
||||||
|
"span",
|
||||||
|
"blockquote",
|
||||||
|
"pre",
|
||||||
|
"code",
|
||||||
# Headers
|
# Headers
|
||||||
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
|
"h1",
|
||||||
|
"h2",
|
||||||
|
"h3",
|
||||||
|
"h4",
|
||||||
|
"h5",
|
||||||
|
"h6",
|
||||||
# Tables
|
# Tables
|
||||||
'table', 'thead', 'tbody', 'tr', 'td', 'th',
|
"table",
|
||||||
|
"thead",
|
||||||
|
"tbody",
|
||||||
|
"tr",
|
||||||
|
"td",
|
||||||
|
"th",
|
||||||
# Other semantic elements
|
# Other semantic elements
|
||||||
'figure', 'figcaption', 'details', 'summary',
|
"figure",
|
||||||
|
"figcaption",
|
||||||
|
"details",
|
||||||
|
"summary",
|
||||||
# Text formatting
|
# Text formatting
|
||||||
'em', 'strong', 'b', 'i', 'mark', 'small',
|
"em",
|
||||||
|
"strong",
|
||||||
|
"b",
|
||||||
|
"i",
|
||||||
|
"mark",
|
||||||
|
"small",
|
||||||
# Rich content
|
# Rich content
|
||||||
'time', 'address', 'cite', 'q'
|
"time",
|
||||||
|
"address",
|
||||||
|
"cite",
|
||||||
|
"q",
|
||||||
}
|
}
|
||||||
self.excluded_tags = {
|
self.excluded_tags = {
|
||||||
'nav', 'footer', 'header', 'aside', 'script',
|
"nav",
|
||||||
'style', 'form', 'iframe', 'noscript'
|
"footer",
|
||||||
|
"header",
|
||||||
|
"aside",
|
||||||
|
"script",
|
||||||
|
"style",
|
||||||
|
"form",
|
||||||
|
"iframe",
|
||||||
|
"noscript",
|
||||||
}
|
}
|
||||||
self.header_tags = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}
|
self.header_tags = {"h1", "h2", "h3", "h4", "h5", "h6"}
|
||||||
self.negative_patterns = re.compile(
|
self.negative_patterns = re.compile(
|
||||||
r'nav|footer|header|sidebar|ads|comment|promo|advert|social|share',
|
r"nav|footer|header|sidebar|ads|comment|promo|advert|social|share", re.I
|
||||||
re.I
|
|
||||||
)
|
)
|
||||||
self.min_word_count = 2
|
self.min_word_count = 2
|
||||||
|
|
||||||
@@ -62,28 +103,30 @@ class RelevantContentFilter(ABC):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if soup.find('h1'):
|
if soup.find("h1"):
|
||||||
query_parts.append(soup.find('h1').get_text())
|
query_parts.append(soup.find("h1").get_text())
|
||||||
|
|
||||||
# Meta tags
|
# Meta tags
|
||||||
temp = ""
|
temp = ""
|
||||||
for meta_name in ['keywords', 'description']:
|
for meta_name in ["keywords", "description"]:
|
||||||
meta = soup.find('meta', attrs={'name': meta_name})
|
meta = soup.find("meta", attrs={"name": meta_name})
|
||||||
if meta and meta.get('content'):
|
if meta and meta.get("content"):
|
||||||
query_parts.append(meta['content'])
|
query_parts.append(meta["content"])
|
||||||
temp += meta['content']
|
temp += meta["content"]
|
||||||
|
|
||||||
# If still empty, grab first significant paragraph
|
# If still empty, grab first significant paragraph
|
||||||
if not temp:
|
if not temp:
|
||||||
# Find the first tag P thatits text contains more than 50 characters
|
# Find the first tag P thatits text contains more than 50 characters
|
||||||
for p in body.find_all('p'):
|
for p in body.find_all("p"):
|
||||||
if len(p.get_text()) > 150:
|
if len(p.get_text()) > 150:
|
||||||
query_parts.append(p.get_text()[:150])
|
query_parts.append(p.get_text()[:150])
|
||||||
break
|
break
|
||||||
|
|
||||||
return ' '.join(filter(None, query_parts))
|
return " ".join(filter(None, query_parts))
|
||||||
|
|
||||||
def extract_text_chunks(self, body: Tag, min_word_threshold: int = None) -> List[Tuple[str, str]]:
|
def extract_text_chunks(
|
||||||
|
self, body: Tag, min_word_threshold: int = None
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Extracts text chunks from a BeautifulSoup body element while preserving order.
|
Extracts text chunks from a BeautifulSoup body element while preserving order.
|
||||||
Returns list of tuples (text, tag_name) for classification.
|
Returns list of tuples (text, tag_name) for classification.
|
||||||
@@ -96,14 +139,42 @@ class RelevantContentFilter(ABC):
|
|||||||
"""
|
"""
|
||||||
# Tags to ignore - inline elements that shouldn't break text flow
|
# Tags to ignore - inline elements that shouldn't break text flow
|
||||||
INLINE_TAGS = {
|
INLINE_TAGS = {
|
||||||
'a', 'abbr', 'acronym', 'b', 'bdo', 'big', 'br', 'button', 'cite', 'code',
|
"a",
|
||||||
'dfn', 'em', 'i', 'img', 'input', 'kbd', 'label', 'map', 'object', 'q',
|
"abbr",
|
||||||
'samp', 'script', 'select', 'small', 'span', 'strong', 'sub', 'sup',
|
"acronym",
|
||||||
'textarea', 'time', 'tt', 'var'
|
"b",
|
||||||
|
"bdo",
|
||||||
|
"big",
|
||||||
|
"br",
|
||||||
|
"button",
|
||||||
|
"cite",
|
||||||
|
"code",
|
||||||
|
"dfn",
|
||||||
|
"em",
|
||||||
|
"i",
|
||||||
|
"img",
|
||||||
|
"input",
|
||||||
|
"kbd",
|
||||||
|
"label",
|
||||||
|
"map",
|
||||||
|
"object",
|
||||||
|
"q",
|
||||||
|
"samp",
|
||||||
|
"script",
|
||||||
|
"select",
|
||||||
|
"small",
|
||||||
|
"span",
|
||||||
|
"strong",
|
||||||
|
"sub",
|
||||||
|
"sup",
|
||||||
|
"textarea",
|
||||||
|
"time",
|
||||||
|
"tt",
|
||||||
|
"var",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Tags that typically contain meaningful headers
|
# Tags that typically contain meaningful headers
|
||||||
HEADER_TAGS = {'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'header'}
|
HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"}
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
current_text = []
|
current_text = []
|
||||||
@@ -111,9 +182,8 @@ class RelevantContentFilter(ABC):
|
|||||||
|
|
||||||
def should_break_chunk(tag: Tag) -> bool:
|
def should_break_chunk(tag: Tag) -> bool:
|
||||||
"""Determine if a tag should cause a break in the current text chunk"""
|
"""Determine if a tag should cause a break in the current text chunk"""
|
||||||
return (
|
return tag.name not in INLINE_TAGS and not (
|
||||||
tag.name not in INLINE_TAGS
|
tag.name == "p" and len(current_text) == 0
|
||||||
and not (tag.name == 'p' and len(current_text) == 0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use deque for efficient push/pop operations
|
# Use deque for efficient push/pop operations
|
||||||
@@ -125,9 +195,11 @@ class RelevantContentFilter(ABC):
|
|||||||
if visited:
|
if visited:
|
||||||
# End of block element - flush accumulated text
|
# End of block element - flush accumulated text
|
||||||
if current_text and should_break_chunk(element):
|
if current_text and should_break_chunk(element):
|
||||||
text = ' '.join(''.join(current_text).split())
|
text = " ".join("".join(current_text).split())
|
||||||
if text:
|
if text:
|
||||||
tag_type = 'header' if element.name in HEADER_TAGS else 'content'
|
tag_type = (
|
||||||
|
"header" if element.name in HEADER_TAGS else "content"
|
||||||
|
)
|
||||||
chunks.append((chunk_index, text, tag_type, element))
|
chunks.append((chunk_index, text, tag_type, element))
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
current_text = []
|
current_text = []
|
||||||
@@ -153,18 +225,23 @@ class RelevantContentFilter(ABC):
|
|||||||
|
|
||||||
# Handle any remaining text
|
# Handle any remaining text
|
||||||
if current_text:
|
if current_text:
|
||||||
text = ' '.join(''.join(current_text).split())
|
text = " ".join("".join(current_text).split())
|
||||||
if text:
|
if text:
|
||||||
chunks.append((chunk_index, text, 'content', body))
|
chunks.append((chunk_index, text, "content", body))
|
||||||
|
|
||||||
if min_word_threshold:
|
if min_word_threshold:
|
||||||
chunks = [chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold]
|
chunks = [
|
||||||
|
chunk for chunk in chunks if len(chunk[1].split()) >= min_word_threshold
|
||||||
|
]
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def _deprecated_extract_text_chunks(self, soup: BeautifulSoup) -> List[Tuple[int, str, Tag]]:
|
def _deprecated_extract_text_chunks(
|
||||||
|
self, soup: BeautifulSoup
|
||||||
|
) -> List[Tuple[int, str, Tag]]:
|
||||||
"""Common method for extracting text chunks"""
|
"""Common method for extracting text chunks"""
|
||||||
_text_cache = {}
|
_text_cache = {}
|
||||||
|
|
||||||
def fast_text(element: Tag) -> str:
|
def fast_text(element: Tag) -> str:
|
||||||
elem_id = id(element)
|
elem_id = id(element)
|
||||||
if elem_id in _text_cache:
|
if elem_id in _text_cache:
|
||||||
@@ -175,7 +252,7 @@ class RelevantContentFilter(ABC):
|
|||||||
text = content.strip()
|
text = content.strip()
|
||||||
if text:
|
if text:
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
result = ' '.join(texts)
|
result = " ".join(texts)
|
||||||
_text_cache[elem_id] = result
|
_text_cache[elem_id] = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -210,10 +287,9 @@ class RelevantContentFilter(ABC):
|
|||||||
"""Common method for exclusion logic"""
|
"""Common method for exclusion logic"""
|
||||||
if tag.name in self.excluded_tags:
|
if tag.name in self.excluded_tags:
|
||||||
return True
|
return True
|
||||||
class_id = ' '.join(filter(None, [
|
class_id = " ".join(
|
||||||
' '.join(tag.get('class', [])),
|
filter(None, [" ".join(tag.get("class", [])), tag.get("id", "")])
|
||||||
tag.get('id', '')
|
)
|
||||||
]))
|
|
||||||
return bool(self.negative_patterns.search(class_id))
|
return bool(self.negative_patterns.search(class_id))
|
||||||
|
|
||||||
def clean_element(self, tag: Tag) -> str:
|
def clean_element(self, tag: Tag) -> str:
|
||||||
@@ -221,8 +297,16 @@ class RelevantContentFilter(ABC):
|
|||||||
if not tag or not isinstance(tag, Tag):
|
if not tag or not isinstance(tag, Tag):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
unwanted_tags = {'script', 'style', 'aside', 'form', 'iframe', 'noscript'}
|
unwanted_tags = {"script", "style", "aside", "form", "iframe", "noscript"}
|
||||||
unwanted_attrs = {'style', 'onclick', 'onmouseover', 'align', 'bgcolor', 'class', 'id'}
|
unwanted_attrs = {
|
||||||
|
"style",
|
||||||
|
"onclick",
|
||||||
|
"onmouseover",
|
||||||
|
"align",
|
||||||
|
"bgcolor",
|
||||||
|
"class",
|
||||||
|
"id",
|
||||||
|
}
|
||||||
|
|
||||||
# Use string builder pattern for better performance
|
# Use string builder pattern for better performance
|
||||||
builder = []
|
builder = []
|
||||||
@@ -237,28 +321,29 @@ class RelevantContentFilter(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Start tag
|
# Start tag
|
||||||
builder.append(f'<{elem.name}')
|
builder.append(f"<{elem.name}")
|
||||||
|
|
||||||
# Add cleaned attributes
|
# Add cleaned attributes
|
||||||
attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs}
|
attrs = {k: v for k, v in elem.attrs.items() if k not in unwanted_attrs}
|
||||||
for key, value in attrs.items():
|
for key, value in attrs.items():
|
||||||
builder.append(f' {key}="{value}"')
|
builder.append(f' {key}="{value}"')
|
||||||
|
|
||||||
builder.append('>')
|
builder.append(">")
|
||||||
|
|
||||||
# Process children
|
# Process children
|
||||||
for child in elem.children:
|
for child in elem.children:
|
||||||
render_tag(child)
|
render_tag(child)
|
||||||
|
|
||||||
# Close tag
|
# Close tag
|
||||||
builder.append(f'</{elem.name}>')
|
builder.append(f"</{elem.name}>")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
render_tag(tag)
|
render_tag(tag)
|
||||||
return ''.join(builder)
|
return "".join(builder)
|
||||||
except Exception:
|
except Exception:
|
||||||
return str(tag) # Fallback to original if anything fails
|
return str(tag) # Fallback to original if anything fails
|
||||||
|
|
||||||
|
|
||||||
class BM25ContentFilter(RelevantContentFilter):
|
class BM25ContentFilter(RelevantContentFilter):
|
||||||
"""
|
"""
|
||||||
Content filtering using BM25 algorithm with priority tag handling.
|
Content filtering using BM25 algorithm with priority tag handling.
|
||||||
@@ -280,7 +365,13 @@ class BM25ContentFilter(RelevantContentFilter):
|
|||||||
Methods:
|
Methods:
|
||||||
filter_content(self, html: str, min_word_threshold: int = None)
|
filter_content(self, html: str, min_word_threshold: int = None)
|
||||||
"""
|
"""
|
||||||
def __init__(self, user_query: str = None, bm25_threshold: float = 1.0, language: str = 'english'):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_query: str = None,
|
||||||
|
bm25_threshold: float = 1.0,
|
||||||
|
language: str = "english",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the BM25ContentFilter class, if not provided, falls back to page metadata.
|
Initializes the BM25ContentFilter class, if not provided, falls back to page metadata.
|
||||||
|
|
||||||
@@ -295,17 +386,17 @@ class BM25ContentFilter(RelevantContentFilter):
|
|||||||
super().__init__(user_query=user_query)
|
super().__init__(user_query=user_query)
|
||||||
self.bm25_threshold = bm25_threshold
|
self.bm25_threshold = bm25_threshold
|
||||||
self.priority_tags = {
|
self.priority_tags = {
|
||||||
'h1': 5.0,
|
"h1": 5.0,
|
||||||
'h2': 4.0,
|
"h2": 4.0,
|
||||||
'h3': 3.0,
|
"h3": 3.0,
|
||||||
'title': 4.0,
|
"title": 4.0,
|
||||||
'strong': 2.0,
|
"strong": 2.0,
|
||||||
'b': 1.5,
|
"b": 1.5,
|
||||||
'em': 1.5,
|
"em": 1.5,
|
||||||
'blockquote': 2.0,
|
"blockquote": 2.0,
|
||||||
'code': 2.0,
|
"code": 2.0,
|
||||||
'pre': 1.5,
|
"pre": 1.5,
|
||||||
'th': 1.5, # Table headers
|
"th": 1.5, # Table headers
|
||||||
}
|
}
|
||||||
self.stemmer = stemmer(language)
|
self.stemmer = stemmer(language)
|
||||||
|
|
||||||
@@ -327,13 +418,13 @@ class BM25ContentFilter(RelevantContentFilter):
|
|||||||
if not html or not isinstance(html, str):
|
if not html or not isinstance(html, str):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
soup = BeautifulSoup(html, 'lxml')
|
soup = BeautifulSoup(html, "lxml")
|
||||||
|
|
||||||
# Check if body is present
|
# Check if body is present
|
||||||
if not soup.body:
|
if not soup.body:
|
||||||
# Wrap in body tag if missing
|
# Wrap in body tag if missing
|
||||||
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
|
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
|
||||||
body = soup.find('body')
|
body = soup.find("body")
|
||||||
|
|
||||||
query = self.extract_page_query(soup, body)
|
query = self.extract_page_query(soup, body)
|
||||||
|
|
||||||
@@ -354,9 +445,13 @@ class BM25ContentFilter(RelevantContentFilter):
|
|||||||
# for _, chunk, _, _ in candidates]
|
# for _, chunk, _, _ in candidates]
|
||||||
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
|
# tokenized_query = [ps.stem(word) for word in query.lower().split()]
|
||||||
|
|
||||||
tokenized_corpus = [[self.stemmer.stemWord(word) for word in chunk.lower().split()]
|
tokenized_corpus = [
|
||||||
for _, chunk, _, _ in candidates]
|
[self.stemmer.stemWord(word) for word in chunk.lower().split()]
|
||||||
tokenized_query = [self.stemmer.stemWord(word) for word in query.lower().split()]
|
for _, chunk, _, _ in candidates
|
||||||
|
]
|
||||||
|
tokenized_query = [
|
||||||
|
self.stemmer.stemWord(word) for word in query.lower().split()
|
||||||
|
]
|
||||||
|
|
||||||
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
|
# tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())]
|
||||||
# for _, chunk, _, _ in candidates]
|
# for _, chunk, _, _ in candidates]
|
||||||
@@ -378,7 +473,8 @@ class BM25ContentFilter(RelevantContentFilter):
|
|||||||
|
|
||||||
# Filter candidates by threshold
|
# Filter candidates by threshold
|
||||||
selected_candidates = [
|
selected_candidates = [
|
||||||
(index, chunk, tag) for adjusted_score, index, chunk, tag in adjusted_candidates
|
(index, chunk, tag)
|
||||||
|
for adjusted_score, index, chunk, tag in adjusted_candidates
|
||||||
if adjusted_score >= self.bm25_threshold
|
if adjusted_score >= self.bm25_threshold
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -390,6 +486,7 @@ class BM25ContentFilter(RelevantContentFilter):
|
|||||||
|
|
||||||
return [self.clean_element(tag) for _, _, tag in selected_candidates]
|
return [self.clean_element(tag) for _, _, tag in selected_candidates]
|
||||||
|
|
||||||
|
|
||||||
class PruningContentFilter(RelevantContentFilter):
|
class PruningContentFilter(RelevantContentFilter):
|
||||||
"""
|
"""
|
||||||
Content filtering using pruning algorithm with dynamic threshold.
|
Content filtering using pruning algorithm with dynamic threshold.
|
||||||
@@ -411,8 +508,14 @@ class PruningContentFilter(RelevantContentFilter):
|
|||||||
Methods:
|
Methods:
|
||||||
filter_content(self, html: str, min_word_threshold: int = None):
|
filter_content(self, html: str, min_word_threshold: int = None):
|
||||||
"""
|
"""
|
||||||
def __init__(self, user_query: str = None, min_word_threshold: int = None,
|
|
||||||
threshold_type: str = 'fixed', threshold: float = 0.48):
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_query: str = None,
|
||||||
|
min_word_threshold: int = None,
|
||||||
|
threshold_type: str = "fixed",
|
||||||
|
threshold: float = 0.48,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the PruningContentFilter class, if not provided, falls back to page metadata.
|
Initializes the PruningContentFilter class, if not provided, falls back to page metadata.
|
||||||
|
|
||||||
@@ -432,49 +535,49 @@ class PruningContentFilter(RelevantContentFilter):
|
|||||||
|
|
||||||
# Add tag importance for dynamic threshold
|
# Add tag importance for dynamic threshold
|
||||||
self.tag_importance = {
|
self.tag_importance = {
|
||||||
'article': 1.5,
|
"article": 1.5,
|
||||||
'main': 1.4,
|
"main": 1.4,
|
||||||
'section': 1.3,
|
"section": 1.3,
|
||||||
'p': 1.2,
|
"p": 1.2,
|
||||||
'h1': 1.4,
|
"h1": 1.4,
|
||||||
'h2': 1.3,
|
"h2": 1.3,
|
||||||
'h3': 1.2,
|
"h3": 1.2,
|
||||||
'div': 0.7,
|
"div": 0.7,
|
||||||
'span': 0.6
|
"span": 0.6,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Metric configuration
|
# Metric configuration
|
||||||
self.metric_config = {
|
self.metric_config = {
|
||||||
'text_density': True,
|
"text_density": True,
|
||||||
'link_density': True,
|
"link_density": True,
|
||||||
'tag_weight': True,
|
"tag_weight": True,
|
||||||
'class_id_weight': True,
|
"class_id_weight": True,
|
||||||
'text_length': True,
|
"text_length": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.metric_weights = {
|
self.metric_weights = {
|
||||||
'text_density': 0.4,
|
"text_density": 0.4,
|
||||||
'link_density': 0.2,
|
"link_density": 0.2,
|
||||||
'tag_weight': 0.2,
|
"tag_weight": 0.2,
|
||||||
'class_id_weight': 0.1,
|
"class_id_weight": 0.1,
|
||||||
'text_length': 0.1,
|
"text_length": 0.1,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.tag_weights = {
|
self.tag_weights = {
|
||||||
'div': 0.5,
|
"div": 0.5,
|
||||||
'p': 1.0,
|
"p": 1.0,
|
||||||
'article': 1.5,
|
"article": 1.5,
|
||||||
'section': 1.0,
|
"section": 1.0,
|
||||||
'span': 0.3,
|
"span": 0.3,
|
||||||
'li': 0.5,
|
"li": 0.5,
|
||||||
'ul': 0.5,
|
"ul": 0.5,
|
||||||
'ol': 0.5,
|
"ol": 0.5,
|
||||||
'h1': 1.2,
|
"h1": 1.2,
|
||||||
'h2': 1.1,
|
"h2": 1.1,
|
||||||
'h3': 1.0,
|
"h3": 1.0,
|
||||||
'h4': 0.9,
|
"h4": 0.9,
|
||||||
'h5': 0.8,
|
"h5": 0.8,
|
||||||
'h6': 0.7,
|
"h6": 0.7,
|
||||||
}
|
}
|
||||||
|
|
||||||
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
|
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
|
||||||
@@ -495,22 +598,22 @@ class PruningContentFilter(RelevantContentFilter):
|
|||||||
if not html or not isinstance(html, str):
|
if not html or not isinstance(html, str):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
soup = BeautifulSoup(html, 'lxml')
|
soup = BeautifulSoup(html, "lxml")
|
||||||
if not soup.body:
|
if not soup.body:
|
||||||
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
|
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
|
||||||
|
|
||||||
# Remove comments and unwanted tags
|
# Remove comments and unwanted tags
|
||||||
self._remove_comments(soup)
|
self._remove_comments(soup)
|
||||||
self._remove_unwanted_tags(soup)
|
self._remove_unwanted_tags(soup)
|
||||||
|
|
||||||
# Prune tree starting from body
|
# Prune tree starting from body
|
||||||
body = soup.find('body')
|
body = soup.find("body")
|
||||||
self._prune_tree(body)
|
self._prune_tree(body)
|
||||||
|
|
||||||
# Extract remaining content as list of HTML strings
|
# Extract remaining content as list of HTML strings
|
||||||
content_blocks = []
|
content_blocks = []
|
||||||
for element in body.children:
|
for element in body.children:
|
||||||
if isinstance(element, str) or not hasattr(element, 'name'):
|
if isinstance(element, str) or not hasattr(element, "name"):
|
||||||
continue
|
continue
|
||||||
if len(element.get_text(strip=True)) > 0:
|
if len(element.get_text(strip=True)) > 0:
|
||||||
content_blocks.append(str(element))
|
content_blocks.append(str(element))
|
||||||
@@ -535,24 +638,28 @@ class PruningContentFilter(RelevantContentFilter):
|
|||||||
Args:
|
Args:
|
||||||
node (Tag): The node from which the pruning starts.
|
node (Tag): The node from which the pruning starts.
|
||||||
"""
|
"""
|
||||||
if not node or not hasattr(node, 'name') or node.name is None:
|
if not node or not hasattr(node, "name") or node.name is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
text_len = len(node.get_text(strip=True))
|
text_len = len(node.get_text(strip=True))
|
||||||
tag_len = len(node.encode_contents().decode('utf-8'))
|
tag_len = len(node.encode_contents().decode("utf-8"))
|
||||||
link_text_len = sum(len(s.strip()) for s in (a.string for a in node.find_all('a', recursive=False)) if s)
|
link_text_len = sum(
|
||||||
|
len(s.strip())
|
||||||
|
for s in (a.string for a in node.find_all("a", recursive=False))
|
||||||
|
if s
|
||||||
|
)
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'node': node,
|
"node": node,
|
||||||
'tag_name': node.name,
|
"tag_name": node.name,
|
||||||
'text_len': text_len,
|
"text_len": text_len,
|
||||||
'tag_len': tag_len,
|
"tag_len": tag_len,
|
||||||
'link_text_len': link_text_len
|
"link_text_len": link_text_len,
|
||||||
}
|
}
|
||||||
|
|
||||||
score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len)
|
score = self._compute_composite_score(metrics, text_len, tag_len, link_text_len)
|
||||||
|
|
||||||
if self.threshold_type == 'fixed':
|
if self.threshold_type == "fixed":
|
||||||
should_remove = score < self.threshold
|
should_remove = score < self.threshold
|
||||||
else: # dynamic
|
else: # dynamic
|
||||||
tag_importance = self.tag_importance.get(node.name, 0.7)
|
tag_importance = self.tag_importance.get(node.name, 0.7)
|
||||||
@@ -572,7 +679,7 @@ class PruningContentFilter(RelevantContentFilter):
|
|||||||
if should_remove:
|
if should_remove:
|
||||||
node.decompose()
|
node.decompose()
|
||||||
else:
|
else:
|
||||||
children = [child for child in node.children if hasattr(child, 'name')]
|
children = [child for child in node.children if hasattr(child, "name")]
|
||||||
for child in children:
|
for child in children:
|
||||||
self._prune_tree(child)
|
self._prune_tree(child)
|
||||||
|
|
||||||
@@ -580,48 +687,48 @@ class PruningContentFilter(RelevantContentFilter):
|
|||||||
"""Computes the composite score"""
|
"""Computes the composite score"""
|
||||||
if self.min_word_threshold:
|
if self.min_word_threshold:
|
||||||
# Get raw text from metrics node - avoid extra processing
|
# Get raw text from metrics node - avoid extra processing
|
||||||
text = metrics['node'].get_text(strip=True)
|
text = metrics["node"].get_text(strip=True)
|
||||||
word_count = text.count(' ') + 1
|
word_count = text.count(" ") + 1
|
||||||
if word_count < self.min_word_threshold:
|
if word_count < self.min_word_threshold:
|
||||||
return -1.0 # Guaranteed removal
|
return -1.0 # Guaranteed removal
|
||||||
score = 0.0
|
score = 0.0
|
||||||
total_weight = 0.0
|
total_weight = 0.0
|
||||||
|
|
||||||
if self.metric_config['text_density']:
|
if self.metric_config["text_density"]:
|
||||||
density = text_len / tag_len if tag_len > 0 else 0
|
density = text_len / tag_len if tag_len > 0 else 0
|
||||||
score += self.metric_weights['text_density'] * density
|
score += self.metric_weights["text_density"] * density
|
||||||
total_weight += self.metric_weights['text_density']
|
total_weight += self.metric_weights["text_density"]
|
||||||
|
|
||||||
if self.metric_config['link_density']:
|
if self.metric_config["link_density"]:
|
||||||
density = 1 - (link_text_len / text_len if text_len > 0 else 0)
|
density = 1 - (link_text_len / text_len if text_len > 0 else 0)
|
||||||
score += self.metric_weights['link_density'] * density
|
score += self.metric_weights["link_density"] * density
|
||||||
total_weight += self.metric_weights['link_density']
|
total_weight += self.metric_weights["link_density"]
|
||||||
|
|
||||||
if self.metric_config['tag_weight']:
|
if self.metric_config["tag_weight"]:
|
||||||
tag_score = self.tag_weights.get(metrics['tag_name'], 0.5)
|
tag_score = self.tag_weights.get(metrics["tag_name"], 0.5)
|
||||||
score += self.metric_weights['tag_weight'] * tag_score
|
score += self.metric_weights["tag_weight"] * tag_score
|
||||||
total_weight += self.metric_weights['tag_weight']
|
total_weight += self.metric_weights["tag_weight"]
|
||||||
|
|
||||||
if self.metric_config['class_id_weight']:
|
if self.metric_config["class_id_weight"]:
|
||||||
class_score = self._compute_class_id_weight(metrics['node'])
|
class_score = self._compute_class_id_weight(metrics["node"])
|
||||||
score += self.metric_weights['class_id_weight'] * max(0, class_score)
|
score += self.metric_weights["class_id_weight"] * max(0, class_score)
|
||||||
total_weight += self.metric_weights['class_id_weight']
|
total_weight += self.metric_weights["class_id_weight"]
|
||||||
|
|
||||||
if self.metric_config['text_length']:
|
if self.metric_config["text_length"]:
|
||||||
score += self.metric_weights['text_length'] * math.log(text_len + 1)
|
score += self.metric_weights["text_length"] * math.log(text_len + 1)
|
||||||
total_weight += self.metric_weights['text_length']
|
total_weight += self.metric_weights["text_length"]
|
||||||
|
|
||||||
return score / total_weight if total_weight > 0 else 0
|
return score / total_weight if total_weight > 0 else 0
|
||||||
|
|
||||||
def _compute_class_id_weight(self, node):
|
def _compute_class_id_weight(self, node):
|
||||||
"""Computes the class ID weight"""
|
"""Computes the class ID weight"""
|
||||||
class_id_score = 0
|
class_id_score = 0
|
||||||
if 'class' in node.attrs:
|
if "class" in node.attrs:
|
||||||
classes = ' '.join(node['class'])
|
classes = " ".join(node["class"])
|
||||||
if self.negative_patterns.match(classes):
|
if self.negative_patterns.match(classes):
|
||||||
class_id_score -= 0.5
|
class_id_score -= 0.5
|
||||||
if 'id' in node.attrs:
|
if "id" in node.attrs:
|
||||||
element_id = node['id']
|
element_id = node["id"]
|
||||||
if self.negative_patterns.match(element_id):
|
if self.negative_patterns.match(element_id):
|
||||||
class_id_score -= 0.5
|
class_id_score -= 0.5
|
||||||
return class_id_score
|
return class_id_score
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -15,32 +15,30 @@ import logging, time
|
|||||||
import base64
|
import base64
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Callable
|
from typing import Callable
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
|
||||||
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
|
logger = logging.getLogger("selenium.webdriver.remote.remote_connection")
|
||||||
logger.setLevel(logging.WARNING)
|
logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
logger_driver = logging.getLogger('selenium.webdriver.common.service')
|
logger_driver = logging.getLogger("selenium.webdriver.common.service")
|
||||||
logger_driver.setLevel(logging.WARNING)
|
logger_driver.setLevel(logging.WARNING)
|
||||||
|
|
||||||
urllib3_logger = logging.getLogger('urllib3.connectionpool')
|
urllib3_logger = logging.getLogger("urllib3.connectionpool")
|
||||||
urllib3_logger.setLevel(logging.WARNING)
|
urllib3_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Disable http.client logging
|
# Disable http.client logging
|
||||||
http_client_logger = logging.getLogger('http.client')
|
http_client_logger = logging.getLogger("http.client")
|
||||||
http_client_logger.setLevel(logging.WARNING)
|
http_client_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
# Disable driver_finder and service logging
|
# Disable driver_finder and service logging
|
||||||
driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finder')
|
driver_finder_logger = logging.getLogger("selenium.webdriver.common.driver_finder")
|
||||||
driver_finder_logger.setLevel(logging.WARNING)
|
driver_finder_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrawlerStrategy(ABC):
|
class CrawlerStrategy(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def crawl(self, url: str, **kwargs) -> str:
|
def crawl(self, url: str, **kwargs) -> str:
|
||||||
@@ -58,6 +56,7 @@ class CrawlerStrategy(ABC):
|
|||||||
def set_hook(self, hook_type: str, hook: Callable):
|
def set_hook(self, hook_type: str, hook: Callable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CloudCrawlerStrategy(CrawlerStrategy):
|
class CloudCrawlerStrategy(CrawlerStrategy):
|
||||||
def __init__(self, use_cached_html=False):
|
def __init__(self, use_cached_html=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -76,6 +75,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
|
|||||||
html = response["results"][0]["html"]
|
html = response["results"][0]["html"]
|
||||||
return sanitize_input_encode(html)
|
return sanitize_input_encode(html)
|
||||||
|
|
||||||
|
|
||||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||||
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -87,9 +87,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
if kwargs.get("user_agent"):
|
if kwargs.get("user_agent"):
|
||||||
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
|
self.options.add_argument("--user-agent=" + kwargs.get("user_agent"))
|
||||||
else:
|
else:
|
||||||
user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
user_agent = kwargs.get(
|
||||||
|
"user_agent",
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||||
|
)
|
||||||
self.options.add_argument(f"--user-agent={user_agent}")
|
self.options.add_argument(f"--user-agent={user_agent}")
|
||||||
self.options.add_argument("user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
self.options.add_argument(
|
||||||
|
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||||
|
)
|
||||||
|
|
||||||
self.options.headless = kwargs.get("headless", True)
|
self.options.headless = kwargs.get("headless", True)
|
||||||
if self.options.headless:
|
if self.options.headless:
|
||||||
@@ -123,11 +128,11 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
|
|
||||||
# Hooks
|
# Hooks
|
||||||
self.hooks = {
|
self.hooks = {
|
||||||
'on_driver_created': None,
|
"on_driver_created": None,
|
||||||
'on_user_agent_updated': None,
|
"on_user_agent_updated": None,
|
||||||
'before_get_url': None,
|
"before_get_url": None,
|
||||||
'after_get_url': None,
|
"after_get_url": None,
|
||||||
'before_return_html': None
|
"before_return_html": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# chromedriver_autoinstaller.install()
|
# chromedriver_autoinstaller.install()
|
||||||
@@ -138,7 +143,6 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
# chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver()
|
# chromedriver_path = chromedriver_autoinstaller.utils.download_chromedriver()
|
||||||
# self.service = Service(chromedriver_autoinstaller.install())
|
# self.service = Service(chromedriver_autoinstaller.install())
|
||||||
|
|
||||||
|
|
||||||
# chromedriver_path = ChromeDriverManager().install()
|
# chromedriver_path = ChromeDriverManager().install()
|
||||||
# self.service = Service(chromedriver_path)
|
# self.service = Service(chromedriver_path)
|
||||||
# self.service.log_path = "NUL"
|
# self.service.log_path = "NUL"
|
||||||
@@ -148,14 +152,12 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
self.service = Service()
|
self.service = Service()
|
||||||
self.driver = webdriver.Chrome(options=self.options)
|
self.driver = webdriver.Chrome(options=self.options)
|
||||||
|
|
||||||
self.driver = self.execute_hook('on_driver_created', self.driver)
|
self.driver = self.execute_hook("on_driver_created", self.driver)
|
||||||
|
|
||||||
if kwargs.get("cookies"):
|
if kwargs.get("cookies"):
|
||||||
for cookie in kwargs.get("cookies"):
|
for cookie in kwargs.get("cookies"):
|
||||||
self.driver.add_cookie(cookie)
|
self.driver.add_cookie(cookie)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_hook(self, hook_type: str, hook: Callable):
|
def set_hook(self, hook_type: str, hook: Callable):
|
||||||
if hook_type in self.hooks:
|
if hook_type in self.hooks:
|
||||||
self.hooks[hook_type] = hook
|
self.hooks[hook_type] = hook
|
||||||
@@ -170,7 +172,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
if isinstance(result, webdriver.Chrome):
|
if isinstance(result, webdriver.Chrome):
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.")
|
raise TypeError(
|
||||||
|
f"Hook {hook_type} must return an instance of webdriver.Chrome or None."
|
||||||
|
)
|
||||||
# If the hook returns None or there is no hook, return self.driver
|
# If the hook returns None or there is no hook, return self.driver
|
||||||
return self.driver
|
return self.driver
|
||||||
|
|
||||||
@@ -178,13 +182,13 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
self.options.add_argument(f"user-agent={user_agent}")
|
self.options.add_argument(f"user-agent={user_agent}")
|
||||||
self.driver.quit()
|
self.driver.quit()
|
||||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||||
self.driver = self.execute_hook('on_user_agent_updated', self.driver)
|
self.driver = self.execute_hook("on_user_agent_updated", self.driver)
|
||||||
|
|
||||||
def set_custom_headers(self, headers: dict):
|
def set_custom_headers(self, headers: dict):
|
||||||
# Enable Network domain for sending headers
|
# Enable Network domain for sending headers
|
||||||
self.driver.execute_cdp_cmd('Network.enable', {})
|
self.driver.execute_cdp_cmd("Network.enable", {})
|
||||||
# Set extra HTTP headers
|
# Set extra HTTP headers
|
||||||
self.driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': headers})
|
self.driver.execute_cdp_cmd("Network.setExtraHTTPHeaders", {"headers": headers})
|
||||||
|
|
||||||
def _ensure_page_load(self, max_checks=6, check_interval=0.01):
|
def _ensure_page_load(self, max_checks=6, check_interval=0.01):
|
||||||
initial_length = len(self.driver.page_source)
|
initial_length = len(self.driver.page_source)
|
||||||
@@ -202,36 +206,53 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
def crawl(self, url: str, **kwargs) -> str:
|
def crawl(self, url: str, **kwargs) -> str:
|
||||||
# Create md5 hash of the URL
|
# Create md5 hash of the URL
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
url_hash = hashlib.md5(url.encode()).hexdigest()
|
url_hash = hashlib.md5(url.encode()).hexdigest()
|
||||||
|
|
||||||
if self.use_cached_html:
|
if self.use_cached_html:
|
||||||
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash)
|
cache_file_path = os.path.join(
|
||||||
|
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
|
||||||
|
".crawl4ai",
|
||||||
|
"cache",
|
||||||
|
url_hash,
|
||||||
|
)
|
||||||
if os.path.exists(cache_file_path):
|
if os.path.exists(cache_file_path):
|
||||||
with open(cache_file_path, "r") as f:
|
with open(cache_file_path, "r") as f:
|
||||||
return sanitize_input_encode(f.read())
|
return sanitize_input_encode(f.read())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.driver = self.execute_hook('before_get_url', self.driver)
|
self.driver = self.execute_hook("before_get_url", self.driver)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
||||||
self.driver.get(url) # <html><head></head><body></body></html>
|
self.driver.get(url) # <html><head></head><body></body></html>
|
||||||
|
|
||||||
WebDriverWait(self.driver, 20).until(
|
WebDriverWait(self.driver, 20).until(
|
||||||
lambda d: d.execute_script('return document.readyState') == 'complete'
|
lambda d: d.execute_script("return document.readyState") == "complete"
|
||||||
)
|
)
|
||||||
WebDriverWait(self.driver, 10).until(
|
WebDriverWait(self.driver, 10).until(
|
||||||
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
|
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
|
||||||
)
|
)
|
||||||
|
|
||||||
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
self.driver.execute_script(
|
||||||
|
"window.scrollTo(0, document.body.scrollHeight);"
|
||||||
|
)
|
||||||
|
|
||||||
self.driver = self.execute_hook('after_get_url', self.driver)
|
self.driver = self.execute_hook("after_get_url", self.driver)
|
||||||
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source
|
html = sanitize_input_encode(
|
||||||
can_not_be_done_headless = False # Look at my creativity for naming variables
|
self._ensure_page_load()
|
||||||
|
) # self.driver.page_source
|
||||||
|
can_not_be_done_headless = (
|
||||||
|
False # Look at my creativity for naming variables
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Very ugly approach, but promise to change it!
|
# TODO: Very ugly approach, but promise to change it!
|
||||||
if kwargs.get('bypass_headless', False) or html == "<html><head></head><body></body></html>":
|
if (
|
||||||
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...")
|
kwargs.get("bypass_headless", False)
|
||||||
|
or html == "<html><head></head><body></body></html>"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
"[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode..."
|
||||||
|
)
|
||||||
can_not_be_done_headless = True
|
can_not_be_done_headless = True
|
||||||
options = Options()
|
options = Options()
|
||||||
options.headless = False
|
options.headless = False
|
||||||
@@ -239,7 +260,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
options.add_argument("--window-size=5,5")
|
options.add_argument("--window-size=5,5")
|
||||||
driver = webdriver.Chrome(service=self.service, options=options)
|
driver = webdriver.Chrome(service=self.service, options=options)
|
||||||
driver.get(url)
|
driver.get(url)
|
||||||
self.driver = self.execute_hook('after_get_url', driver)
|
self.driver = self.execute_hook("after_get_url", driver)
|
||||||
html = sanitize_input_encode(driver.page_source)
|
html = sanitize_input_encode(driver.page_source)
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
||||||
@@ -249,17 +270,21 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
self.driver.execute_script(self.js_code)
|
self.driver.execute_script(self.js_code)
|
||||||
# Optionally, wait for some condition after executing the JS code
|
# Optionally, wait for some condition after executing the JS code
|
||||||
WebDriverWait(self.driver, 10).until(
|
WebDriverWait(self.driver, 10).until(
|
||||||
lambda driver: driver.execute_script("return document.readyState") == "complete"
|
lambda driver: driver.execute_script("return document.readyState")
|
||||||
|
== "complete"
|
||||||
)
|
)
|
||||||
elif self.js_code and type(self.js_code) == list:
|
elif self.js_code and type(self.js_code) == list:
|
||||||
for js in self.js_code:
|
for js in self.js_code:
|
||||||
self.driver.execute_script(js)
|
self.driver.execute_script(js)
|
||||||
WebDriverWait(self.driver, 10).until(
|
WebDriverWait(self.driver, 10).until(
|
||||||
lambda driver: driver.execute_script("return document.readyState") == "complete"
|
lambda driver: driver.execute_script(
|
||||||
|
"return document.readyState"
|
||||||
|
)
|
||||||
|
== "complete"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky)
|
# Optionally, wait for some condition after executing the JS code : Contributed by (https://github.com/jonymusky)
|
||||||
wait_for = kwargs.get('wait_for', False)
|
wait_for = kwargs.get("wait_for", False)
|
||||||
if wait_for:
|
if wait_for:
|
||||||
if callable(wait_for):
|
if callable(wait_for):
|
||||||
print("[LOG] 🔄 Waiting for condition...")
|
print("[LOG] 🔄 Waiting for condition...")
|
||||||
@@ -272,10 +297,15 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
|
|
||||||
if not can_not_be_done_headless:
|
if not can_not_be_done_headless:
|
||||||
html = sanitize_input_encode(self.driver.page_source)
|
html = sanitize_input_encode(self.driver.page_source)
|
||||||
self.driver = self.execute_hook('before_return_html', self.driver, html)
|
self.driver = self.execute_hook("before_return_html", self.driver, html)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
cache_file_path = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai", "cache", url_hash)
|
cache_file_path = os.path.join(
|
||||||
|
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()),
|
||||||
|
".crawl4ai",
|
||||||
|
"cache",
|
||||||
|
url_hash,
|
||||||
|
)
|
||||||
with open(cache_file_path, "w", encoding="utf-8") as f:
|
with open(cache_file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(html)
|
f.write(html)
|
||||||
|
|
||||||
@@ -284,16 +314,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
|
|
||||||
return html
|
return html
|
||||||
except InvalidArgumentException as e:
|
except InvalidArgumentException as e:
|
||||||
if not hasattr(e, 'msg'):
|
if not hasattr(e, "msg"):
|
||||||
e.msg = sanitize_input_encode(str(e))
|
e.msg = sanitize_input_encode(str(e))
|
||||||
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
|
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
|
||||||
except WebDriverException as e:
|
except WebDriverException as e:
|
||||||
# If e does nlt have msg attribute create it and set it to str(e)
|
# If e does nlt have msg attribute create it and set it to str(e)
|
||||||
if not hasattr(e, 'msg'):
|
if not hasattr(e, "msg"):
|
||||||
e.msg = sanitize_input_encode(str(e))
|
e.msg = sanitize_input_encode(str(e))
|
||||||
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
|
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not hasattr(e, 'msg'):
|
if not hasattr(e, "msg"):
|
||||||
e.msg = sanitize_input_encode(str(e))
|
e.msg = sanitize_input_encode(str(e))
|
||||||
raise Exception(f"Failed to crawl {url}: {e.msg}")
|
raise Exception(f"Failed to crawl {url}: {e.msg}")
|
||||||
|
|
||||||
@@ -301,7 +331,9 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
try:
|
try:
|
||||||
# Get the dimensions of the page
|
# Get the dimensions of the page
|
||||||
total_width = self.driver.execute_script("return document.body.scrollWidth")
|
total_width = self.driver.execute_script("return document.body.scrollWidth")
|
||||||
total_height = self.driver.execute_script("return document.body.scrollHeight")
|
total_height = self.driver.execute_script(
|
||||||
|
"return document.body.scrollHeight"
|
||||||
|
)
|
||||||
|
|
||||||
# Set the window size to the dimensions of the page
|
# Set the window size to the dimensions of the page
|
||||||
self.driver.set_window_size(total_width, total_height)
|
self.driver.set_window_size(total_width, total_height)
|
||||||
@@ -313,23 +345,25 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
image = Image.open(BytesIO(screenshot))
|
image = Image.open(BytesIO(screenshot))
|
||||||
|
|
||||||
# Convert image to RGB mode (this will handle both RGB and RGBA images)
|
# Convert image to RGB mode (this will handle both RGB and RGBA images)
|
||||||
rgb_image = image.convert('RGB')
|
rgb_image = image.convert("RGB")
|
||||||
|
|
||||||
# Convert to JPEG and compress
|
# Convert to JPEG and compress
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
rgb_image.save(buffered, format="JPEG", quality=85)
|
rgb_image.save(buffered, format="JPEG", quality=85)
|
||||||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] 📸 Screenshot taken and converted to base64")
|
print("[LOG] 📸 Screenshot taken and converted to base64")
|
||||||
|
|
||||||
return img_base64
|
return img_base64
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}")
|
error_message = sanitize_input_encode(
|
||||||
|
f"Failed to take screenshot: {str(e)}"
|
||||||
|
)
|
||||||
print(error_message)
|
print(error_message)
|
||||||
|
|
||||||
# Generate an image with black background
|
# Generate an image with black background
|
||||||
img = Image.new('RGB', (800, 600), color='black')
|
img = Image.new("RGB", (800, 600), color="black")
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
# Load a font
|
# Load a font
|
||||||
@@ -352,7 +386,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
# Convert to base64
|
# Convert to base64
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
img.save(buffered, format="JPEG")
|
img.save(buffered, format="JPEG")
|
||||||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
return img_base64
|
return img_base64
|
||||||
|
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".cra
|
|||||||
os.makedirs(DB_PATH, exist_ok=True)
|
os.makedirs(DB_PATH, exist_ok=True)
|
||||||
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
|
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
global DB_PATH
|
global DB_PATH
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('''
|
cursor.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS crawled_data (
|
CREATE TABLE IF NOT EXISTS crawled_data (
|
||||||
url TEXT PRIMARY KEY,
|
url TEXT PRIMARY KEY,
|
||||||
html TEXT,
|
html TEXT,
|
||||||
@@ -24,31 +26,42 @@ def init_db():
|
|||||||
metadata TEXT DEFAULT "{}",
|
metadata TEXT DEFAULT "{}",
|
||||||
screenshot TEXT DEFAULT ""
|
screenshot TEXT DEFAULT ""
|
||||||
)
|
)
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def alter_db_add_screenshot(new_column: str = "media"):
|
def alter_db_add_screenshot(new_column: str = "media"):
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
|
cursor.execute(
|
||||||
|
f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""'
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error altering database to add screenshot column: {e}")
|
print(f"Error altering database to add screenshot column: {e}")
|
||||||
|
|
||||||
|
|
||||||
def check_db_path():
|
def check_db_path():
|
||||||
if not DB_PATH:
|
if not DB_PATH:
|
||||||
raise ValueError("Database path is not set or is empty.")
|
raise ValueError("Database path is not set or is empty.")
|
||||||
|
|
||||||
def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
|
|
||||||
|
def get_cached_url(
|
||||||
|
url: str,
|
||||||
|
) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,))
|
cursor.execute(
|
||||||
|
"SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?",
|
||||||
|
(url,),
|
||||||
|
)
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
return result
|
return result
|
||||||
@@ -56,12 +69,25 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, str, str
|
|||||||
print(f"Error retrieving cached URL: {e}")
|
print(f"Error retrieving cached URL: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", links : str = "{}", metadata : str = "{}", screenshot: str = ""):
|
|
||||||
|
def cache_url(
|
||||||
|
url: str,
|
||||||
|
html: str,
|
||||||
|
cleaned_html: str,
|
||||||
|
markdown: str,
|
||||||
|
extracted_content: str,
|
||||||
|
success: bool,
|
||||||
|
media: str = "{}",
|
||||||
|
links: str = "{}",
|
||||||
|
metadata: str = "{}",
|
||||||
|
screenshot: str = "",
|
||||||
|
):
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('''
|
cursor.execute(
|
||||||
|
"""
|
||||||
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
|
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT(url) DO UPDATE SET
|
ON CONFLICT(url) DO UPDATE SET
|
||||||
@@ -74,18 +100,32 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
|
|||||||
links = excluded.links,
|
links = excluded.links,
|
||||||
metadata = excluded.metadata,
|
metadata = excluded.metadata,
|
||||||
screenshot = excluded.screenshot
|
screenshot = excluded.screenshot
|
||||||
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot))
|
""",
|
||||||
|
(
|
||||||
|
url,
|
||||||
|
html,
|
||||||
|
cleaned_html,
|
||||||
|
markdown,
|
||||||
|
extracted_content,
|
||||||
|
success,
|
||||||
|
media,
|
||||||
|
links,
|
||||||
|
metadata,
|
||||||
|
screenshot,
|
||||||
|
),
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error caching URL: {e}")
|
print(f"Error caching URL: {e}")
|
||||||
|
|
||||||
|
|
||||||
def get_total_count() -> int:
|
def get_total_count() -> int:
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('SELECT COUNT(*) FROM crawled_data')
|
cursor.execute("SELECT COUNT(*) FROM crawled_data")
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
return result[0]
|
return result[0]
|
||||||
@@ -93,43 +133,48 @@ def get_total_count() -> int:
|
|||||||
print(f"Error getting total count: {e}")
|
print(f"Error getting total count: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def clear_db():
|
def clear_db():
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('DELETE FROM crawled_data')
|
cursor.execute("DELETE FROM crawled_data")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error clearing database: {e}")
|
print(f"Error clearing database: {e}")
|
||||||
|
|
||||||
|
|
||||||
def flush_db():
|
def flush_db():
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('DROP TABLE crawled_data')
|
cursor.execute("DROP TABLE crawled_data")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error flushing database: {e}")
|
print(f"Error flushing database: {e}")
|
||||||
|
|
||||||
|
|
||||||
def update_existing_records(new_column: str = "media", default_value: str = "{}"):
|
def update_existing_records(new_column: str = "media", default_value: str = "{}"):
|
||||||
check_db_path()
|
check_db_path()
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL')
|
cursor.execute(
|
||||||
|
f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL'
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error updating existing records: {e}")
|
print(f"Error updating existing records: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Delete the existing database file
|
# Delete the existing database file
|
||||||
if os.path.exists(DB_PATH):
|
if os.path.exists(DB_PATH):
|
||||||
os.remove(DB_PATH)
|
os.remove(DB_PATH)
|
||||||
init_db()
|
init_db()
|
||||||
# alter_db_add_screenshot("COL_NAME")
|
# alter_db_add_screenshot("COL_NAME")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from crawl4ai.async_logger import AsyncLogger
|
from crawl4ai.async_logger import AsyncLogger
|
||||||
from crawl4ai.llmtxt import AsyncLLMTextManager
|
from crawl4ai.llmtxt import AsyncLLMTextManager
|
||||||
|
|
||||||
|
|
||||||
class DocsManager:
|
class DocsManager:
|
||||||
def __init__(self, logger=None):
|
def __init__(self, logger=None):
|
||||||
self.docs_dir = Path.home() / ".crawl4ai" / "docs"
|
self.docs_dir = Path.home() / ".crawl4ai" / "docs"
|
||||||
@@ -21,7 +22,10 @@ class DocsManager:
|
|||||||
"""Copy from local docs or download from GitHub"""
|
"""Copy from local docs or download from GitHub"""
|
||||||
try:
|
try:
|
||||||
# Try local first
|
# Try local first
|
||||||
if self.local_docs.exists() and (any(self.local_docs.glob("*.md")) or any(self.local_docs.glob("*.tokens"))):
|
if self.local_docs.exists() and (
|
||||||
|
any(self.local_docs.glob("*.md"))
|
||||||
|
or any(self.local_docs.glob("*.tokens"))
|
||||||
|
):
|
||||||
# Empty the local docs directory
|
# Empty the local docs directory
|
||||||
for file_path in self.docs_dir.glob("*.md"):
|
for file_path in self.docs_dir.glob("*.md"):
|
||||||
file_path.unlink()
|
file_path.unlink()
|
||||||
@@ -36,14 +40,14 @@ class DocsManager:
|
|||||||
# Fallback to GitHub
|
# Fallback to GitHub
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
|
"https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt",
|
||||||
headers={'Accept': 'application/vnd.github.v3+json'}
|
headers={"Accept": "application/vnd.github.v3+json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
for item in response.json():
|
for item in response.json():
|
||||||
if item['type'] == 'file' and item['name'].endswith('.md'):
|
if item["type"] == "file" and item["name"].endswith(".md"):
|
||||||
content = requests.get(item['download_url']).text
|
content = requests.get(item["download_url"]).text
|
||||||
with open(self.docs_dir / item['name'], 'w', encoding='utf-8') as f:
|
with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -57,7 +61,11 @@ class DocsManager:
|
|||||||
# Remove [0-9]+_ prefix
|
# Remove [0-9]+_ prefix
|
||||||
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
|
names = [name.split("_", 1)[1] if name[0].isdigit() else name for name in names]
|
||||||
# Exclude those end with .xs.md and .q.md
|
# Exclude those end with .xs.md and .q.md
|
||||||
names = [name for name in names if not name.endswith(".xs") and not name.endswith(".q")]
|
names = [
|
||||||
|
name
|
||||||
|
for name in names
|
||||||
|
if not name.endswith(".xs") and not name.endswith(".q")
|
||||||
|
]
|
||||||
return names
|
return names
|
||||||
|
|
||||||
def generate(self, sections, mode="extended"):
|
def generate(self, sections, mode="extended"):
|
||||||
|
|||||||
@@ -1,20 +1,48 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Dict, Optional, Union
|
from typing import Any, List, Dict, Optional
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
import json, time
|
import json
|
||||||
# from optimum.intel import IPEXModel
|
import time
|
||||||
from .prompts import *
|
import os
|
||||||
from .config import *
|
|
||||||
from .utils import *
|
from .prompts import PROMPT_EXTRACT_BLOCKS
|
||||||
from .models import *
|
from .config import (
|
||||||
|
DEFAULT_PROVIDER, PROVIDER_MODELS,
|
||||||
|
CHUNK_TOKEN_THRESHOLD,
|
||||||
|
OVERLAP_RATE,
|
||||||
|
WORD_TOKEN_RATE,
|
||||||
|
PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION,
|
||||||
|
PROMPT_EXTRACT_SCHEMA_WITH_INSTRUCTION
|
||||||
|
)
|
||||||
|
from .utils import * # noqa: F403
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
sanitize_html,
|
||||||
|
calculate_batch_size,
|
||||||
|
escape_json_string,
|
||||||
|
perform_completion_with_backoff,
|
||||||
|
extract_xml_data,
|
||||||
|
split_and_parse_json_objects,
|
||||||
|
sanitize_input_encode,
|
||||||
|
)
|
||||||
|
from .models import * # noqa: F403
|
||||||
|
|
||||||
|
from .models import TokenUsage
|
||||||
|
|
||||||
|
from .model_loader import * # noqa: F403
|
||||||
|
from .model_loader import (
|
||||||
|
get_device,
|
||||||
|
load_HF_embedding_model,
|
||||||
|
load_text_multilabel_classifier,
|
||||||
|
)
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from .model_loader import *
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from lxml import html, etree
|
from lxml import html, etree
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
class ExtractionStrategy(ABC):
|
class ExtractionStrategy(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -56,15 +84,20 @@ class ExtractionStrategy(ABC):
|
|||||||
"""
|
"""
|
||||||
extracted_content = []
|
extracted_content = []
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections]
|
futures = [
|
||||||
|
executor.submit(self.extract, url, section, **kwargs)
|
||||||
|
for section in sections
|
||||||
|
]
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
extracted_content.extend(future.result())
|
extracted_content.extend(future.result())
|
||||||
return extracted_content
|
return extracted_content
|
||||||
|
|
||||||
|
|
||||||
class NoExtractionStrategy(ExtractionStrategy):
|
class NoExtractionStrategy(ExtractionStrategy):
|
||||||
"""
|
"""
|
||||||
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
|
A strategy that does not extract any meaningful content from the HTML. It simply returns the entire HTML as a single block.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Extract meaningful blocks or chunks from the given HTML.
|
Extract meaningful blocks or chunks from the given HTML.
|
||||||
@@ -72,13 +105,17 @@ class NoExtractionStrategy(ExtractionStrategy):
|
|||||||
return [{"index": 0, "content": html}]
|
return [{"index": 0, "content": html}]
|
||||||
|
|
||||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
|
return [
|
||||||
|
{"index": i, "tags": [], "content": section}
|
||||||
|
for i, section in enumerate(sections)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
#######################################################
|
#######################################################
|
||||||
# Strategies using clustering for text data extraction #
|
# Strategies using clustering for text data extraction #
|
||||||
#######################################################
|
#######################################################
|
||||||
|
|
||||||
|
|
||||||
class CosineStrategy(ExtractionStrategy):
|
class CosineStrategy(ExtractionStrategy):
|
||||||
"""
|
"""
|
||||||
Extract meaningful blocks or chunks from the given HTML using cosine similarity.
|
Extract meaningful blocks or chunks from the given HTML using cosine similarity.
|
||||||
@@ -99,7 +136,18 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
model_name (str): The name of the sentence-transformers model.
|
model_name (str): The name of the sentence-transformers model.
|
||||||
sim_threshold (float): The similarity threshold for clustering.
|
sim_threshold (float): The similarity threshold for clustering.
|
||||||
"""
|
"""
|
||||||
def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'sentence-transformers/all-MiniLM-L6-v2', sim_threshold = 0.3, **kwargs):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
semantic_filter=None,
|
||||||
|
word_count_threshold=10,
|
||||||
|
max_dist=0.2,
|
||||||
|
linkage_method="ward",
|
||||||
|
top_k=3,
|
||||||
|
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
sim_threshold=0.3,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the strategy with clustering parameters.
|
Initialize the strategy with clustering parameters.
|
||||||
|
|
||||||
@@ -162,7 +210,6 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
# self.tokenizer = self.model.tokenizer
|
# self.tokenizer = self.model.tokenizer
|
||||||
# self.get_embedding_method = "direct"
|
# self.get_embedding_method = "direct"
|
||||||
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.")
|
print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.")
|
||||||
|
|
||||||
@@ -170,9 +217,15 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
|
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
print(
|
||||||
|
f"[LOG] Model loaded {model_name}, models/reuters, took "
|
||||||
|
+ str(time.time() - self.timer)
|
||||||
|
+ " seconds"
|
||||||
|
)
|
||||||
|
|
||||||
def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, at_least_k: int = 20) -> List[str]:
|
def filter_documents_embeddings(
|
||||||
|
self, documents: List[str], semantic_filter: str, at_least_k: int = 20
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
|
Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
|
||||||
|
|
||||||
@@ -200,14 +253,24 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
document_embeddings = self.get_embeddings(documents)
|
document_embeddings = self.get_embeddings(documents)
|
||||||
|
|
||||||
# Calculate cosine similarity between the query embedding and document embeddings
|
# Calculate cosine similarity between the query embedding and document embeddings
|
||||||
similarities = cosine_similarity([query_embedding], document_embeddings).flatten()
|
similarities = cosine_similarity(
|
||||||
|
[query_embedding], document_embeddings
|
||||||
|
).flatten()
|
||||||
|
|
||||||
# Filter documents based on the similarity threshold
|
# Filter documents based on the similarity threshold
|
||||||
filtered_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim >= self.sim_threshold]
|
filtered_docs = [
|
||||||
|
(doc, sim)
|
||||||
|
for doc, sim in zip(documents, similarities)
|
||||||
|
if sim >= self.sim_threshold
|
||||||
|
]
|
||||||
|
|
||||||
# If the number of filtered documents is less than at_least_k, sort remaining documents by similarity
|
# If the number of filtered documents is less than at_least_k, sort remaining documents by similarity
|
||||||
if len(filtered_docs) < at_least_k:
|
if len(filtered_docs) < at_least_k:
|
||||||
remaining_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim < self.sim_threshold]
|
remaining_docs = [
|
||||||
|
(doc, sim)
|
||||||
|
for doc, sim in zip(documents, similarities)
|
||||||
|
if sim < self.sim_threshold
|
||||||
|
]
|
||||||
remaining_docs.sort(key=lambda x: x[1], reverse=True)
|
remaining_docs.sort(key=lambda x: x[1], reverse=True)
|
||||||
filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)])
|
filtered_docs.extend(remaining_docs[: at_least_k - len(filtered_docs)])
|
||||||
|
|
||||||
@@ -216,7 +279,9 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
return filtered_docs[:at_least_k]
|
return filtered_docs[:at_least_k]
|
||||||
|
|
||||||
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=False):
|
def get_embeddings(
|
||||||
|
self, sentences: List[str], batch_size=None, bypass_buffer=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get BERT embeddings for a list of sentences.
|
Get BERT embeddings for a list of sentences.
|
||||||
|
|
||||||
@@ -231,6 +296,7 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
if self.device.type in ["cpu", "gpu", "cuda", "mps"]:
|
if self.device.type in ["cpu", "gpu", "cuda", "mps"]:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Tokenize sentences and convert to tensor
|
# Tokenize sentences and convert to tensor
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size = self.default_batch_size
|
batch_size = self.default_batch_size
|
||||||
@@ -238,8 +304,12 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
for i in range(0, len(sentences), batch_size):
|
for i in range(0, len(sentences), batch_size):
|
||||||
batch_sentences = sentences[i : i + batch_size]
|
batch_sentences = sentences[i : i + batch_size]
|
||||||
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
|
encoded_input = self.tokenizer(
|
||||||
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
|
batch_sentences, padding=True, truncation=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
encoded_input = {
|
||||||
|
key: tensor.to(self.device) for key, tensor in encoded_input.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Ensure no gradients are calculated
|
# Ensure no gradients are calculated
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -277,18 +347,21 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
# Get embeddings
|
# Get embeddings
|
||||||
from scipy.cluster.hierarchy import linkage, fcluster
|
from scipy.cluster.hierarchy import linkage, fcluster
|
||||||
from scipy.spatial.distance import pdist
|
from scipy.spatial.distance import pdist
|
||||||
|
|
||||||
self.timer = time.time()
|
self.timer = time.time()
|
||||||
embeddings = self.get_embeddings(sentences, bypass_buffer=True)
|
embeddings = self.get_embeddings(sentences, bypass_buffer=True)
|
||||||
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
|
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
|
||||||
# Compute pairwise cosine distances
|
# Compute pairwise cosine distances
|
||||||
distance_matrix = pdist(embeddings, 'cosine')
|
distance_matrix = pdist(embeddings, "cosine")
|
||||||
# Perform agglomerative clustering respecting order
|
# Perform agglomerative clustering respecting order
|
||||||
linked = linkage(distance_matrix, method=self.linkage_method)
|
linked = linkage(distance_matrix, method=self.linkage_method)
|
||||||
# Form flat clusters
|
# Form flat clusters
|
||||||
labels = fcluster(linked, self.max_dist, criterion='distance')
|
labels = fcluster(linked, self.max_dist, criterion="distance")
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def filter_clusters_by_word_count(self, clusters: Dict[int, List[str]]) -> Dict[int, List[str]]:
|
def filter_clusters_by_word_count(
|
||||||
|
self, clusters: Dict[int, List[str]]
|
||||||
|
) -> Dict[int, List[str]]:
|
||||||
"""
|
"""
|
||||||
Filter clusters to remove those with a word count below the threshold.
|
Filter clusters to remove those with a word count below the threshold.
|
||||||
|
|
||||||
@@ -327,7 +400,9 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
|
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
|
||||||
|
|
||||||
# Pre-filter documents using embeddings and semantic_filter
|
# Pre-filter documents using embeddings and semantic_filter
|
||||||
text_chunks = self.filter_documents_embeddings(text_chunks, self.semantic_filter)
|
text_chunks = self.filter_documents_embeddings(
|
||||||
|
text_chunks, self.semantic_filter
|
||||||
|
)
|
||||||
|
|
||||||
if not text_chunks:
|
if not text_chunks:
|
||||||
return []
|
return []
|
||||||
@@ -346,16 +421,19 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
filtered_clusters = self.filter_clusters_by_word_count(clusters)
|
filtered_clusters = self.filter_clusters_by_word_count(clusters)
|
||||||
|
|
||||||
# Convert filtered clusters to a sorted list of dictionaries
|
# Convert filtered clusters to a sorted list of dictionaries
|
||||||
cluster_list = [{"index": int(idx), "tags" : [], "content": " ".join(filtered_clusters[idx])} for idx in sorted(filtered_clusters)]
|
cluster_list = [
|
||||||
|
{"index": int(idx), "tags": [], "content": " ".join(filtered_clusters[idx])}
|
||||||
|
for idx in sorted(filtered_clusters)
|
||||||
|
]
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] 🚀 Assign tags using {self.device}")
|
print(f"[LOG] 🚀 Assign tags using {self.device}")
|
||||||
|
|
||||||
if self.device.type in ["gpu", "cuda", "mps", "cpu"]:
|
if self.device.type in ["gpu", "cuda", "mps", "cpu"]:
|
||||||
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
labels = self.nlp([cluster["content"] for cluster in cluster_list])
|
||||||
|
|
||||||
for cluster, label in zip(cluster_list, labels):
|
for cluster, label in zip(cluster_list, labels):
|
||||||
cluster['tags'] = label
|
cluster["tags"] = label
|
||||||
# elif self.device.type == "cpu":
|
# elif self.device.type == "cpu":
|
||||||
# # Process the text with the loaded model
|
# # Process the text with the loaded model
|
||||||
# texts = [cluster['content'] for cluster in cluster_list]
|
# texts = [cluster['content'] for cluster in cluster_list]
|
||||||
@@ -393,7 +471,6 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
return self.extract(url, self.DEL.join(sections), **kwargs)
|
return self.extract(url, self.DEL.join(sections), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#######################################################
|
#######################################################
|
||||||
# Strategies using LLM-based extraction for text data #
|
# Strategies using LLM-based extraction for text data #
|
||||||
#######################################################
|
#######################################################
|
||||||
@@ -419,9 +496,15 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
total_usage: Accumulated token usage.
|
total_usage: Accumulated token usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None,
|
self,
|
||||||
instruction:str = None, schema:Dict = None, extraction_type = "block", **kwargs):
|
provider: str = DEFAULT_PROVIDER,
|
||||||
|
api_token: Optional[str] = None,
|
||||||
|
instruction: str = None,
|
||||||
|
schema: Dict = None,
|
||||||
|
extraction_type="block",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the strategy with clustering parameters.
|
Initialize the strategy with clustering parameters.
|
||||||
|
|
||||||
@@ -445,14 +528,20 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.api_token = api_token or PROVIDER_MODELS.get(provider, "no-token") or os.getenv("OPENAI_API_KEY")
|
self.api_token = (
|
||||||
|
api_token
|
||||||
|
or PROVIDER_MODELS.get(provider, "no-token")
|
||||||
|
or os.getenv("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
self.instruction = instruction
|
self.instruction = instruction
|
||||||
self.extract_type = extraction_type
|
self.extract_type = extraction_type
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
if schema:
|
if schema:
|
||||||
self.extract_type = "schema"
|
self.extract_type = "schema"
|
||||||
|
|
||||||
self.chunk_token_threshold = kwargs.get("chunk_token_threshold", CHUNK_TOKEN_THRESHOLD)
|
self.chunk_token_threshold = kwargs.get(
|
||||||
|
"chunk_token_threshold", CHUNK_TOKEN_THRESHOLD
|
||||||
|
)
|
||||||
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
|
self.overlap_rate = kwargs.get("overlap_rate", OVERLAP_RATE)
|
||||||
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
|
self.word_token_rate = kwargs.get("word_token_rate", WORD_TOKEN_RATE)
|
||||||
self.apply_chunking = kwargs.get("apply_chunking", True)
|
self.apply_chunking = kwargs.get("apply_chunking", True)
|
||||||
@@ -467,8 +556,9 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
self.total_usage = TokenUsage() # Accumulated usage
|
self.total_usage = TokenUsage() # Accumulated usage
|
||||||
|
|
||||||
if not self.api_token:
|
if not self.api_token:
|
||||||
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.")
|
raise ValueError(
|
||||||
|
"API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
|
def extract(self, url: str, ix: int, html: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -515,15 +605,19 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
prompt_with_variables,
|
prompt_with_variables,
|
||||||
self.api_token,
|
self.api_token,
|
||||||
base_url=self.api_base or self.base_url,
|
base_url=self.api_base or self.base_url,
|
||||||
extra_args = self.extra_args
|
extra_args=self.extra_args,
|
||||||
) # , json_response=self.extract_type == "schema")
|
) # , json_response=self.extract_type == "schema")
|
||||||
# Track usage
|
# Track usage
|
||||||
usage = TokenUsage(
|
usage = TokenUsage(
|
||||||
completion_tokens=response.usage.completion_tokens,
|
completion_tokens=response.usage.completion_tokens,
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
total_tokens=response.usage.total_tokens,
|
total_tokens=response.usage.total_tokens,
|
||||||
completion_tokens_details=response.usage.completion_tokens_details.__dict__ if response.usage.completion_tokens_details else {},
|
completion_tokens_details=response.usage.completion_tokens_details.__dict__
|
||||||
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__ if response.usage.prompt_tokens_details else {}
|
if response.usage.completion_tokens_details
|
||||||
|
else {},
|
||||||
|
prompt_tokens_details=response.usage.prompt_tokens_details.__dict__
|
||||||
|
if response.usage.prompt_tokens_details
|
||||||
|
else {},
|
||||||
)
|
)
|
||||||
self.usages.append(usage)
|
self.usages.append(usage)
|
||||||
|
|
||||||
@@ -533,36 +627,44 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
self.total_usage.total_tokens += usage.total_tokens
|
self.total_usage.total_tokens += usage.total_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
|
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)[
|
||||||
|
"blocks"
|
||||||
|
]
|
||||||
blocks = json.loads(blocks)
|
blocks = json.loads(blocks)
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
block['error'] = False
|
block["error"] = False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
parsed, unparsed = split_and_parse_json_objects(response.choices[0].message.content)
|
parsed, unparsed = split_and_parse_json_objects(
|
||||||
|
response.choices[0].message.content
|
||||||
|
)
|
||||||
blocks = parsed
|
blocks = parsed
|
||||||
if unparsed:
|
if unparsed:
|
||||||
blocks.append({
|
blocks.append(
|
||||||
"index": 0,
|
{"index": 0, "error": True, "tags": ["error"], "content": unparsed}
|
||||||
"error": True,
|
)
|
||||||
"tags": ["error"],
|
|
||||||
"content": unparsed
|
|
||||||
})
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix)
|
print(
|
||||||
|
"[LOG] Extracted",
|
||||||
|
len(blocks),
|
||||||
|
"blocks from URL:",
|
||||||
|
url,
|
||||||
|
"block index:",
|
||||||
|
ix,
|
||||||
|
)
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
def _merge(self, documents, chunk_token_threshold, overlap):
|
def _merge(self, documents, chunk_token_threshold, overlap):
|
||||||
"""
|
"""
|
||||||
Merge documents into sections based on chunk_token_threshold and overlap.
|
Merge documents into sections based on chunk_token_threshold and overlap.
|
||||||
"""
|
"""
|
||||||
chunks = []
|
# chunks = []
|
||||||
sections = []
|
sections = []
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
# Calculate the total tokens across all documents
|
# Calculate the total tokens across all documents
|
||||||
for document in documents:
|
for document in documents:
|
||||||
total_tokens += len(document.split(' ')) * self.word_token_rate
|
total_tokens += len(document.split(" ")) * self.word_token_rate
|
||||||
|
|
||||||
# Calculate the number of sections needed
|
# Calculate the number of sections needed
|
||||||
num_sections = math.floor(total_tokens / chunk_token_threshold)
|
num_sections = math.floor(total_tokens / chunk_token_threshold)
|
||||||
@@ -574,7 +676,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
current_chunk = []
|
current_chunk = []
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
tokens = document.split(' ')
|
tokens = document.split(" ")
|
||||||
token_count = len(tokens) * self.word_token_rate
|
token_count = len(tokens) * self.word_token_rate
|
||||||
|
|
||||||
if total_token_so_far + token_count <= adjusted_chunk_threshold:
|
if total_token_so_far + token_count <= adjusted_chunk_threshold:
|
||||||
@@ -591,17 +693,16 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
overlap_tokens = current_chunk[-overlap:]
|
overlap_tokens = current_chunk[-overlap:]
|
||||||
current_chunk.extend(overlap_tokens)
|
current_chunk.extend(overlap_tokens)
|
||||||
|
|
||||||
sections.append(' '.join(current_chunk))
|
sections.append(" ".join(current_chunk))
|
||||||
current_chunk = tokens
|
current_chunk = tokens
|
||||||
total_token_so_far = token_count
|
total_token_so_far = token_count
|
||||||
|
|
||||||
# Add the last chunk
|
# Add the last chunk
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
sections.append(' '.join(current_chunk))
|
sections.append(" ".join(current_chunk))
|
||||||
|
|
||||||
return sections
|
return sections
|
||||||
|
|
||||||
|
|
||||||
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
|
def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
|
||||||
@@ -615,15 +716,18 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
merged_sections = self._merge(
|
merged_sections = self._merge(
|
||||||
sections, self.chunk_token_threshold,
|
sections,
|
||||||
overlap= int(self.chunk_token_threshold * self.overlap_rate)
|
self.chunk_token_threshold,
|
||||||
|
overlap=int(self.chunk_token_threshold * self.overlap_rate),
|
||||||
)
|
)
|
||||||
extracted_content = []
|
extracted_content = []
|
||||||
if self.provider.startswith("groq/"):
|
if self.provider.startswith("groq/"):
|
||||||
# Sequential processing with a delay
|
# Sequential processing with a delay
|
||||||
for ix, section in enumerate(merged_sections):
|
for ix, section in enumerate(merged_sections):
|
||||||
extract_func = partial(self.extract, url)
|
extract_func = partial(self.extract, url)
|
||||||
extracted_content.extend(extract_func(ix, sanitize_input_encode(section)))
|
extracted_content.extend(
|
||||||
|
extract_func(ix, sanitize_input_encode(section))
|
||||||
|
)
|
||||||
time.sleep(0.5) # 500 ms delay between each processing
|
time.sleep(0.5) # 500 ms delay between each processing
|
||||||
else:
|
else:
|
||||||
# Parallel processing using ThreadPoolExecutor
|
# Parallel processing using ThreadPoolExecutor
|
||||||
@@ -633,7 +737,10 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=4) as executor:
|
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||||
extract_func = partial(self.extract, url)
|
extract_func = partial(self.extract, url)
|
||||||
futures = [executor.submit(extract_func, ix, sanitize_input_encode(section)) for ix, section in enumerate(merged_sections)]
|
futures = [
|
||||||
|
executor.submit(extract_func, ix, sanitize_input_encode(section))
|
||||||
|
for ix, section in enumerate(merged_sections)
|
||||||
|
]
|
||||||
|
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
try:
|
try:
|
||||||
@@ -642,17 +749,17 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Error in thread execution: {e}")
|
print(f"Error in thread execution: {e}")
|
||||||
# Add error information to extracted_content
|
# Add error information to extracted_content
|
||||||
extracted_content.append({
|
extracted_content.append(
|
||||||
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"error": True,
|
"error": True,
|
||||||
"tags": ["error"],
|
"tags": ["error"],
|
||||||
"content": str(e)
|
"content": str(e),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return extracted_content
|
return extracted_content
|
||||||
|
|
||||||
|
|
||||||
def show_usage(self) -> None:
|
def show_usage(self) -> None:
|
||||||
"""Print a detailed token usage report showing total and per-request usage."""
|
"""Print a detailed token usage report showing total and per-request usage."""
|
||||||
print("\n=== Token Usage Summary ===")
|
print("\n=== Token Usage Summary ===")
|
||||||
@@ -666,14 +773,16 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}")
|
print(f"{'Request #':<10} {'Completion':>12} {'Prompt':>12} {'Total':>12}")
|
||||||
print("-" * 48)
|
print("-" * 48)
|
||||||
for i, usage in enumerate(self.usages, 1):
|
for i, usage in enumerate(self.usages, 1):
|
||||||
print(f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}")
|
print(
|
||||||
|
f"{i:<10} {usage.completion_tokens:>12,} {usage.prompt_tokens:>12,} {usage.total_tokens:>12,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#######################################################
|
#######################################################
|
||||||
# New extraction strategies for JSON-based extraction #
|
# New extraction strategies for JSON-based extraction #
|
||||||
#######################################################
|
#######################################################
|
||||||
|
|
||||||
|
|
||||||
class JsonElementExtractionStrategy(ExtractionStrategy):
|
class JsonElementExtractionStrategy(ExtractionStrategy):
|
||||||
"""
|
"""
|
||||||
Abstract base class for extracting structured JSON from HTML content.
|
Abstract base class for extracting structured JSON from HTML content.
|
||||||
@@ -706,8 +815,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
_get_element_attribute(element, attribute): Extracts an attribute's value from an element.
|
_get_element_attribute(element, attribute): Extracts an attribute's value from an element.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
DEL = "\n"
|
||||||
DEL = '\n'
|
|
||||||
|
|
||||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -718,9 +826,11 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.verbose = kwargs.get('verbose', False)
|
self.verbose = kwargs.get("verbose", False)
|
||||||
|
|
||||||
def extract(self, url: str, html_content: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
def extract(
|
||||||
|
self, url: str, html_content: str, *q, **kwargs
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Extract structured data from HTML content.
|
Extract structured data from HTML content.
|
||||||
|
|
||||||
@@ -740,20 +850,22 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
parsed_html = self._parse_html(html_content)
|
parsed_html = self._parse_html(html_content)
|
||||||
base_elements = self._get_base_elements(parsed_html, self.schema['baseSelector'])
|
base_elements = self._get_base_elements(
|
||||||
|
parsed_html, self.schema["baseSelector"]
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for element in base_elements:
|
for element in base_elements:
|
||||||
# Extract base element attributes
|
# Extract base element attributes
|
||||||
item = {}
|
item = {}
|
||||||
if 'baseFields' in self.schema:
|
if "baseFields" in self.schema:
|
||||||
for field in self.schema['baseFields']:
|
for field in self.schema["baseFields"]:
|
||||||
value = self._extract_single_field(element, field)
|
value = self._extract_single_field(element, field)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
item[field['name']] = value
|
item[field["name"]] = value
|
||||||
|
|
||||||
# Extract child fields
|
# Extract child fields
|
||||||
field_data = self._extract_item(element, self.schema['fields'])
|
field_data = self._extract_item(element, self.schema["fields"])
|
||||||
item.update(field_data)
|
item.update(field_data)
|
||||||
|
|
||||||
if item:
|
if item:
|
||||||
@@ -778,24 +890,28 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
def _extract_field(self, element, field):
|
def _extract_field(self, element, field):
|
||||||
try:
|
try:
|
||||||
if field['type'] == 'nested':
|
if field["type"] == "nested":
|
||||||
nested_elements = self._get_elements(element, field['selector'])
|
nested_elements = self._get_elements(element, field["selector"])
|
||||||
nested_element = nested_elements[0] if nested_elements else None
|
nested_element = nested_elements[0] if nested_elements else None
|
||||||
return self._extract_item(nested_element, field['fields']) if nested_element else {}
|
return (
|
||||||
|
self._extract_item(nested_element, field["fields"])
|
||||||
|
if nested_element
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
if field['type'] == 'list':
|
if field["type"] == "list":
|
||||||
elements = self._get_elements(element, field['selector'])
|
elements = self._get_elements(element, field["selector"])
|
||||||
return [self._extract_list_item(el, field['fields']) for el in elements]
|
return [self._extract_list_item(el, field["fields"]) for el in elements]
|
||||||
|
|
||||||
if field['type'] == 'nested_list':
|
if field["type"] == "nested_list":
|
||||||
elements = self._get_elements(element, field['selector'])
|
elements = self._get_elements(element, field["selector"])
|
||||||
return [self._extract_item(el, field['fields']) for el in elements]
|
return [self._extract_item(el, field["fields"]) for el in elements]
|
||||||
|
|
||||||
return self._extract_single_field(element, field)
|
return self._extract_single_field(element, field)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Error extracting field {field['name']}: {str(e)}")
|
print(f"Error extracting field {field['name']}: {str(e)}")
|
||||||
return field.get('default')
|
return field.get("default")
|
||||||
|
|
||||||
def _extract_single_field(self, element, field):
|
def _extract_single_field(self, element, field):
|
||||||
"""
|
"""
|
||||||
@@ -814,37 +930,37 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
Any: The extracted field value.
|
Any: The extracted field value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if 'selector' in field:
|
if "selector" in field:
|
||||||
selected = self._get_elements(element, field['selector'])
|
selected = self._get_elements(element, field["selector"])
|
||||||
if not selected:
|
if not selected:
|
||||||
return field.get('default')
|
return field.get("default")
|
||||||
selected = selected[0]
|
selected = selected[0]
|
||||||
else:
|
else:
|
||||||
selected = element
|
selected = element
|
||||||
|
|
||||||
value = None
|
value = None
|
||||||
if field['type'] == 'text':
|
if field["type"] == "text":
|
||||||
value = self._get_element_text(selected)
|
value = self._get_element_text(selected)
|
||||||
elif field['type'] == 'attribute':
|
elif field["type"] == "attribute":
|
||||||
value = self._get_element_attribute(selected, field['attribute'])
|
value = self._get_element_attribute(selected, field["attribute"])
|
||||||
elif field['type'] == 'html':
|
elif field["type"] == "html":
|
||||||
value = self._get_element_html(selected)
|
value = self._get_element_html(selected)
|
||||||
elif field['type'] == 'regex':
|
elif field["type"] == "regex":
|
||||||
text = self._get_element_text(selected)
|
text = self._get_element_text(selected)
|
||||||
match = re.search(field['pattern'], text)
|
match = re.search(field["pattern"], text)
|
||||||
value = match.group(1) if match else None
|
value = match.group(1) if match else None
|
||||||
|
|
||||||
if 'transform' in field:
|
if "transform" in field:
|
||||||
value = self._apply_transform(value, field['transform'])
|
value = self._apply_transform(value, field["transform"])
|
||||||
|
|
||||||
return value if value is not None else field.get('default')
|
return value if value is not None else field.get("default")
|
||||||
|
|
||||||
def _extract_list_item(self, element, fields):
|
def _extract_list_item(self, element, fields):
|
||||||
item = {}
|
item = {}
|
||||||
for field in fields:
|
for field in fields:
|
||||||
value = self._extract_single_field(element, field)
|
value = self._extract_single_field(element, field)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
item[field['name']] = value
|
item[field["name"]] = value
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def _extract_item(self, element, fields):
|
def _extract_item(self, element, fields):
|
||||||
@@ -866,12 +982,12 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
item = {}
|
item = {}
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field['type'] == 'computed':
|
if field["type"] == "computed":
|
||||||
value = self._compute_field(item, field)
|
value = self._compute_field(item, field)
|
||||||
else:
|
else:
|
||||||
value = self._extract_field(element, field)
|
value = self._extract_field(element, field)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
item[field['name']] = value
|
item[field["name"]] = value
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def _apply_transform(self, value, transform):
|
def _apply_transform(self, value, transform):
|
||||||
@@ -891,24 +1007,24 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
str: The transformed value.
|
str: The transformed value.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if transform == 'lowercase':
|
if transform == "lowercase":
|
||||||
return value.lower()
|
return value.lower()
|
||||||
elif transform == 'uppercase':
|
elif transform == "uppercase":
|
||||||
return value.upper()
|
return value.upper()
|
||||||
elif transform == 'strip':
|
elif transform == "strip":
|
||||||
return value.strip()
|
return value.strip()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _compute_field(self, item, field):
|
def _compute_field(self, item, field):
|
||||||
try:
|
try:
|
||||||
if 'expression' in field:
|
if "expression" in field:
|
||||||
return eval(field['expression'], {}, item)
|
return eval(field["expression"], {}, item)
|
||||||
elif 'function' in field:
|
elif "function" in field:
|
||||||
return field['function'](item)
|
return field["function"](item)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Error computing field {field['name']}: {str(e)}")
|
print(f"Error computing field {field['name']}: {str(e)}")
|
||||||
return field.get('default')
|
return field.get("default")
|
||||||
|
|
||||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -946,6 +1062,7 @@ class JsonElementExtractionStrategy(ExtractionStrategy):
|
|||||||
"""Get attribute value from element"""
|
"""Get attribute value from element"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
|
class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
|
||||||
"""
|
"""
|
||||||
Concrete implementation of `JsonElementExtractionStrategy` using CSS selectors.
|
Concrete implementation of `JsonElementExtractionStrategy` using CSS selectors.
|
||||||
@@ -969,11 +1086,11 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||||
kwargs['input_format'] = 'html' # Force HTML input
|
kwargs["input_format"] = "html" # Force HTML input
|
||||||
super().__init__(schema, **kwargs)
|
super().__init__(schema, **kwargs)
|
||||||
|
|
||||||
def _parse_html(self, html_content: str):
|
def _parse_html(self, html_content: str):
|
||||||
return BeautifulSoup(html_content, 'html.parser')
|
return BeautifulSoup(html_content, "html.parser")
|
||||||
|
|
||||||
def _get_base_elements(self, parsed_html, selector: str):
|
def _get_base_elements(self, parsed_html, selector: str):
|
||||||
return parsed_html.select(selector)
|
return parsed_html.select(selector)
|
||||||
@@ -992,6 +1109,7 @@ class JsonCssExtractionStrategy(JsonElementExtractionStrategy):
|
|||||||
def _get_element_attribute(self, element, attribute: str):
|
def _get_element_attribute(self, element, attribute: str):
|
||||||
return element.get(attribute)
|
return element.get(attribute)
|
||||||
|
|
||||||
|
|
||||||
class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
|
class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
|
||||||
"""
|
"""
|
||||||
Concrete implementation of `JsonElementExtractionStrategy` using XPath selectors.
|
Concrete implementation of `JsonElementExtractionStrategy` using XPath selectors.
|
||||||
@@ -1016,7 +1134,7 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||||
kwargs['input_format'] = 'html' # Force HTML input
|
kwargs["input_format"] = "html" # Force HTML input
|
||||||
super().__init__(schema, **kwargs)
|
super().__init__(schema, **kwargs)
|
||||||
|
|
||||||
def _parse_html(self, html_content: str):
|
def _parse_html(self, html_content: str):
|
||||||
@@ -1027,31 +1145,31 @@ class JsonXPathExtractionStrategy(JsonElementExtractionStrategy):
|
|||||||
|
|
||||||
def _css_to_xpath(self, css_selector: str) -> str:
|
def _css_to_xpath(self, css_selector: str) -> str:
|
||||||
"""Convert CSS selector to XPath if needed"""
|
"""Convert CSS selector to XPath if needed"""
|
||||||
if '/' in css_selector: # Already an XPath
|
if "/" in css_selector: # Already an XPath
|
||||||
return css_selector
|
return css_selector
|
||||||
return self._basic_css_to_xpath(css_selector)
|
return self._basic_css_to_xpath(css_selector)
|
||||||
|
|
||||||
def _basic_css_to_xpath(self, css_selector: str) -> str:
|
def _basic_css_to_xpath(self, css_selector: str) -> str:
|
||||||
"""Basic CSS to XPath conversion for common cases"""
|
"""Basic CSS to XPath conversion for common cases"""
|
||||||
if ' > ' in css_selector:
|
if " > " in css_selector:
|
||||||
parts = css_selector.split(' > ')
|
parts = css_selector.split(" > ")
|
||||||
return '//' + '/'.join(parts)
|
return "//" + "/".join(parts)
|
||||||
if ' ' in css_selector:
|
if " " in css_selector:
|
||||||
parts = css_selector.split(' ')
|
parts = css_selector.split(" ")
|
||||||
return '//' + '//'.join(parts)
|
return "//" + "//".join(parts)
|
||||||
return '//' + css_selector
|
return "//" + css_selector
|
||||||
|
|
||||||
def _get_elements(self, element, selector: str):
|
def _get_elements(self, element, selector: str):
|
||||||
xpath = self._css_to_xpath(selector)
|
xpath = self._css_to_xpath(selector)
|
||||||
if not xpath.startswith('.'):
|
if not xpath.startswith("."):
|
||||||
xpath = '.' + xpath
|
xpath = "." + xpath
|
||||||
return element.xpath(xpath)
|
return element.xpath(xpath)
|
||||||
|
|
||||||
def _get_element_text(self, element) -> str:
|
def _get_element_text(self, element) -> str:
|
||||||
return ''.join(element.xpath('.//text()')).strip()
|
return "".join(element.xpath(".//text()")).strip()
|
||||||
|
|
||||||
def _get_element_html(self, element) -> str:
|
def _get_element_html(self, element) -> str:
|
||||||
return etree.tostring(element, encoding='unicode')
|
return etree.tostring(element, encoding="unicode")
|
||||||
|
|
||||||
def _get_element_attribute(self, element, attribute: str):
|
def _get_element_attribute(self, element, attribute: str):
|
||||||
return element.get(attribute)
|
return element.get(attribute)
|
||||||
|
|||||||
@@ -903,7 +903,13 @@ class HTML2Text(html.parser.HTMLParser):
|
|||||||
self.empty_link = False
|
self.empty_link = False
|
||||||
|
|
||||||
if not self.code and not self.pre and not entity_char:
|
if not self.code and not self.pre and not entity_char:
|
||||||
data = escape_md_section(data, snob=self.escape_snob, escape_dot=self.escape_dot, escape_plus=self.escape_plus, escape_dash=self.escape_dash)
|
data = escape_md_section(
|
||||||
|
data,
|
||||||
|
snob=self.escape_snob,
|
||||||
|
escape_dot=self.escape_dot,
|
||||||
|
escape_plus=self.escape_plus,
|
||||||
|
escape_dash=self.escape_dash,
|
||||||
|
)
|
||||||
self.preceding_data = data
|
self.preceding_data = data
|
||||||
self.o(data, puredata=True)
|
self.o(data, puredata=True)
|
||||||
|
|
||||||
@@ -1006,6 +1012,7 @@ class HTML2Text(html.parser.HTMLParser):
|
|||||||
newlines += 1
|
newlines += 1
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str:
|
def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) -> str:
|
||||||
if bodywidth is None:
|
if bodywidth is None:
|
||||||
bodywidth = config.BODY_WIDTH
|
bodywidth = config.BODY_WIDTH
|
||||||
@@ -1013,6 +1020,7 @@ def html2text(html: str, baseurl: str = "", bodywidth: Optional[int] = None) ->
|
|||||||
|
|
||||||
return h.handle(html)
|
return h.handle(html)
|
||||||
|
|
||||||
|
|
||||||
class CustomHTML2Text(HTML2Text):
|
class CustomHTML2Text(HTML2Text):
|
||||||
def __init__(self, *args, handle_code_in_pre=False, **kwargs):
|
def __init__(self, *args, handle_code_in_pre=False, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -1041,9 +1049,9 @@ class CustomHTML2Text(HTML2Text):
|
|||||||
def update_params(self, **kwargs):
|
def update_params(self, **kwargs):
|
||||||
"""Update parameters and set preserved tags."""
|
"""Update parameters and set preserved tags."""
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key == 'preserve_tags':
|
if key == "preserve_tags":
|
||||||
self.preserve_tags = set(value)
|
self.preserve_tags = set(value)
|
||||||
elif key == 'handle_code_in_pre':
|
elif key == "handle_code_in_pre":
|
||||||
self.handle_code_in_pre = value
|
self.handle_code_in_pre = value
|
||||||
else:
|
else:
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
@@ -1056,17 +1064,19 @@ class CustomHTML2Text(HTML2Text):
|
|||||||
self.current_preserved_tag = tag
|
self.current_preserved_tag = tag
|
||||||
self.preserved_content = []
|
self.preserved_content = []
|
||||||
# Format opening tag with attributes
|
# Format opening tag with attributes
|
||||||
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None)
|
attr_str = "".join(
|
||||||
self.preserved_content.append(f'<{tag}{attr_str}>')
|
f' {k}="{v}"' for k, v in attrs.items() if v is not None
|
||||||
|
)
|
||||||
|
self.preserved_content.append(f"<{tag}{attr_str}>")
|
||||||
self.preserve_depth += 1
|
self.preserve_depth += 1
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.preserve_depth -= 1
|
self.preserve_depth -= 1
|
||||||
if self.preserve_depth == 0:
|
if self.preserve_depth == 0:
|
||||||
self.preserved_content.append(f'</{tag}>')
|
self.preserved_content.append(f"</{tag}>")
|
||||||
# Output the preserved HTML block with proper spacing
|
# Output the preserved HTML block with proper spacing
|
||||||
preserved_html = ''.join(self.preserved_content)
|
preserved_html = "".join(self.preserved_content)
|
||||||
self.o('\n' + preserved_html + '\n')
|
self.o("\n" + preserved_html + "\n")
|
||||||
self.current_preserved_tag = None
|
self.current_preserved_tag = None
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1074,29 +1084,31 @@ class CustomHTML2Text(HTML2Text):
|
|||||||
if self.preserve_depth > 0:
|
if self.preserve_depth > 0:
|
||||||
if start:
|
if start:
|
||||||
# Format nested tags with attributes
|
# Format nested tags with attributes
|
||||||
attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None)
|
attr_str = "".join(
|
||||||
self.preserved_content.append(f'<{tag}{attr_str}>')
|
f' {k}="{v}"' for k, v in attrs.items() if v is not None
|
||||||
|
)
|
||||||
|
self.preserved_content.append(f"<{tag}{attr_str}>")
|
||||||
else:
|
else:
|
||||||
self.preserved_content.append(f'</{tag}>')
|
self.preserved_content.append(f"</{tag}>")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle pre tags
|
# Handle pre tags
|
||||||
if tag == 'pre':
|
if tag == "pre":
|
||||||
if start:
|
if start:
|
||||||
self.o('```\n') # Markdown code block start
|
self.o("```\n") # Markdown code block start
|
||||||
self.inside_pre = True
|
self.inside_pre = True
|
||||||
else:
|
else:
|
||||||
self.o('\n```\n') # Markdown code block end
|
self.o("\n```\n") # Markdown code block end
|
||||||
self.inside_pre = False
|
self.inside_pre = False
|
||||||
elif tag == 'code':
|
elif tag == "code":
|
||||||
if self.inside_pre and not self.handle_code_in_pre:
|
if self.inside_pre and not self.handle_code_in_pre:
|
||||||
# Ignore code tags inside pre blocks if handle_code_in_pre is False
|
# Ignore code tags inside pre blocks if handle_code_in_pre is False
|
||||||
return
|
return
|
||||||
if start:
|
if start:
|
||||||
self.o('`') # Markdown inline code start
|
self.o("`") # Markdown inline code start
|
||||||
self.inside_code = True
|
self.inside_code = True
|
||||||
else:
|
else:
|
||||||
self.o('`') # Markdown inline code end
|
self.o("`") # Markdown inline code end
|
||||||
self.inside_code = False
|
self.inside_code = False
|
||||||
else:
|
else:
|
||||||
super().handle_tag(tag, attrs, start)
|
super().handle_tag(tag, attrs, start)
|
||||||
@@ -1113,13 +1125,12 @@ class CustomHTML2Text(HTML2Text):
|
|||||||
return
|
return
|
||||||
if self.inside_code:
|
if self.inside_code:
|
||||||
# Inline code: no newlines allowed
|
# Inline code: no newlines allowed
|
||||||
self.o(data.replace('\n', ' '))
|
self.o(data.replace("\n", " "))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Default behavior for other tags
|
# Default behavior for other tags
|
||||||
super().handle_data(data, entity_char)
|
super().handle_data(data, entity_char)
|
||||||
|
|
||||||
|
|
||||||
# # Handle pre tags
|
# # Handle pre tags
|
||||||
# if tag == 'pre':
|
# if tag == 'pre':
|
||||||
# if start:
|
# if start:
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
class OutCallback:
|
class OutCallback:
|
||||||
def __call__(self, s: str) -> None: ...
|
def __call__(self, s: str) -> None:
|
||||||
|
...
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ def escape_md_section(
|
|||||||
snob: bool = False,
|
snob: bool = False,
|
||||||
escape_dot: bool = True,
|
escape_dot: bool = True,
|
||||||
escape_plus: bool = True,
|
escape_plus: bool = True,
|
||||||
escape_dash: bool = True
|
escape_dash: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Escapes markdown-sensitive characters across whole document sections.
|
Escapes markdown-sensitive characters across whole document sections.
|
||||||
@@ -233,6 +233,7 @@ def escape_md_section(
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def reformat_table(lines: List[str], right_margin: int) -> List[str]:
|
def reformat_table(lines: List[str], right_margin: int) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Given the lines of a table
|
Given the lines of a table
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .async_logger import AsyncLogger, LogLevel
|
|||||||
# Initialize logger
|
# Initialize logger
|
||||||
logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
|
logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
|
||||||
|
|
||||||
|
|
||||||
def post_install():
|
def post_install():
|
||||||
"""Run all post-installation tasks"""
|
"""Run all post-installation tasks"""
|
||||||
logger.info("Running post-installation setup...", tag="INIT")
|
logger.info("Running post-installation setup...", tag="INIT")
|
||||||
@@ -13,18 +14,36 @@ def post_install():
|
|||||||
run_migration()
|
run_migration()
|
||||||
logger.success("Post-installation setup completed!", tag="COMPLETE")
|
logger.success("Post-installation setup completed!", tag="COMPLETE")
|
||||||
|
|
||||||
|
|
||||||
def install_playwright():
|
def install_playwright():
|
||||||
logger.info("Installing Playwright browsers...", tag="INIT")
|
logger.info("Installing Playwright browsers...", tag="INIT")
|
||||||
try:
|
try:
|
||||||
# subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"])
|
# subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chrome"])
|
||||||
subprocess.check_call([sys.executable, "-m", "playwright", "install", "--with-deps", "--force", "chromium"])
|
subprocess.check_call(
|
||||||
logger.success("Playwright installation completed successfully.", tag="COMPLETE")
|
[
|
||||||
except subprocess.CalledProcessError as e:
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"playwright",
|
||||||
|
"install",
|
||||||
|
"--with-deps",
|
||||||
|
"--force",
|
||||||
|
"chromium",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.success(
|
||||||
|
"Playwright installation completed successfully.", tag="COMPLETE"
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
# logger.error(f"Error during Playwright installation: {e}", tag="ERROR")
|
# logger.error(f"Error during Playwright installation: {e}", tag="ERROR")
|
||||||
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.")
|
logger.warning(
|
||||||
except Exception as e:
|
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
# logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR")
|
# logger.error(f"Unexpected error during Playwright installation: {e}", tag="ERROR")
|
||||||
logger.warning(f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation.")
|
logger.warning(
|
||||||
|
f"Please run '{sys.executable} -m playwright install --with-deps' manually after the installation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_migration():
|
def run_migration():
|
||||||
"""Initialize database during installation"""
|
"""Initialize database during installation"""
|
||||||
@@ -33,18 +52,26 @@ def run_migration():
|
|||||||
from crawl4ai.async_database import async_db_manager
|
from crawl4ai.async_database import async_db_manager
|
||||||
|
|
||||||
asyncio.run(async_db_manager.initialize())
|
asyncio.run(async_db_manager.initialize())
|
||||||
logger.success("Database initialization completed successfully.", tag="COMPLETE")
|
logger.success(
|
||||||
|
"Database initialization completed successfully.", tag="COMPLETE"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("Database module not found. Will initialize on first use.")
|
logger.warning("Database module not found. Will initialize on first use.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Database initialization failed: {e}")
|
logger.warning(f"Database initialization failed: {e}")
|
||||||
logger.warning("Database will be initialized on first use")
|
logger.warning("Database will be initialized on first use")
|
||||||
|
|
||||||
|
|
||||||
async def run_doctor():
|
async def run_doctor():
|
||||||
"""Test if Crawl4AI is working properly"""
|
"""Test if Crawl4AI is working properly"""
|
||||||
logger.info("Running Crawl4AI health check...", tag="INIT")
|
logger.info("Running Crawl4AI health check...", tag="INIT")
|
||||||
try:
|
try:
|
||||||
from .async_webcrawler import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
from .async_webcrawler import (
|
||||||
|
AsyncWebCrawler,
|
||||||
|
BrowserConfig,
|
||||||
|
CrawlerRunConfig,
|
||||||
|
CacheMode,
|
||||||
|
)
|
||||||
|
|
||||||
browser_config = BrowserConfig(
|
browser_config = BrowserConfig(
|
||||||
headless=True,
|
headless=True,
|
||||||
@@ -52,7 +79,7 @@ async def run_doctor():
|
|||||||
ignore_https_errors=True,
|
ignore_https_errors=True,
|
||||||
light_mode=True,
|
light_mode=True,
|
||||||
viewport_width=1280,
|
viewport_width=1280,
|
||||||
viewport_height=720
|
viewport_height=720,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_config = CrawlerRunConfig(
|
run_config = CrawlerRunConfig(
|
||||||
@@ -62,10 +89,7 @@ async def run_doctor():
|
|||||||
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
logger.info("Testing crawling capabilities...", tag="TEST")
|
logger.info("Testing crawling capabilities...", tag="TEST")
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(url="https://crawl4ai.com", config=run_config)
|
||||||
url="https://crawl4ai.com",
|
|
||||||
config=run_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if result and result.markdown:
|
if result and result.markdown:
|
||||||
logger.success("✅ Crawling test passed!", tag="COMPLETE")
|
logger.success("✅ Crawling test passed!", tag="COMPLETE")
|
||||||
@@ -77,7 +101,9 @@ async def run_doctor():
|
|||||||
logger.error(f"❌ Test failed: {e}", tag="ERROR")
|
logger.error(f"❌ Test failed: {e}", tag="ERROR")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def doctor():
|
def doctor():
|
||||||
"""Entry point for the doctor command"""
|
"""Entry point for the doctor command"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
return asyncio.run(run_doctor())
|
return asyncio.run(run_doctor())
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import os, sys
|
import os
|
||||||
|
|
||||||
|
|
||||||
# Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free
|
# Create a function get name of a js script, then load from the CURRENT folder of this script and return its content as string, make sure its error free
|
||||||
def load_js_script(script_name):
|
def load_js_script(script_name):
|
||||||
# Get the path of the current script
|
# Get the path of the current script
|
||||||
current_script_path = os.path.dirname(os.path.realpath(__file__))
|
current_script_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
# Get the path of the script to load
|
# Get the path of the script to load
|
||||||
script_path = os.path.join(current_script_path, script_name + '.js')
|
script_path = os.path.join(current_script_path, script_name + ".js")
|
||||||
# Check if the script exists
|
# Check if the script exists
|
||||||
if not os.path.exists(script_path):
|
if not os.path.exists(script_path):
|
||||||
raise ValueError(f"Script {script_name} not found in the folder {current_script_path}")
|
raise ValueError(
|
||||||
|
f"Script {script_name} not found in the folder {current_script_path}"
|
||||||
|
)
|
||||||
# Load the content of the script
|
# Load the content of the script
|
||||||
with open(script_path, 'r') as f:
|
with open(script_path, "r") as f:
|
||||||
script_content = f.read()
|
script_content = f.read()
|
||||||
return script_content
|
return script_content
|
||||||
|
|||||||
@@ -11,16 +11,16 @@ from rank_bm25 import BM25Okapi
|
|||||||
from nltk.tokenize import word_tokenize
|
from nltk.tokenize import word_tokenize
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from nltk.stem import WordNetLemmatizer
|
from nltk.stem import WordNetLemmatizer
|
||||||
from litellm import completion, batch_completion
|
from litellm import batch_completion
|
||||||
from .async_logger import AsyncLogger
|
from .async_logger import AsyncLogger
|
||||||
import litellm
|
import litellm
|
||||||
import pickle
|
import pickle
|
||||||
import hashlib # <--- ADDED for file-hash
|
import hashlib # <--- ADDED for file-hash
|
||||||
from fnmatch import fnmatch
|
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
|
||||||
def _compute_file_hash(file_path: Path) -> str:
|
def _compute_file_hash(file_path: Path) -> str:
|
||||||
"""Compute MD5 hash for the file's entire content."""
|
"""Compute MD5 hash for the file's entire content."""
|
||||||
hash_md5 = hashlib.md5()
|
hash_md5 = hashlib.md5()
|
||||||
@@ -29,13 +29,14 @@ def _compute_file_hash(file_path: Path) -> str:
|
|||||||
hash_md5.update(chunk)
|
hash_md5.update(chunk)
|
||||||
return hash_md5.hexdigest()
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMTextManager:
|
class AsyncLLMTextManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
docs_dir: Path,
|
docs_dir: Path,
|
||||||
logger: Optional[AsyncLogger] = None,
|
logger: Optional[AsyncLogger] = None,
|
||||||
max_concurrent_calls: int = 5,
|
max_concurrent_calls: int = 5,
|
||||||
batch_size: int = 3
|
batch_size: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.docs_dir = docs_dir
|
self.docs_dir = docs_dir
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
@@ -51,7 +52,7 @@ class AsyncLLMTextManager:
|
|||||||
contents = []
|
contents = []
|
||||||
for file_path in doc_batch:
|
for file_path in doc_batch:
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
contents.append(f.read())
|
contents.append(f.read())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error reading {file_path}: {str(e)}")
|
self.logger.error(f"Error reading {file_path}: {str(e)}")
|
||||||
@@ -77,43 +78,53 @@ Wrap your response in <index>...</index> tags.
|
|||||||
# Prepare messages for batch processing
|
# Prepare messages for batch processing
|
||||||
messages_list = [
|
messages_list = [
|
||||||
[
|
[
|
||||||
{"role": "user", "content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}"}
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{prompt}\n\nGenerate index for this documentation:\n\n{content}",
|
||||||
|
}
|
||||||
]
|
]
|
||||||
for content in contents if content
|
for content in contents
|
||||||
|
if content
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
responses = batch_completion(
|
responses = batch_completion(
|
||||||
model="anthropic/claude-3-5-sonnet-latest",
|
model="anthropic/claude-3-5-sonnet-latest",
|
||||||
messages=messages_list,
|
messages=messages_list,
|
||||||
logger_fn=None
|
logger_fn=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process responses and save index files
|
# Process responses and save index files
|
||||||
for response, file_path in zip(responses, doc_batch):
|
for response, file_path in zip(responses, doc_batch):
|
||||||
try:
|
try:
|
||||||
index_content_match = re.search(
|
index_content_match = re.search(
|
||||||
r'<index>(.*?)</index>',
|
r"<index>(.*?)</index>",
|
||||||
response.choices[0].message.content,
|
response.choices[0].message.content,
|
||||||
re.DOTALL
|
re.DOTALL,
|
||||||
)
|
)
|
||||||
if not index_content_match:
|
if not index_content_match:
|
||||||
self.logger.warning(f"No <index>...</index> content found for {file_path}")
|
self.logger.warning(
|
||||||
|
f"No <index>...</index> content found for {file_path}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index_content = re.sub(
|
index_content = re.sub(
|
||||||
r"\n\s*\n", "\n", index_content_match.group(1)
|
r"\n\s*\n", "\n", index_content_match.group(1)
|
||||||
).strip()
|
).strip()
|
||||||
if index_content:
|
if index_content:
|
||||||
index_file = file_path.with_suffix('.q.md')
|
index_file = file_path.with_suffix(".q.md")
|
||||||
with open(index_file, 'w', encoding='utf-8') as f:
|
with open(index_file, "w", encoding="utf-8") as f:
|
||||||
f.write(index_content)
|
f.write(index_content)
|
||||||
self.logger.info(f"Created index file: {index_file}")
|
self.logger.info(f"Created index file: {index_file}")
|
||||||
else:
|
else:
|
||||||
self.logger.warning(f"No index content found in response for {file_path}")
|
self.logger.warning(
|
||||||
|
f"No index content found in response for {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error processing response for {file_path}: {str(e)}")
|
self.logger.error(
|
||||||
|
f"Error processing response for {file_path}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error in batch completion: {str(e)}")
|
self.logger.error(f"Error in batch completion: {str(e)}")
|
||||||
@@ -171,7 +182,12 @@ Wrap your response in <index>...</index> tags.
|
|||||||
|
|
||||||
lemmatizer = WordNetLemmatizer()
|
lemmatizer = WordNetLemmatizer()
|
||||||
stop_words = set(stopwords.words("english")) - {
|
stop_words = set(stopwords.words("english")) - {
|
||||||
"how", "what", "when", "where", "why", "which",
|
"how",
|
||||||
|
"what",
|
||||||
|
"when",
|
||||||
|
"where",
|
||||||
|
"why",
|
||||||
|
"which",
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
@@ -222,7 +238,9 @@ Wrap your response in <index>...</index> tags.
|
|||||||
self.logger.info("Checking which .q.md files need (re)indexing...")
|
self.logger.info("Checking which .q.md files need (re)indexing...")
|
||||||
|
|
||||||
# Gather all .q.md files
|
# Gather all .q.md files
|
||||||
q_files = [self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")]
|
q_files = [
|
||||||
|
self.docs_dir / f for f in os.listdir(self.docs_dir) if f.endswith(".q.md")
|
||||||
|
]
|
||||||
|
|
||||||
# We'll store known (unchanged) facts in these lists
|
# We'll store known (unchanged) facts in these lists
|
||||||
existing_facts: List[str] = []
|
existing_facts: List[str] = []
|
||||||
@@ -243,7 +261,9 @@ Wrap your response in <index>...</index> tags.
|
|||||||
# Otherwise, load the existing cache and compare hash
|
# Otherwise, load the existing cache and compare hash
|
||||||
cache = self._load_or_create_token_cache(qf)
|
cache = self._load_or_create_token_cache(qf)
|
||||||
# If the .q.tokens was out of date (i.e. changed hash), we reindex
|
# If the .q.tokens was out of date (i.e. changed hash), we reindex
|
||||||
if len(cache["facts"]) == 0 or cache.get("content_hash") != _compute_file_hash(qf):
|
if len(cache["facts"]) == 0 or cache.get(
|
||||||
|
"content_hash"
|
||||||
|
) != _compute_file_hash(qf):
|
||||||
needSet.append(qf)
|
needSet.append(qf)
|
||||||
else:
|
else:
|
||||||
# File is unchanged → retrieve cached token data
|
# File is unchanged → retrieve cached token data
|
||||||
@@ -255,20 +275,29 @@ Wrap your response in <index>...</index> tags.
|
|||||||
if not needSet and not clear_cache:
|
if not needSet and not clear_cache:
|
||||||
# If no file needs reindexing, try loading existing index
|
# If no file needs reindexing, try loading existing index
|
||||||
if self.maybe_load_bm25_index(clear_cache=False):
|
if self.maybe_load_bm25_index(clear_cache=False):
|
||||||
self.logger.info("No new/changed .q.md files found. Using existing BM25 index.")
|
self.logger.info(
|
||||||
|
"No new/changed .q.md files found. Using existing BM25 index."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# If there's no existing index, we must build a fresh index from the old caches
|
# If there's no existing index, we must build a fresh index from the old caches
|
||||||
self.logger.info("No existing BM25 index found. Building from cached facts.")
|
self.logger.info(
|
||||||
|
"No existing BM25 index found. Building from cached facts."
|
||||||
|
)
|
||||||
if existing_facts:
|
if existing_facts:
|
||||||
self.logger.info(f"Building BM25 index with {len(existing_facts)} cached facts.")
|
self.logger.info(
|
||||||
|
f"Building BM25 index with {len(existing_facts)} cached facts."
|
||||||
|
)
|
||||||
self.bm25_index = BM25Okapi(existing_tokens)
|
self.bm25_index = BM25Okapi(existing_tokens)
|
||||||
self.tokenized_facts = existing_facts
|
self.tokenized_facts = existing_facts
|
||||||
with open(self.bm25_index_file, "wb") as f:
|
with open(self.bm25_index_file, "wb") as f:
|
||||||
pickle.dump({
|
pickle.dump(
|
||||||
|
{
|
||||||
"bm25_index": self.bm25_index,
|
"bm25_index": self.bm25_index,
|
||||||
"tokenized_facts": self.tokenized_facts
|
"tokenized_facts": self.tokenized_facts,
|
||||||
}, f)
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.warning("No facts found at all. Index remains empty.")
|
self.logger.warning("No facts found at all. Index remains empty.")
|
||||||
return
|
return
|
||||||
@@ -311,7 +340,9 @@ Wrap your response in <index>...</index> tags.
|
|||||||
self._save_token_cache(file, fresh_cache)
|
self._save_token_cache(file, fresh_cache)
|
||||||
|
|
||||||
mem_usage = process.memory_info().rss / 1024 / 1024
|
mem_usage = process.memory_info().rss / 1024 / 1024
|
||||||
self.logger.debug(f"Memory usage after {file.name}: {mem_usage:.2f}MB")
|
self.logger.debug(
|
||||||
|
f"Memory usage after {file.name}: {mem_usage:.2f}MB"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error processing {file}: {str(e)}")
|
self.logger.error(f"Error processing {file}: {str(e)}")
|
||||||
@@ -328,21 +359,28 @@ Wrap your response in <index>...</index> tags.
|
|||||||
all_tokens = existing_tokens + new_tokens
|
all_tokens = existing_tokens + new_tokens
|
||||||
|
|
||||||
# 3) Build BM25 index from combined facts
|
# 3) Build BM25 index from combined facts
|
||||||
self.logger.info(f"Building BM25 index with {len(all_facts)} total facts (old + new).")
|
self.logger.info(
|
||||||
|
f"Building BM25 index with {len(all_facts)} total facts (old + new)."
|
||||||
|
)
|
||||||
self.bm25_index = BM25Okapi(all_tokens)
|
self.bm25_index = BM25Okapi(all_tokens)
|
||||||
self.tokenized_facts = all_facts
|
self.tokenized_facts = all_facts
|
||||||
|
|
||||||
# 4) Save the updated BM25 index to disk
|
# 4) Save the updated BM25 index to disk
|
||||||
with open(self.bm25_index_file, "wb") as f:
|
with open(self.bm25_index_file, "wb") as f:
|
||||||
pickle.dump({
|
pickle.dump(
|
||||||
|
{
|
||||||
"bm25_index": self.bm25_index,
|
"bm25_index": self.bm25_index,
|
||||||
"tokenized_facts": self.tokenized_facts
|
"tokenized_facts": self.tokenized_facts,
|
||||||
}, f)
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
final_mem = process.memory_info().rss / 1024 / 1024
|
final_mem = process.memory_info().rss / 1024 / 1024
|
||||||
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
|
self.logger.info(f"Search index updated. Final memory usage: {final_mem:.2f}MB")
|
||||||
|
|
||||||
async def generate_index_files(self, force_generate_facts: bool = False, clear_bm25_cache: bool = False) -> None:
|
async def generate_index_files(
|
||||||
|
self, force_generate_facts: bool = False, clear_bm25_cache: bool = False
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Generate index files for all documents in parallel batches
|
Generate index files for all documents in parallel batches
|
||||||
|
|
||||||
@@ -353,15 +391,17 @@ Wrap your response in <index>...</index> tags.
|
|||||||
self.logger.info("Starting index generation for documentation files.")
|
self.logger.info("Starting index generation for documentation files.")
|
||||||
|
|
||||||
md_files = [
|
md_files = [
|
||||||
self.docs_dir / f for f in os.listdir(self.docs_dir)
|
self.docs_dir / f
|
||||||
if f.endswith('.md') and not any(f.endswith(x) for x in ['.q.md', '.xs.md'])
|
for f in os.listdir(self.docs_dir)
|
||||||
|
if f.endswith(".md") and not any(f.endswith(x) for x in [".q.md", ".xs.md"])
|
||||||
]
|
]
|
||||||
|
|
||||||
# Filter out files that already have .q files unless force=True
|
# Filter out files that already have .q files unless force=True
|
||||||
if not force_generate_facts:
|
if not force_generate_facts:
|
||||||
md_files = [
|
md_files = [
|
||||||
f for f in md_files
|
f
|
||||||
if not (self.docs_dir / f.name.replace('.md', '.q.md')).exists()
|
for f in md_files
|
||||||
|
if not (self.docs_dir / f.name.replace(".md", ".q.md")).exists()
|
||||||
]
|
]
|
||||||
|
|
||||||
if not md_files:
|
if not md_files:
|
||||||
@@ -370,7 +410,9 @@ Wrap your response in <index>...</index> tags.
|
|||||||
# Process documents in batches
|
# Process documents in batches
|
||||||
for i in range(0, len(md_files), self.batch_size):
|
for i in range(0, len(md_files), self.batch_size):
|
||||||
batch = md_files[i : i + self.batch_size]
|
batch = md_files[i : i + self.batch_size]
|
||||||
self.logger.info(f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}")
|
self.logger.info(
|
||||||
|
f"Processing batch {i//self.batch_size + 1}/{(len(md_files)//self.batch_size) + 1}"
|
||||||
|
)
|
||||||
await self._process_document_batch(batch)
|
await self._process_document_batch(batch)
|
||||||
|
|
||||||
self.logger.info("Index generation complete, building/updating search index.")
|
self.logger.info("Index generation complete, building/updating search index.")
|
||||||
@@ -378,21 +420,31 @@ Wrap your response in <index>...</index> tags.
|
|||||||
|
|
||||||
def generate(self, sections: List[str], mode: str = "extended") -> str:
|
def generate(self, sections: List[str], mode: str = "extended") -> str:
|
||||||
# Get all markdown files
|
# Get all markdown files
|
||||||
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + \
|
all_files = glob.glob(str(self.docs_dir / "[0-9]*.md")) + glob.glob(
|
||||||
glob.glob(str(self.docs_dir / "[0-9]*.xs.md"))
|
str(self.docs_dir / "[0-9]*.xs.md")
|
||||||
|
)
|
||||||
|
|
||||||
# Extract base names without extensions
|
# Extract base names without extensions
|
||||||
base_docs = {Path(f).name.split('.')[0] for f in all_files
|
base_docs = {
|
||||||
if not Path(f).name.endswith('.q.md')}
|
Path(f).name.split(".")[0]
|
||||||
|
for f in all_files
|
||||||
|
if not Path(f).name.endswith(".q.md")
|
||||||
|
}
|
||||||
|
|
||||||
# Filter by sections if provided
|
# Filter by sections if provided
|
||||||
if sections:
|
if sections:
|
||||||
base_docs = {doc for doc in base_docs
|
base_docs = {
|
||||||
if any(section.lower() in doc.lower() for section in sections)}
|
doc
|
||||||
|
for doc in base_docs
|
||||||
|
if any(section.lower() in doc.lower() for section in sections)
|
||||||
|
}
|
||||||
|
|
||||||
# Get file paths based on mode
|
# Get file paths based on mode
|
||||||
files = []
|
files = []
|
||||||
for doc in sorted(base_docs, key=lambda x: int(x.split('_')[0]) if x.split('_')[0].isdigit() else 999999):
|
for doc in sorted(
|
||||||
|
base_docs,
|
||||||
|
key=lambda x: int(x.split("_")[0]) if x.split("_")[0].isdigit() else 999999,
|
||||||
|
):
|
||||||
if mode == "condensed":
|
if mode == "condensed":
|
||||||
xs_file = self.docs_dir / f"{doc}.xs.md"
|
xs_file = self.docs_dir / f"{doc}.xs.md"
|
||||||
regular_file = self.docs_dir / f"{doc}.md"
|
regular_file = self.docs_dir / f"{doc}.md"
|
||||||
@@ -404,7 +456,7 @@ Wrap your response in <index>...</index> tags.
|
|||||||
content = []
|
content = []
|
||||||
for file in files:
|
for file in files:
|
||||||
try:
|
try:
|
||||||
with open(file, 'r', encoding='utf-8') as f:
|
with open(file, "r", encoding="utf-8") as f:
|
||||||
fname = Path(file).name
|
fname = Path(file).name
|
||||||
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
|
content.append(f"{'#'*20}\n# {fname}\n{'#'*20}\n\n{f.read()}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -443,15 +495,9 @@ Wrap your response in <index>...</index> tags.
|
|||||||
for file, _ in ranked_files:
|
for file, _ in ranked_files:
|
||||||
main_doc = str(file).replace(".q.md", ".md")
|
main_doc = str(file).replace(".q.md", ".md")
|
||||||
if os.path.exists(self.docs_dir / main_doc):
|
if os.path.exists(self.docs_dir / main_doc):
|
||||||
with open(self.docs_dir / main_doc, "r", encoding='utf-8') as f:
|
with open(self.docs_dir / main_doc, "r", encoding="utf-8") as f:
|
||||||
only_file_name = main_doc.split("/")[-1]
|
only_file_name = main_doc.split("/")[-1]
|
||||||
content = [
|
content = ["#" * 20, f"# {only_file_name}", "#" * 20, "", f.read()]
|
||||||
"#" * 20,
|
|
||||||
f"# {only_file_name}",
|
|
||||||
"#" * 20,
|
|
||||||
"",
|
|
||||||
f.read()
|
|
||||||
]
|
|
||||||
results.append("\n".join(content))
|
results.append("\n".join(content))
|
||||||
|
|
||||||
return "\n\n---\n\n".join(results)
|
return "\n\n---\n\n".join(results)
|
||||||
@@ -482,7 +528,9 @@ Wrap your response in <index>...</index> tags.
|
|||||||
if len(components) == 3:
|
if len(components) == 3:
|
||||||
code_ref = components[2].strip()
|
code_ref = components[2].strip()
|
||||||
code_tokens = self.preprocess_text(code_ref)
|
code_tokens = self.preprocess_text(code_ref)
|
||||||
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(query_tokens)
|
code_match_score = len(set(query_tokens) & set(code_tokens)) / len(
|
||||||
|
query_tokens
|
||||||
|
)
|
||||||
|
|
||||||
file_data[file_path]["total_score"] += score
|
file_data[file_path]["total_score"] += score
|
||||||
file_data[file_path]["match_count"] += 1
|
file_data[file_path]["match_count"] += 1
|
||||||
|
|||||||
@@ -2,41 +2,51 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Optional, Dict, Any, Tuple
|
from typing import Optional, Dict, Any, Tuple
|
||||||
from .models import MarkdownGenerationResult
|
from .models import MarkdownGenerationResult
|
||||||
from .html2text import CustomHTML2Text
|
from .html2text import CustomHTML2Text
|
||||||
from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter
|
from .content_filter_strategy import RelevantContentFilter
|
||||||
import re
|
import re
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
# Pre-compile the regex pattern
|
# Pre-compile the regex pattern
|
||||||
LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)')
|
LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)')
|
||||||
|
|
||||||
|
|
||||||
def fast_urljoin(base: str, url: str) -> str:
|
def fast_urljoin(base: str, url: str) -> str:
|
||||||
"""Fast URL joining for common cases."""
|
"""Fast URL joining for common cases."""
|
||||||
if url.startswith(('http://', 'https://', 'mailto:', '//')):
|
if url.startswith(("http://", "https://", "mailto:", "//")):
|
||||||
return url
|
return url
|
||||||
if url.startswith('/'):
|
if url.startswith("/"):
|
||||||
# Handle absolute paths
|
# Handle absolute paths
|
||||||
if base.endswith('/'):
|
if base.endswith("/"):
|
||||||
return base[:-1] + url
|
return base[:-1] + url
|
||||||
return base + url
|
return base + url
|
||||||
return urljoin(base, url)
|
return urljoin(base, url)
|
||||||
|
|
||||||
|
|
||||||
class MarkdownGenerationStrategy(ABC):
|
class MarkdownGenerationStrategy(ABC):
|
||||||
"""Abstract base class for markdown generation strategies."""
|
"""Abstract base class for markdown generation strategies."""
|
||||||
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content_filter: Optional[RelevantContentFilter] = None,
|
||||||
|
options: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
self.content_filter = content_filter
|
self.content_filter = content_filter
|
||||||
self.options = options or {}
|
self.options = options or {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_markdown(self,
|
def generate_markdown(
|
||||||
|
self,
|
||||||
cleaned_html: str,
|
cleaned_html: str,
|
||||||
base_url: str = "",
|
base_url: str = "",
|
||||||
html2text_options: Optional[Dict[str, Any]] = None,
|
html2text_options: Optional[Dict[str, Any]] = None,
|
||||||
content_filter: Optional[RelevantContentFilter] = None,
|
content_filter: Optional[RelevantContentFilter] = None,
|
||||||
citations: bool = True,
|
citations: bool = True,
|
||||||
**kwargs) -> MarkdownGenerationResult:
|
**kwargs,
|
||||||
|
) -> MarkdownGenerationResult:
|
||||||
"""Generate markdown from cleaned HTML."""
|
"""Generate markdown from cleaned HTML."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
||||||
"""
|
"""
|
||||||
Default implementation of markdown generation strategy.
|
Default implementation of markdown generation strategy.
|
||||||
@@ -54,10 +64,17 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
|||||||
Returns:
|
Returns:
|
||||||
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
|
MarkdownGenerationResult: Result containing raw markdown, fit markdown, fit HTML, and references markdown.
|
||||||
"""
|
"""
|
||||||
def __init__(self, content_filter: Optional[RelevantContentFilter] = None, options: Optional[Dict[str, Any]] = None):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content_filter: Optional[RelevantContentFilter] = None,
|
||||||
|
options: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
super().__init__(content_filter, options)
|
super().__init__(content_filter, options)
|
||||||
|
|
||||||
def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]:
|
def convert_links_to_citations(
|
||||||
|
self, markdown: str, base_url: str = ""
|
||||||
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Convert links in markdown to citations.
|
Convert links in markdown to citations.
|
||||||
|
|
||||||
@@ -87,24 +104,30 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
|||||||
text, url, title = match.groups()
|
text, url, title = match.groups()
|
||||||
|
|
||||||
# Use cached URL if available, otherwise compute and cache
|
# Use cached URL if available, otherwise compute and cache
|
||||||
if base_url and not url.startswith(('http://', 'https://', 'mailto:')):
|
if base_url and not url.startswith(("http://", "https://", "mailto:")):
|
||||||
if url not in url_cache:
|
if url not in url_cache:
|
||||||
url_cache[url] = fast_urljoin(base_url, url)
|
url_cache[url] = fast_urljoin(base_url, url)
|
||||||
url = url_cache[url]
|
url = url_cache[url]
|
||||||
|
|
||||||
if url not in link_map:
|
if url not in link_map:
|
||||||
desc = []
|
desc = []
|
||||||
if title: desc.append(title)
|
if title:
|
||||||
if text and text != title: desc.append(text)
|
desc.append(title)
|
||||||
|
if text and text != title:
|
||||||
|
desc.append(text)
|
||||||
link_map[url] = (counter, ": " + " - ".join(desc) if desc else "")
|
link_map[url] = (counter, ": " + " - ".join(desc) if desc else "")
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
num = link_map[url][0]
|
num = link_map[url][0]
|
||||||
parts.append(f"{text}⟨{num}⟩" if not match.group(0).startswith('!') else f"![{text}⟨{num}⟩]")
|
parts.append(
|
||||||
|
f"{text}⟨{num}⟩"
|
||||||
|
if not match.group(0).startswith("!")
|
||||||
|
else f"![{text}⟨{num}⟩]"
|
||||||
|
)
|
||||||
last_end = match.end()
|
last_end = match.end()
|
||||||
|
|
||||||
parts.append(markdown[last_end:])
|
parts.append(markdown[last_end:])
|
||||||
converted_text = ''.join(parts)
|
converted_text = "".join(parts)
|
||||||
|
|
||||||
# Pre-build reference strings
|
# Pre-build reference strings
|
||||||
references = ["\n\n## References\n\n"]
|
references = ["\n\n## References\n\n"]
|
||||||
@@ -113,16 +136,18 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
|||||||
for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0])
|
for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0])
|
||||||
)
|
)
|
||||||
|
|
||||||
return converted_text, ''.join(references)
|
return converted_text, "".join(references)
|
||||||
|
|
||||||
def generate_markdown(self,
|
def generate_markdown(
|
||||||
|
self,
|
||||||
cleaned_html: str,
|
cleaned_html: str,
|
||||||
base_url: str = "",
|
base_url: str = "",
|
||||||
html2text_options: Optional[Dict[str, Any]] = None,
|
html2text_options: Optional[Dict[str, Any]] = None,
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: Optional[Dict[str, Any]] = None,
|
||||||
content_filter: Optional[RelevantContentFilter] = None,
|
content_filter: Optional[RelevantContentFilter] = None,
|
||||||
citations: bool = True,
|
citations: bool = True,
|
||||||
**kwargs) -> MarkdownGenerationResult:
|
**kwargs,
|
||||||
|
) -> MarkdownGenerationResult:
|
||||||
"""
|
"""
|
||||||
Generate markdown with citations from cleaned HTML.
|
Generate markdown with citations from cleaned HTML.
|
||||||
|
|
||||||
@@ -147,14 +172,14 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
|||||||
# Initialize HTML2Text with default options for better conversion
|
# Initialize HTML2Text with default options for better conversion
|
||||||
h = CustomHTML2Text(baseurl=base_url)
|
h = CustomHTML2Text(baseurl=base_url)
|
||||||
default_options = {
|
default_options = {
|
||||||
'body_width': 0, # Disable text wrapping
|
"body_width": 0, # Disable text wrapping
|
||||||
'ignore_emphasis': False,
|
"ignore_emphasis": False,
|
||||||
'ignore_links': False,
|
"ignore_links": False,
|
||||||
'ignore_images': False,
|
"ignore_images": False,
|
||||||
'protect_links': True,
|
"protect_links": True,
|
||||||
'single_line_break': True,
|
"single_line_break": True,
|
||||||
'mark_code': True,
|
"mark_code": True,
|
||||||
'escape_snob': False
|
"escape_snob": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update with custom options if provided
|
# Update with custom options if provided
|
||||||
@@ -179,16 +204,17 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raw_markdown = f"Error converting HTML to markdown: {str(e)}"
|
raw_markdown = f"Error converting HTML to markdown: {str(e)}"
|
||||||
|
|
||||||
raw_markdown = raw_markdown.replace(' ```', '```')
|
raw_markdown = raw_markdown.replace(" ```", "```")
|
||||||
|
|
||||||
# Convert links to citations
|
# Convert links to citations
|
||||||
markdown_with_citations: str = raw_markdown
|
markdown_with_citations: str = raw_markdown
|
||||||
references_markdown: str = ""
|
references_markdown: str = ""
|
||||||
if citations:
|
if citations:
|
||||||
try:
|
try:
|
||||||
markdown_with_citations, references_markdown = self.convert_links_to_citations(
|
(
|
||||||
raw_markdown, base_url
|
markdown_with_citations,
|
||||||
)
|
references_markdown,
|
||||||
|
) = self.convert_links_to_citations(raw_markdown, base_url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
markdown_with_citations = raw_markdown
|
markdown_with_citations = raw_markdown
|
||||||
references_markdown = f"Error generating citations: {str(e)}"
|
references_markdown = f"Error generating citations: {str(e)}"
|
||||||
@@ -200,7 +226,9 @@ class DefaultMarkdownGenerator(MarkdownGenerationStrategy):
|
|||||||
try:
|
try:
|
||||||
content_filter = content_filter or self.content_filter
|
content_filter = content_filter or self.content_filter
|
||||||
filtered_html = content_filter.filter_content(cleaned_html)
|
filtered_html = content_filter.filter_content(cleaned_html)
|
||||||
filtered_html = '\n'.join('<div>{}</div>'.format(s) for s in filtered_html)
|
filtered_html = "\n".join(
|
||||||
|
"<div>{}</div>".format(s) for s in filtered_html
|
||||||
|
)
|
||||||
fit_markdown = h.handle(filtered_html)
|
fit_markdown = h.handle(filtered_html)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
fit_markdown = f"Error generating fit markdown: {str(e)}"
|
fit_markdown = f"Error generating fit markdown: {str(e)}"
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import xxhash
|
import xxhash
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from .async_logger import AsyncLogger, LogLevel
|
from .async_logger import AsyncLogger, LogLevel
|
||||||
|
|
||||||
@@ -17,6 +15,7 @@ logger = AsyncLogger(log_level=LogLevel.DEBUG, verbose=True)
|
|||||||
# logging.basicConfig(level=logging.INFO)
|
# logging.basicConfig(level=logging.INFO)
|
||||||
# logger = logging.getLogger(__name__)
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseMigration:
|
class DatabaseMigration:
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
@@ -24,11 +23,11 @@ class DatabaseMigration:
|
|||||||
|
|
||||||
def _ensure_content_dirs(self, base_path: str) -> dict:
|
def _ensure_content_dirs(self, base_path: str) -> dict:
|
||||||
dirs = {
|
dirs = {
|
||||||
'html': 'html_content',
|
"html": "html_content",
|
||||||
'cleaned': 'cleaned_html',
|
"cleaned": "cleaned_html",
|
||||||
'markdown': 'markdown_content',
|
"markdown": "markdown_content",
|
||||||
'extracted': 'extracted_content',
|
"extracted": "extracted_content",
|
||||||
'screenshots': 'screenshots'
|
"screenshots": "screenshots",
|
||||||
}
|
}
|
||||||
content_paths = {}
|
content_paths = {}
|
||||||
for key, dirname in dirs.items():
|
for key, dirname in dirs.items():
|
||||||
@@ -52,7 +51,7 @@ class DatabaseMigration:
|
|||||||
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
file_path = os.path.join(self.content_paths[content_type], content_hash)
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
|
async with aiofiles.open(file_path, "w", encoding="utf-8") as f:
|
||||||
await f.write(content)
|
await f.write(content)
|
||||||
|
|
||||||
return content_hash
|
return content_hash
|
||||||
@@ -66,24 +65,36 @@ class DatabaseMigration:
|
|||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
# Get all rows
|
# Get all rows
|
||||||
async with db.execute(
|
async with db.execute(
|
||||||
'''SELECT url, html, cleaned_html, markdown,
|
"""SELECT url, html, cleaned_html, markdown,
|
||||||
extracted_content, screenshot FROM crawled_data'''
|
extracted_content, screenshot FROM crawled_data"""
|
||||||
) as cursor:
|
) as cursor:
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
migrated_count = 0
|
migrated_count = 0
|
||||||
for row in rows:
|
for row in rows:
|
||||||
url, html, cleaned_html, markdown, extracted_content, screenshot = row
|
(
|
||||||
|
url,
|
||||||
|
html,
|
||||||
|
cleaned_html,
|
||||||
|
markdown,
|
||||||
|
extracted_content,
|
||||||
|
screenshot,
|
||||||
|
) = row
|
||||||
|
|
||||||
# Store content in files and get hashes
|
# Store content in files and get hashes
|
||||||
html_hash = await self._store_content(html, 'html')
|
html_hash = await self._store_content(html, "html")
|
||||||
cleaned_hash = await self._store_content(cleaned_html, 'cleaned')
|
cleaned_hash = await self._store_content(cleaned_html, "cleaned")
|
||||||
markdown_hash = await self._store_content(markdown, 'markdown')
|
markdown_hash = await self._store_content(markdown, "markdown")
|
||||||
extracted_hash = await self._store_content(extracted_content, 'extracted')
|
extracted_hash = await self._store_content(
|
||||||
screenshot_hash = await self._store_content(screenshot, 'screenshots')
|
extracted_content, "extracted"
|
||||||
|
)
|
||||||
|
screenshot_hash = await self._store_content(
|
||||||
|
screenshot, "screenshots"
|
||||||
|
)
|
||||||
|
|
||||||
# Update database with hashes
|
# Update database with hashes
|
||||||
await db.execute('''
|
await db.execute(
|
||||||
|
"""
|
||||||
UPDATE crawled_data
|
UPDATE crawled_data
|
||||||
SET html = ?,
|
SET html = ?,
|
||||||
cleaned_html = ?,
|
cleaned_html = ?,
|
||||||
@@ -91,26 +102,37 @@ class DatabaseMigration:
|
|||||||
extracted_content = ?,
|
extracted_content = ?,
|
||||||
screenshot = ?
|
screenshot = ?
|
||||||
WHERE url = ?
|
WHERE url = ?
|
||||||
''', (html_hash, cleaned_hash, markdown_hash,
|
""",
|
||||||
extracted_hash, screenshot_hash, url))
|
(
|
||||||
|
html_hash,
|
||||||
|
cleaned_hash,
|
||||||
|
markdown_hash,
|
||||||
|
extracted_hash,
|
||||||
|
screenshot_hash,
|
||||||
|
url,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
migrated_count += 1
|
migrated_count += 1
|
||||||
if migrated_count % 100 == 0:
|
if migrated_count % 100 == 0:
|
||||||
logger.info(f"Migrated {migrated_count} records...", tag="INIT")
|
logger.info(f"Migrated {migrated_count} records...", tag="INIT")
|
||||||
|
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
logger.success(f"Migration completed. {migrated_count} records processed.", tag="COMPLETE")
|
logger.success(
|
||||||
|
f"Migration completed. {migrated_count} records processed.",
|
||||||
|
tag="COMPLETE",
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"Migration failed: {e}")
|
# logger.error(f"Migration failed: {e}")
|
||||||
logger.error(
|
logger.error(
|
||||||
message="Migration failed: {error}",
|
message="Migration failed: {error}",
|
||||||
tag="ERROR",
|
tag="ERROR",
|
||||||
params={"error": str(e)}
|
params={"error": str(e)},
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def backup_database(db_path: str) -> str:
|
async def backup_database(db_path: str) -> str:
|
||||||
"""Create backup of existing database"""
|
"""Create backup of existing database"""
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
@@ -118,7 +140,7 @@ async def backup_database(db_path: str) -> str:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Create backup with timestamp
|
# Create backup with timestamp
|
||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
backup_path = f"{db_path}.backup_{timestamp}"
|
backup_path = f"{db_path}.backup_{timestamp}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -132,12 +154,11 @@ async def backup_database(db_path: str) -> str:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"Backup failed: {e}")
|
# logger.error(f"Backup failed: {e}")
|
||||||
logger.error(
|
logger.error(
|
||||||
message="Migration failed: {error}",
|
message="Migration failed: {error}", tag="ERROR", params={"error": str(e)}
|
||||||
tag="ERROR",
|
|
||||||
params={"error": str(e)}
|
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def run_migration(db_path: Optional[str] = None):
|
async def run_migration(db_path: Optional[str] = None):
|
||||||
"""Run database migration"""
|
"""Run database migration"""
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
@@ -155,14 +176,19 @@ async def run_migration(db_path: Optional[str] = None):
|
|||||||
migration = DatabaseMigration(db_path)
|
migration = DatabaseMigration(db_path)
|
||||||
await migration.migrate_database()
|
await migration.migrate_database()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""CLI entry point for migration"""
|
"""CLI entry point for migration"""
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage')
|
|
||||||
parser.add_argument('--db-path', help='Custom database path')
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Migrate Crawl4AI database to file-based storage"
|
||||||
|
)
|
||||||
|
parser.add_argument("--db-path", help="Custom database path")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
asyncio.run(run_migration(args.db_path))
|
asyncio.run(run_migration(args.db_path))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -2,30 +2,32 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import subprocess, os
|
import subprocess, os
|
||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
|
||||||
from .model_loader import *
|
from .model_loader import *
|
||||||
import argparse
|
import argparse
|
||||||
import urllib.request
|
|
||||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||||
|
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_available_memory(device):
|
def get_available_memory(device):
|
||||||
import torch
|
import torch
|
||||||
if device.type == 'cuda':
|
|
||||||
|
if device.type == "cuda":
|
||||||
return torch.cuda.get_device_properties(device).total_memory
|
return torch.cuda.get_device_properties(device).total_memory
|
||||||
elif device.type == 'mps':
|
elif device.type == "mps":
|
||||||
return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
|
return 48 * 1024**3 # Assuming 8GB for MPS, as a conservative estimate
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def calculate_batch_size(device):
|
def calculate_batch_size(device):
|
||||||
available_memory = get_available_memory(device)
|
available_memory = get_available_memory(device)
|
||||||
|
|
||||||
if device.type == 'cpu':
|
if device.type == "cpu":
|
||||||
return 16
|
return 16
|
||||||
elif device.type in ['cuda', 'mps']:
|
elif device.type in ["cuda", "mps"]:
|
||||||
# Adjust these thresholds based on your model size and available memory
|
# Adjust these thresholds based on your model size and available memory
|
||||||
if available_memory >= 31 * 1024**3: # > 32GB
|
if available_memory >= 31 * 1024**3: # > 32GB
|
||||||
return 256
|
return 256
|
||||||
@@ -38,39 +40,48 @@ def calculate_batch_size(device):
|
|||||||
else:
|
else:
|
||||||
return 16 # Default batch size
|
return 16 # Default batch size
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_device():
|
def get_device():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device('cuda')
|
device = torch.device("cuda")
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device = torch.device('mps')
|
device = torch.device("mps")
|
||||||
else:
|
else:
|
||||||
device = torch.device('cpu')
|
device = torch.device("cpu")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def set_model_device(model):
|
def set_model_device(model):
|
||||||
device = get_device()
|
device = get_device()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
return model, device
|
return model, device
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_home_folder():
|
def get_home_folder():
|
||||||
home_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
|
home_folder = os.path.join(
|
||||||
|
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
|
||||||
|
)
|
||||||
os.makedirs(home_folder, exist_ok=True)
|
os.makedirs(home_folder, exist_ok=True)
|
||||||
os.makedirs(f"{home_folder}/cache", exist_ok=True)
|
os.makedirs(f"{home_folder}/cache", exist_ok=True)
|
||||||
os.makedirs(f"{home_folder}/models", exist_ok=True)
|
os.makedirs(f"{home_folder}/models", exist_ok=True)
|
||||||
return home_folder
|
return home_folder
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_bert_base_uncased():
|
def load_bert_base_uncased():
|
||||||
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
from transformers import BertTokenizer, BertModel
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
|
||||||
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=None)
|
||||||
|
model = BertModel.from_pretrained("bert-base-uncased", resume_download=None)
|
||||||
model.eval()
|
model.eval()
|
||||||
model, device = set_model_device(model)
|
model, device = set_model_device(model)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
|
def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
|
||||||
"""Load the Hugging Face model for embedding.
|
"""Load the Hugging Face model for embedding.
|
||||||
@@ -81,30 +92,35 @@ def load_HF_embedding_model(model_name="BAAI/bge-small-en-v1.5") -> tuple:
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: The tokenizer and model.
|
tuple: The tokenizer and model.
|
||||||
"""
|
"""
|
||||||
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=None)
|
||||||
model = AutoModel.from_pretrained(model_name, resume_download=None)
|
model = AutoModel.from_pretrained(model_name, resume_download=None)
|
||||||
model.eval()
|
model.eval()
|
||||||
model, device = set_model_device(model)
|
model, device = set_model_device(model)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_text_classifier():
|
def load_text_classifier():
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
import torch
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
"dstefa/roberta-base_topic_classification_nyt_news"
|
||||||
|
)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"dstefa/roberta-base_topic_classification_nyt_news"
|
||||||
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
model, device = set_model_device(model)
|
model, device = set_model_device(model)
|
||||||
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_text_multilabel_classifier():
|
def load_text_multilabel_classifier():
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
import numpy as np
|
|
||||||
from scipy.special import expit
|
from scipy.special import expit
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -117,17 +133,26 @@ def load_text_multilabel_classifier():
|
|||||||
# device = torch.device("cpu")
|
# device = torch.device("cpu")
|
||||||
# # return load_spacy_model(), torch.device("cpu")
|
# # return load_spacy_model(), torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
MODEL, resume_download=None
|
||||||
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
model, device = set_model_device(model)
|
model, device = set_model_device(model)
|
||||||
class_mapping = model.config.id2label
|
class_mapping = model.config.id2label
|
||||||
|
|
||||||
def _classifier(texts, threshold=0.5, max_length=64):
|
def _classifier(texts, threshold=0.5, max_length=64):
|
||||||
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
tokens = tokenizer(
|
||||||
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
|
texts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
)
|
||||||
|
tokens = {
|
||||||
|
key: val.to(device) for key, val in tokens.items()
|
||||||
|
} # Move tokens to the selected device
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(**tokens)
|
output = model(**tokens)
|
||||||
@@ -138,25 +163,31 @@ def load_text_multilabel_classifier():
|
|||||||
|
|
||||||
batch_labels = []
|
batch_labels = []
|
||||||
for prediction in predictions:
|
for prediction in predictions:
|
||||||
labels = [class_mapping[i] for i, value in enumerate(prediction) if value == 1]
|
labels = [
|
||||||
|
class_mapping[i] for i, value in enumerate(prediction) if value == 1
|
||||||
|
]
|
||||||
batch_labels.append(labels)
|
batch_labels.append(labels)
|
||||||
|
|
||||||
return batch_labels
|
return batch_labels
|
||||||
|
|
||||||
return _classifier, device
|
return _classifier, device
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_nltk_punkt():
|
def load_nltk_punkt():
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nltk.data.find('tokenizers/punkt')
|
nltk.data.find("tokenizers/punkt")
|
||||||
except LookupError:
|
except LookupError:
|
||||||
nltk.download('punkt')
|
nltk.download("punkt")
|
||||||
return nltk.data.find('tokenizers/punkt')
|
return nltk.data.find("tokenizers/punkt")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_spacy_model():
|
def load_spacy_model():
|
||||||
import spacy
|
import spacy
|
||||||
|
|
||||||
name = "models/reuters"
|
name = "models/reuters"
|
||||||
home_folder = get_home_folder()
|
home_folder = get_home_folder()
|
||||||
model_folder = Path(home_folder) / name
|
model_folder = Path(home_folder) / name
|
||||||
@@ -176,7 +207,9 @@ def load_spacy_model():
|
|||||||
if model_folder.exists():
|
if model_folder.exists():
|
||||||
shutil.rmtree(model_folder)
|
shutil.rmtree(model_folder)
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
print("[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:")
|
print(
|
||||||
|
"[WARNING] Unable to remove existing folders. Please manually delete the following folders and try again:"
|
||||||
|
)
|
||||||
print(f"- {repo_folder}")
|
print(f"- {repo_folder}")
|
||||||
print(f"- {model_folder}")
|
print(f"- {model_folder}")
|
||||||
return None
|
return None
|
||||||
@@ -187,7 +220,7 @@ def load_spacy_model():
|
|||||||
["git", "clone", "-b", branch, repo_url, str(repo_folder)],
|
["git", "clone", "-b", branch, repo_url, str(repo_folder)],
|
||||||
stdout=subprocess.DEVNULL,
|
stdout=subprocess.DEVNULL,
|
||||||
stderr=subprocess.DEVNULL,
|
stderr=subprocess.DEVNULL,
|
||||||
check=True
|
check=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the models directory if it doesn't exist
|
# Create the models directory if it doesn't exist
|
||||||
@@ -215,6 +248,7 @@ def load_spacy_model():
|
|||||||
print(f"Error loading spacy model: {e}")
|
print(f"Error loading spacy model: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def download_all_models(remove_existing=False):
|
def download_all_models(remove_existing=False):
|
||||||
"""Download all models required for Crawl4AI."""
|
"""Download all models required for Crawl4AI."""
|
||||||
if remove_existing:
|
if remove_existing:
|
||||||
@@ -243,14 +277,20 @@ def download_all_models(remove_existing=False):
|
|||||||
load_nltk_punkt()
|
load_nltk_punkt()
|
||||||
print("[LOG] ✅ All models downloaded successfully.")
|
print("[LOG] ✅ All models downloaded successfully.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("[LOG] Welcome to the Crawl4AI Model Downloader!")
|
print("[LOG] Welcome to the Crawl4AI Model Downloader!")
|
||||||
print("[LOG] This script will download all the models required for Crawl4AI.")
|
print("[LOG] This script will download all the models required for Crawl4AI.")
|
||||||
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
|
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
|
||||||
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
|
parser.add_argument(
|
||||||
|
"--remove-existing",
|
||||||
|
action="store_true",
|
||||||
|
help="Remove existing models before downloading",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
download_all_models(remove_existing=args.remove_existing)
|
download_all_models(remove_existing=args.remove_existing)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,18 +1,12 @@
|
|||||||
from pydantic import BaseModel, HttpUrl
|
from pydantic import BaseModel, HttpUrl
|
||||||
from typing import List, Dict, Optional, Callable, Awaitable, Union, Tuple, Any
|
from typing import List, Dict, Optional, Callable, Awaitable, Union, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from .ssl_certificate import SSLCertificate
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from .ssl_certificate import SSLCertificate
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Dispatcher Models
|
# Dispatcher Models
|
||||||
###############################
|
###############################
|
||||||
@@ -22,6 +16,7 @@ class DomainState:
|
|||||||
current_delay: float = 0
|
current_delay: float = 0
|
||||||
fail_count: int = 0
|
fail_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CrawlerTaskResult:
|
class CrawlerTaskResult:
|
||||||
task_id: str
|
task_id: str
|
||||||
@@ -33,12 +28,14 @@ class CrawlerTaskResult:
|
|||||||
end_time: datetime
|
end_time: datetime
|
||||||
error_message: str = ""
|
error_message: str = ""
|
||||||
|
|
||||||
|
|
||||||
class CrawlStatus(Enum):
|
class CrawlStatus(Enum):
|
||||||
QUEUED = "QUEUED"
|
QUEUED = "QUEUED"
|
||||||
IN_PROGRESS = "IN_PROGRESS"
|
IN_PROGRESS = "IN_PROGRESS"
|
||||||
COMPLETED = "COMPLETED"
|
COMPLETED = "COMPLETED"
|
||||||
FAILED = "FAILED"
|
FAILED = "FAILED"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CrawlStats:
|
class CrawlStats:
|
||||||
task_id: str
|
task_id: str
|
||||||
@@ -58,10 +55,12 @@ class CrawlStats:
|
|||||||
duration = end - self.start_time
|
duration = end - self.start_time
|
||||||
return str(timedelta(seconds=int(duration.total_seconds())))
|
return str(timedelta(seconds=int(duration.total_seconds())))
|
||||||
|
|
||||||
|
|
||||||
class DisplayMode(Enum):
|
class DisplayMode(Enum):
|
||||||
DETAILED = "DETAILED"
|
DETAILED = "DETAILED"
|
||||||
AGGREGATED = "AGGREGATED"
|
AGGREGATED = "AGGREGATED"
|
||||||
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Crawler Models
|
# Crawler Models
|
||||||
###############################
|
###############################
|
||||||
@@ -78,6 +77,7 @@ class UrlModel(BaseModel):
|
|||||||
url: HttpUrl
|
url: HttpUrl
|
||||||
forced: bool = False
|
forced: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MarkdownGenerationResult(BaseModel):
|
class MarkdownGenerationResult(BaseModel):
|
||||||
raw_markdown: str
|
raw_markdown: str
|
||||||
markdown_with_citations: str
|
markdown_with_citations: str
|
||||||
@@ -85,6 +85,7 @@ class MarkdownGenerationResult(BaseModel):
|
|||||||
fit_markdown: Optional[str] = None
|
fit_markdown: Optional[str] = None
|
||||||
fit_html: Optional[str] = None
|
fit_html: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class DispatchResult(BaseModel):
|
class DispatchResult(BaseModel):
|
||||||
task_id: str
|
task_id: str
|
||||||
memory_usage: float
|
memory_usage: float
|
||||||
@@ -92,6 +93,8 @@ class DispatchResult(BaseModel):
|
|||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: datetime
|
end_time: datetime
|
||||||
error_message: str = ""
|
error_message: str = ""
|
||||||
|
|
||||||
|
|
||||||
class CrawlResult(BaseModel):
|
class CrawlResult(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
html: str
|
html: str
|
||||||
@@ -114,9 +117,11 @@ class CrawlResult(BaseModel):
|
|||||||
status_code: Optional[int] = None
|
status_code: Optional[int] = None
|
||||||
ssl_certificate: Optional[SSLCertificate] = None
|
ssl_certificate: Optional[SSLCertificate] = None
|
||||||
dispatch_result: Optional[DispatchResult] = None
|
dispatch_result: Optional[DispatchResult] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class AsyncCrawlResponse(BaseModel):
|
class AsyncCrawlResponse(BaseModel):
|
||||||
html: str
|
html: str
|
||||||
response_headers: Dict[str, str]
|
response_headers: Dict[str, str]
|
||||||
@@ -130,6 +135,7 @@ class AsyncCrawlResponse(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Scraping Models
|
# Scraping Models
|
||||||
###############################
|
###############################
|
||||||
@@ -143,21 +149,29 @@ class MediaItem(BaseModel):
|
|||||||
format: Optional[str] = None
|
format: Optional[str] = None
|
||||||
width: Optional[int] = None
|
width: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class Link(BaseModel):
|
class Link(BaseModel):
|
||||||
href: str
|
href: str
|
||||||
text: str
|
text: str
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
base_domain: str
|
base_domain: str
|
||||||
|
|
||||||
|
|
||||||
class Media(BaseModel):
|
class Media(BaseModel):
|
||||||
images: List[MediaItem] = []
|
images: List[MediaItem] = []
|
||||||
videos: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Video model if needed
|
videos: List[
|
||||||
audios: List[MediaItem] = [] # Using MediaItem model for now, can be extended with Audio model if needed
|
MediaItem
|
||||||
|
] = [] # Using MediaItem model for now, can be extended with Video model if needed
|
||||||
|
audios: List[
|
||||||
|
MediaItem
|
||||||
|
] = [] # Using MediaItem model for now, can be extended with Audio model if needed
|
||||||
|
|
||||||
|
|
||||||
class Links(BaseModel):
|
class Links(BaseModel):
|
||||||
internal: List[Link] = []
|
internal: List[Link] = []
|
||||||
external: List[Link] = []
|
external: List[Link] = []
|
||||||
|
|
||||||
|
|
||||||
class ScrapingResult(BaseModel):
|
class ScrapingResult(BaseModel):
|
||||||
cleaned_html: str
|
cleaned_html: str
|
||||||
success: bool
|
success: bool
|
||||||
|
|||||||
@@ -26,11 +26,12 @@ class SSLCertificate:
|
|||||||
export_as_json() -> Dict[str, Any]: Export the certificate as JSON format.
|
export_as_json() -> Dict[str, Any]: Export the certificate as JSON format.
|
||||||
export_as_text() -> str: Export the certificate as text format.
|
export_as_text() -> str: Export the certificate as text format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cert_info: Dict[str, Any]):
|
def __init__(self, cert_info: Dict[str, Any]):
|
||||||
self._cert_info = self._decode_cert_data(cert_info)
|
self._cert_info = self._decode_cert_data(cert_info)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_url(url: str, timeout: int = 10) -> Optional['SSLCertificate']:
|
def from_url(url: str, timeout: int = 10) -> Optional["SSLCertificate"]:
|
||||||
"""
|
"""
|
||||||
Create SSLCertificate instance from a URL.
|
Create SSLCertificate instance from a URL.
|
||||||
|
|
||||||
@@ -43,14 +44,16 @@ class SSLCertificate:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
hostname = urlparse(url).netloc
|
hostname = urlparse(url).netloc
|
||||||
if ':' in hostname:
|
if ":" in hostname:
|
||||||
hostname = hostname.split(':')[0]
|
hostname = hostname.split(":")[0]
|
||||||
|
|
||||||
context = ssl.create_default_context()
|
context = ssl.create_default_context()
|
||||||
with socket.create_connection((hostname, 443), timeout=timeout) as sock:
|
with socket.create_connection((hostname, 443), timeout=timeout) as sock:
|
||||||
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
|
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
|
||||||
cert_binary = ssock.getpeercert(binary_form=True)
|
cert_binary = ssock.getpeercert(binary_form=True)
|
||||||
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert_binary)
|
x509 = OpenSSL.crypto.load_certificate(
|
||||||
|
OpenSSL.crypto.FILETYPE_ASN1, cert_binary
|
||||||
|
)
|
||||||
|
|
||||||
cert_info = {
|
cert_info = {
|
||||||
"subject": dict(x509.get_subject().get_components()),
|
"subject": dict(x509.get_subject().get_components()),
|
||||||
@@ -61,32 +64,33 @@ class SSLCertificate:
|
|||||||
"not_after": x509.get_notAfter(),
|
"not_after": x509.get_notAfter(),
|
||||||
"fingerprint": x509.digest("sha256").hex(),
|
"fingerprint": x509.digest("sha256").hex(),
|
||||||
"signature_algorithm": x509.get_signature_algorithm(),
|
"signature_algorithm": x509.get_signature_algorithm(),
|
||||||
"raw_cert": base64.b64encode(cert_binary)
|
"raw_cert": base64.b64encode(cert_binary),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add extensions
|
# Add extensions
|
||||||
extensions = []
|
extensions = []
|
||||||
for i in range(x509.get_extension_count()):
|
for i in range(x509.get_extension_count()):
|
||||||
ext = x509.get_extension(i)
|
ext = x509.get_extension(i)
|
||||||
extensions.append({
|
extensions.append(
|
||||||
"name": ext.get_short_name(),
|
{"name": ext.get_short_name(), "value": str(ext)}
|
||||||
"value": str(ext)
|
)
|
||||||
})
|
|
||||||
cert_info["extensions"] = extensions
|
cert_info["extensions"] = extensions
|
||||||
|
|
||||||
return SSLCertificate(cert_info)
|
return SSLCertificate(cert_info)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _decode_cert_data(data: Any) -> Any:
|
def _decode_cert_data(data: Any) -> Any:
|
||||||
"""Helper method to decode bytes in certificate data."""
|
"""Helper method to decode bytes in certificate data."""
|
||||||
if isinstance(data, bytes):
|
if isinstance(data, bytes):
|
||||||
return data.decode('utf-8')
|
return data.decode("utf-8")
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
return {
|
return {
|
||||||
(k.decode('utf-8') if isinstance(k, bytes) else k): SSLCertificate._decode_cert_data(v)
|
(
|
||||||
|
k.decode("utf-8") if isinstance(k, bytes) else k
|
||||||
|
): SSLCertificate._decode_cert_data(v)
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
@@ -105,7 +109,7 @@ class SSLCertificate:
|
|||||||
"""
|
"""
|
||||||
json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False)
|
json_str = json.dumps(self._cert_info, indent=2, ensure_ascii=False)
|
||||||
if filepath:
|
if filepath:
|
||||||
Path(filepath).write_text(json_str, encoding='utf-8')
|
Path(filepath).write_text(json_str, encoding="utf-8")
|
||||||
return None
|
return None
|
||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
@@ -122,18 +126,17 @@ class SSLCertificate:
|
|||||||
try:
|
try:
|
||||||
x509 = OpenSSL.crypto.load_certificate(
|
x509 = OpenSSL.crypto.load_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_ASN1,
|
OpenSSL.crypto.FILETYPE_ASN1,
|
||||||
base64.b64decode(self._cert_info['raw_cert'])
|
base64.b64decode(self._cert_info["raw_cert"]),
|
||||||
)
|
)
|
||||||
pem_data = OpenSSL.crypto.dump_certificate(
|
pem_data = OpenSSL.crypto.dump_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM, x509
|
||||||
x509
|
).decode("utf-8")
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
if filepath:
|
if filepath:
|
||||||
Path(filepath).write_text(pem_data, encoding='utf-8')
|
Path(filepath).write_text(pem_data, encoding="utf-8")
|
||||||
return None
|
return None
|
||||||
return pem_data
|
return pem_data
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]:
|
def to_der(self, filepath: Optional[str] = None) -> Optional[bytes]:
|
||||||
@@ -147,7 +150,7 @@ class SSLCertificate:
|
|||||||
Optional[bytes]: DER bytes if successful, None otherwise.
|
Optional[bytes]: DER bytes if successful, None otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
der_data = base64.b64decode(self._cert_info['raw_cert'])
|
der_data = base64.b64decode(self._cert_info["raw_cert"])
|
||||||
if filepath:
|
if filepath:
|
||||||
Path(filepath).write_bytes(der_data)
|
Path(filepath).write_bytes(der_data)
|
||||||
return None
|
return None
|
||||||
@@ -158,24 +161,24 @@ class SSLCertificate:
|
|||||||
@property
|
@property
|
||||||
def issuer(self) -> Dict[str, str]:
|
def issuer(self) -> Dict[str, str]:
|
||||||
"""Get certificate issuer information."""
|
"""Get certificate issuer information."""
|
||||||
return self._cert_info.get('issuer', {})
|
return self._cert_info.get("issuer", {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subject(self) -> Dict[str, str]:
|
def subject(self) -> Dict[str, str]:
|
||||||
"""Get certificate subject information."""
|
"""Get certificate subject information."""
|
||||||
return self._cert_info.get('subject', {})
|
return self._cert_info.get("subject", {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_from(self) -> str:
|
def valid_from(self) -> str:
|
||||||
"""Get certificate validity start date."""
|
"""Get certificate validity start date."""
|
||||||
return self._cert_info.get('not_before', '')
|
return self._cert_info.get("not_before", "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_until(self) -> str:
|
def valid_until(self) -> str:
|
||||||
"""Get certificate validity end date."""
|
"""Get certificate validity end date."""
|
||||||
return self._cert_info.get('not_after', '')
|
return self._cert_info.get("not_after", "")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fingerprint(self) -> str:
|
def fingerprint(self) -> str:
|
||||||
"""Get certificate fingerprint."""
|
"""Get certificate fingerprint."""
|
||||||
return self._cert_info.get('fingerprint', '')
|
return self._cert_info.get("fingerprint", "")
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class UserAgentGenerator:
|
|||||||
android_version: Optional[str] = None
|
android_version: Optional[str] = None
|
||||||
): Generates a random user agent string based on the specified parameters.
|
): Generates a random user agent string based on the specified parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Previous platform definitions remain the same...
|
# Previous platform definitions remain the same...
|
||||||
self.desktop_platforms = {
|
self.desktop_platforms = {
|
||||||
@@ -47,7 +48,7 @@ class UserAgentGenerator:
|
|||||||
"generic": "(X11; Linux x86_64)",
|
"generic": "(X11; Linux x86_64)",
|
||||||
"ubuntu": "(X11; Ubuntu; Linux x86_64)",
|
"ubuntu": "(X11; Ubuntu; Linux x86_64)",
|
||||||
"chrome_os": "(X11; CrOS x86_64 14541.0.0)",
|
"chrome_os": "(X11; CrOS x86_64 14541.0.0)",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
self.mobile_platforms = {
|
self.mobile_platforms = {
|
||||||
@@ -60,26 +61,14 @@ class UserAgentGenerator:
|
|||||||
"ios": {
|
"ios": {
|
||||||
"iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)",
|
"iphone": "(iPhone; CPU iPhone OS 16_5 like Mac OS X)",
|
||||||
"ipad": "(iPad; CPU OS 16_5 like Mac OS X)",
|
"ipad": "(iPad; CPU OS 16_5 like Mac OS X)",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Browser Combinations
|
# Browser Combinations
|
||||||
self.browser_combinations = {
|
self.browser_combinations = {
|
||||||
1: [
|
1: [["chrome"], ["firefox"], ["safari"], ["edge"]],
|
||||||
["chrome"],
|
2: [["gecko", "firefox"], ["chrome", "safari"], ["webkit", "safari"]],
|
||||||
["firefox"],
|
3: [["chrome", "safari", "edge"], ["webkit", "chrome", "safari"]],
|
||||||
["safari"],
|
|
||||||
["edge"]
|
|
||||||
],
|
|
||||||
2: [
|
|
||||||
["gecko", "firefox"],
|
|
||||||
["chrome", "safari"],
|
|
||||||
["webkit", "safari"]
|
|
||||||
],
|
|
||||||
3: [
|
|
||||||
["chrome", "safari", "edge"],
|
|
||||||
["webkit", "chrome", "safari"]
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Rendering Engines with versions
|
# Rendering Engines with versions
|
||||||
@@ -90,7 +79,7 @@ class UserAgentGenerator:
|
|||||||
"Gecko/20100101",
|
"Gecko/20100101",
|
||||||
"Gecko/20100101", # Firefox usually uses this constant version
|
"Gecko/20100101", # Firefox usually uses this constant version
|
||||||
"Gecko/2010010",
|
"Gecko/2010010",
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Browser Versions
|
# Browser Versions
|
||||||
@@ -170,12 +159,14 @@ class UserAgentGenerator:
|
|||||||
|
|
||||||
return browser_stack
|
return browser_stack
|
||||||
|
|
||||||
def generate(self,
|
def generate(
|
||||||
device_type: Optional[Literal['desktop', 'mobile']] = None,
|
self,
|
||||||
|
device_type: Optional[Literal["desktop", "mobile"]] = None,
|
||||||
os_type: Optional[str] = None,
|
os_type: Optional[str] = None,
|
||||||
device_brand: Optional[str] = None,
|
device_brand: Optional[str] = None,
|
||||||
browser_type: Optional[Literal['chrome', 'edge', 'safari', 'firefox']] = None,
|
browser_type: Optional[Literal["chrome", "edge", "safari", "firefox"]] = None,
|
||||||
num_browsers: int = 3) -> str:
|
num_browsers: int = 3,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a random user agent with specified constraints.
|
Generate a random user agent with specified constraints.
|
||||||
|
|
||||||
@@ -215,9 +206,13 @@ class UserAgentGenerator:
|
|||||||
|
|
||||||
def get_random_platform(self, device_type, os_type, device_brand):
|
def get_random_platform(self, device_type, os_type, device_brand):
|
||||||
"""Helper method to get random platform based on constraints"""
|
"""Helper method to get random platform based on constraints"""
|
||||||
platforms = self.desktop_platforms if device_type == 'desktop' else \
|
platforms = (
|
||||||
self.mobile_platforms if device_type == 'mobile' else \
|
self.desktop_platforms
|
||||||
{**self.desktop_platforms, **self.mobile_platforms}
|
if device_type == "desktop"
|
||||||
|
else self.mobile_platforms
|
||||||
|
if device_type == "mobile"
|
||||||
|
else {**self.desktop_platforms, **self.mobile_platforms}
|
||||||
|
)
|
||||||
|
|
||||||
if os_type:
|
if os_type:
|
||||||
for platform_group in [self.desktop_platforms, self.mobile_platforms]:
|
for platform_group in [self.desktop_platforms, self.mobile_platforms]:
|
||||||
@@ -233,10 +228,10 @@ class UserAgentGenerator:
|
|||||||
def parse_user_agent(self, user_agent: str) -> Dict[str, str]:
|
def parse_user_agent(self, user_agent: str) -> Dict[str, str]:
|
||||||
"""Parse a user agent string to extract browser and version information"""
|
"""Parse a user agent string to extract browser and version information"""
|
||||||
browsers = {
|
browsers = {
|
||||||
'chrome': r'Chrome/(\d+)',
|
"chrome": r"Chrome/(\d+)",
|
||||||
'edge': r'Edg/(\d+)',
|
"edge": r"Edg/(\d+)",
|
||||||
'safari': r'Version/(\d+)',
|
"safari": r"Version/(\d+)",
|
||||||
'firefox': r'Firefox/(\d+)'
|
"firefox": r"Firefox/(\d+)",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
@@ -255,25 +250,26 @@ class UserAgentGenerator:
|
|||||||
hints = []
|
hints = []
|
||||||
|
|
||||||
# Handle different browser combinations
|
# Handle different browser combinations
|
||||||
if 'chrome' in browsers:
|
if "chrome" in browsers:
|
||||||
hints.append(f'"Chromium";v="{browsers["chrome"]}"')
|
hints.append(f'"Chromium";v="{browsers["chrome"]}"')
|
||||||
hints.append('"Not_A Brand";v="8"')
|
hints.append('"Not_A Brand";v="8"')
|
||||||
|
|
||||||
if 'edge' in browsers:
|
if "edge" in browsers:
|
||||||
hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"')
|
hints.append(f'"Microsoft Edge";v="{browsers["edge"]}"')
|
||||||
else:
|
else:
|
||||||
hints.append(f'"Google Chrome";v="{browsers["chrome"]}"')
|
hints.append(f'"Google Chrome";v="{browsers["chrome"]}"')
|
||||||
|
|
||||||
elif 'firefox' in browsers:
|
elif "firefox" in browsers:
|
||||||
# Firefox doesn't typically send Sec-CH-UA
|
# Firefox doesn't typically send Sec-CH-UA
|
||||||
return '""'
|
return '""'
|
||||||
|
|
||||||
elif 'safari' in browsers:
|
elif "safari" in browsers:
|
||||||
# Safari's format for client hints
|
# Safari's format for client hints
|
||||||
hints.append(f'"Safari";v="{browsers["safari"]}"')
|
hints.append(f'"Safari";v="{browsers["safari"]}"')
|
||||||
hints.append('"Not_A Brand";v="8"')
|
hints.append('"Not_A Brand";v="8"')
|
||||||
|
|
||||||
return ', '.join(hints)
|
return ", ".join(hints)
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -281,7 +277,7 @@ if __name__ == "__main__":
|
|||||||
print(generator.generate())
|
print(generator.generate())
|
||||||
|
|
||||||
print("\nSingle browser (Chrome):")
|
print("\nSingle browser (Chrome):")
|
||||||
print(generator.generate(num_browsers=1, browser_type='chrome'))
|
print(generator.generate(num_browsers=1, browser_type="chrome"))
|
||||||
|
|
||||||
print("\nTwo browsers (Gecko/Firefox):")
|
print("\nTwo browsers (Gecko/Firefox):")
|
||||||
print(generator.generate(num_browsers=2))
|
print(generator.generate(num_browsers=2))
|
||||||
@@ -290,16 +286,14 @@ if __name__ == "__main__":
|
|||||||
print(generator.generate(num_browsers=3))
|
print(generator.generate(num_browsers=3))
|
||||||
|
|
||||||
print("\nFirefox on Linux:")
|
print("\nFirefox on Linux:")
|
||||||
print(generator.generate(
|
print(
|
||||||
device_type='desktop',
|
generator.generate(
|
||||||
os_type='linux',
|
device_type="desktop",
|
||||||
browser_type='firefox',
|
os_type="linux",
|
||||||
num_browsers=2
|
browser_type="firefox",
|
||||||
))
|
num_browsers=2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
print("\nChrome/Safari/Edge on Windows:")
|
print("\nChrome/Safari/Edge on Windows:")
|
||||||
print(generator.generate(
|
print(generator.generate(device_type="desktop", os_type="windows", num_browsers=3))
|
||||||
device_type='desktop',
|
|
||||||
os_type='windows',
|
|
||||||
num_browsers=3
|
|
||||||
))
|
|
||||||
|
|||||||
1036
crawl4ai/utils.py
1036
crawl4ai/utils.py
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,9 @@
|
|||||||
# version_manager.py
|
# version_manager.py
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
|
||||||
|
|
||||||
class VersionManager:
|
class VersionManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.home_dir = Path.home() / ".crawl4ai"
|
self.home_dir = Path.home() / ".crawl4ai"
|
||||||
@@ -27,4 +27,3 @@ class VersionManager:
|
|||||||
installed = self.get_installed_version()
|
installed = self.get_installed_version()
|
||||||
current = version.parse(__version__.__version__)
|
current = version.parse(__version__.__version__)
|
||||||
return installed is None or installed < current
|
return installed is None or installed < current
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import os, time
|
import os, time
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .models import UrlModel, CrawlResult
|
from .models import UrlModel, CrawlResult
|
||||||
from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db
|
from .database import init_db, get_cached_url, cache_url
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from .chunking_strategy import *
|
from .chunking_strategy import *
|
||||||
from .extraction_strategy import *
|
from .extraction_strategy import *
|
||||||
@@ -14,14 +15,27 @@ from .content_scraping_strategy import WebScrapingStrategy
|
|||||||
from .config import *
|
from .config import *
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".')
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message='Field "model_name" has conflict with protected namespace "model_".',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebCrawler:
|
class WebCrawler:
|
||||||
def __init__(self, crawler_strategy: CrawlerStrategy = None, always_by_pass_cache: bool = False, verbose: bool = False):
|
def __init__(
|
||||||
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(verbose=verbose)
|
self,
|
||||||
|
crawler_strategy: CrawlerStrategy = None,
|
||||||
|
always_by_pass_cache: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
self.crawler_strategy = crawler_strategy or LocalSeleniumCrawlerStrategy(
|
||||||
|
verbose=verbose
|
||||||
|
)
|
||||||
self.always_by_pass_cache = always_by_pass_cache
|
self.always_by_pass_cache = always_by_pass_cache
|
||||||
self.crawl4ai_folder = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
|
self.crawl4ai_folder = os.path.join(
|
||||||
|
os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai"
|
||||||
|
)
|
||||||
os.makedirs(self.crawl4ai_folder, exist_ok=True)
|
os.makedirs(self.crawl4ai_folder, exist_ok=True)
|
||||||
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
|
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
|
||||||
init_db()
|
init_db()
|
||||||
@@ -30,11 +44,11 @@ class WebCrawler:
|
|||||||
def warmup(self):
|
def warmup(self):
|
||||||
print("[LOG] 🌤️ Warming up the WebCrawler")
|
print("[LOG] 🌤️ Warming up the WebCrawler")
|
||||||
self.run(
|
self.run(
|
||||||
url='https://google.com/',
|
url="https://google.com/",
|
||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
extraction_strategy=NoExtractionStrategy(),
|
extraction_strategy=NoExtractionStrategy(),
|
||||||
bypass_cache=False,
|
bypass_cache=False,
|
||||||
verbose=False
|
verbose=False,
|
||||||
)
|
)
|
||||||
self.ready = True
|
self.ready = True
|
||||||
print("[LOG] 🌞 WebCrawler is ready to crawl")
|
print("[LOG] 🌞 WebCrawler is ready to crawl")
|
||||||
@@ -80,6 +94,7 @@ class WebCrawler:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[CrawlResult]:
|
) -> List[CrawlResult]:
|
||||||
extraction_strategy = extraction_strategy or NoExtractionStrategy()
|
extraction_strategy = extraction_strategy or NoExtractionStrategy()
|
||||||
|
|
||||||
def fetch_page_wrapper(url_model, *args, **kwargs):
|
def fetch_page_wrapper(url_model, *args, **kwargs):
|
||||||
return self.fetch_page(url_model, *args, **kwargs)
|
return self.fetch_page(url_model, *args, **kwargs)
|
||||||
|
|
||||||
@@ -150,12 +165,25 @@ class WebCrawler:
|
|||||||
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
|
html = sanitize_input_encode(self.crawler_strategy.crawl(url, **kwargs))
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds")
|
print(
|
||||||
|
f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
|
||||||
|
)
|
||||||
if screenshot:
|
if screenshot:
|
||||||
screenshot_data = self.crawler_strategy.take_screenshot()
|
screenshot_data = self.crawler_strategy.take_screenshot()
|
||||||
|
|
||||||
|
crawl_result = self.process_html(
|
||||||
crawl_result = self.process_html(url, html, extracted_content, word_count_threshold, extraction_strategy, chunking_strategy, css_selector, screenshot_data, verbose, bool(cached), **kwargs)
|
url,
|
||||||
|
html,
|
||||||
|
extracted_content,
|
||||||
|
word_count_threshold,
|
||||||
|
extraction_strategy,
|
||||||
|
chunking_strategy,
|
||||||
|
css_selector,
|
||||||
|
screenshot_data,
|
||||||
|
verbose,
|
||||||
|
bool(cached),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
crawl_result.success = bool(html)
|
crawl_result.success = bool(html)
|
||||||
return crawl_result
|
return crawl_result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -183,7 +211,11 @@ class WebCrawler:
|
|||||||
try:
|
try:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
scrapping_strategy = WebScrapingStrategy()
|
scrapping_strategy = WebScrapingStrategy()
|
||||||
extra_params = {k: v for k, v in kwargs.items() if k not in ["only_text", "image_description_min_word_threshold"]}
|
extra_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
if k not in ["only_text", "image_description_min_word_threshold"]
|
||||||
|
}
|
||||||
result = scrapping_strategy.scrap(
|
result = scrapping_strategy.scrap(
|
||||||
url,
|
url,
|
||||||
html,
|
html,
|
||||||
@@ -191,14 +223,17 @@ class WebCrawler:
|
|||||||
css_selector=css_selector,
|
css_selector=css_selector,
|
||||||
only_text=kwargs.get("only_text", False),
|
only_text=kwargs.get("only_text", False),
|
||||||
image_description_min_word_threshold=kwargs.get(
|
image_description_min_word_threshold=kwargs.get(
|
||||||
"image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
|
"image_description_min_word_threshold",
|
||||||
|
IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD,
|
||||||
),
|
),
|
||||||
**extra_params,
|
**extra_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
|
# result = get_content_of_website_optimized(url, html, word_count_threshold, css_selector=css_selector, only_text=kwargs.get("only_text", False))
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds")
|
print(
|
||||||
|
f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise ValueError(f"Failed to extract content from the website: {url}")
|
raise ValueError(f"Failed to extract content from the website: {url}")
|
||||||
@@ -213,14 +248,20 @@ class WebCrawler:
|
|||||||
|
|
||||||
if extracted_content is None:
|
if extracted_content is None:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}")
|
print(
|
||||||
|
f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}"
|
||||||
|
)
|
||||||
|
|
||||||
sections = chunking_strategy.chunk(markdown)
|
sections = chunking_strategy.chunk(markdown)
|
||||||
extracted_content = extraction_strategy.run(url, sections)
|
extracted_content = extraction_strategy.run(url, sections)
|
||||||
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
|
extracted_content = json.dumps(
|
||||||
|
extracted_content, indent=4, default=str, ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds.")
|
print(
|
||||||
|
f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
screenshot = None if not screenshot else screenshot
|
screenshot = None if not screenshot else screenshot
|
||||||
|
|
||||||
|
|||||||
@@ -9,12 +9,10 @@ from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
|||||||
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
async def extract_amazon_products():
|
async def extract_amazon_products():
|
||||||
# Initialize browser config
|
# Initialize browser config
|
||||||
browser_config = BrowserConfig(
|
browser_config = BrowserConfig(browser_type="chromium", headless=True)
|
||||||
browser_type="chromium",
|
|
||||||
headless=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize crawler config with JSON CSS extraction strategy
|
# Initialize crawler config with JSON CSS extraction strategy
|
||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(
|
||||||
@@ -27,57 +25,53 @@ async def extract_amazon_products():
|
|||||||
"name": "asin",
|
"name": "asin",
|
||||||
"selector": "",
|
"selector": "",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "data-asin"
|
"attribute": "data-asin",
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "title",
|
|
||||||
"selector": "h2 a span",
|
|
||||||
"type": "text"
|
|
||||||
},
|
},
|
||||||
|
{"name": "title", "selector": "h2 a span", "type": "text"},
|
||||||
{
|
{
|
||||||
"name": "url",
|
"name": "url",
|
||||||
"selector": "h2 a",
|
"selector": "h2 a",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "href"
|
"attribute": "href",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "image",
|
"name": "image",
|
||||||
"selector": ".s-image",
|
"selector": ".s-image",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "src"
|
"attribute": "src",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "rating",
|
"name": "rating",
|
||||||
"selector": ".a-icon-star-small .a-icon-alt",
|
"selector": ".a-icon-star-small .a-icon-alt",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "reviews_count",
|
"name": "reviews_count",
|
||||||
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "price",
|
"name": "price",
|
||||||
"selector": ".a-price .a-offscreen",
|
"selector": ".a-price .a-offscreen",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "original_price",
|
"name": "original_price",
|
||||||
"selector": ".a-price.a-text-price .a-offscreen",
|
"selector": ".a-price.a-text-price .a-offscreen",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "sponsored",
|
"name": "sponsored",
|
||||||
"selector": ".puis-sponsored-label-text",
|
"selector": ".puis-sponsored-label-text",
|
||||||
"type": "exists"
|
"type": "exists",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "delivery_info",
|
"name": "delivery_info",
|
||||||
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"multiple": True
|
"multiple": True,
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -105,10 +99,12 @@ async def extract_amazon_products():
|
|||||||
print(f"Rating: {product.get('rating')}")
|
print(f"Rating: {product.get('rating')}")
|
||||||
print(f"Reviews: {product.get('reviews_count')}")
|
print(f"Reviews: {product.get('reviews_count')}")
|
||||||
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
||||||
if product.get('delivery_info'):
|
if product.get("delivery_info"):
|
||||||
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(extract_amazon_products())
|
asyncio.run(extract_amazon_products())
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
|||||||
import json
|
import json
|
||||||
from playwright.async_api import Page, BrowserContext
|
from playwright.async_api import Page, BrowserContext
|
||||||
|
|
||||||
|
|
||||||
async def extract_amazon_products():
|
async def extract_amazon_products():
|
||||||
# Initialize browser config
|
# Initialize browser config
|
||||||
browser_config = BrowserConfig(
|
browser_config = BrowserConfig(
|
||||||
@@ -20,7 +21,6 @@ async def extract_amazon_products():
|
|||||||
# Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button
|
# Initialize crawler config with JSON CSS extraction strategy nav-search-submit-button
|
||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
|
|
||||||
extraction_strategy=JsonCssExtractionStrategy(
|
extraction_strategy=JsonCssExtractionStrategy(
|
||||||
schema={
|
schema={
|
||||||
"name": "Amazon Product Search Results",
|
"name": "Amazon Product Search Results",
|
||||||
@@ -30,82 +30,86 @@ async def extract_amazon_products():
|
|||||||
"name": "asin",
|
"name": "asin",
|
||||||
"selector": "",
|
"selector": "",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "data-asin"
|
"attribute": "data-asin",
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "title",
|
|
||||||
"selector": "h2 a span",
|
|
||||||
"type": "text"
|
|
||||||
},
|
},
|
||||||
|
{"name": "title", "selector": "h2 a span", "type": "text"},
|
||||||
{
|
{
|
||||||
"name": "url",
|
"name": "url",
|
||||||
"selector": "h2 a",
|
"selector": "h2 a",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "href"
|
"attribute": "href",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "image",
|
"name": "image",
|
||||||
"selector": ".s-image",
|
"selector": ".s-image",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "src"
|
"attribute": "src",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "rating",
|
"name": "rating",
|
||||||
"selector": ".a-icon-star-small .a-icon-alt",
|
"selector": ".a-icon-star-small .a-icon-alt",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "reviews_count",
|
"name": "reviews_count",
|
||||||
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "price",
|
"name": "price",
|
||||||
"selector": ".a-price .a-offscreen",
|
"selector": ".a-price .a-offscreen",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "original_price",
|
"name": "original_price",
|
||||||
"selector": ".a-price.a-text-price .a-offscreen",
|
"selector": ".a-price.a-text-price .a-offscreen",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "sponsored",
|
"name": "sponsored",
|
||||||
"selector": ".puis-sponsored-label-text",
|
"selector": ".puis-sponsored-label-text",
|
||||||
"type": "exists"
|
"type": "exists",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "delivery_info",
|
"name": "delivery_info",
|
||||||
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"multiple": True
|
"multiple": True,
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
]
|
),
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
url = "https://www.amazon.com/"
|
url = "https://www.amazon.com/"
|
||||||
|
|
||||||
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs):
|
async def after_goto(
|
||||||
|
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
|
||||||
|
):
|
||||||
"""Hook called after navigating to each URL"""
|
"""Hook called after navigating to each URL"""
|
||||||
print(f"[HOOK] after_goto - Successfully loaded: {url}")
|
print(f"[HOOK] after_goto - Successfully loaded: {url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wait for search box to be available
|
# Wait for search box to be available
|
||||||
search_box = await page.wait_for_selector('#twotabsearchtextbox', timeout=1000)
|
search_box = await page.wait_for_selector(
|
||||||
|
"#twotabsearchtextbox", timeout=1000
|
||||||
|
)
|
||||||
|
|
||||||
# Type the search query
|
# Type the search query
|
||||||
await search_box.fill('Samsung Galaxy Tab')
|
await search_box.fill("Samsung Galaxy Tab")
|
||||||
|
|
||||||
# Get the search button and prepare for navigation
|
# Get the search button and prepare for navigation
|
||||||
search_button = await page.wait_for_selector('#nav-search-submit-button', timeout=1000)
|
search_button = await page.wait_for_selector(
|
||||||
|
"#nav-search-submit-button", timeout=1000
|
||||||
|
)
|
||||||
|
|
||||||
# Click with navigation waiting
|
# Click with navigation waiting
|
||||||
await search_button.click()
|
await search_button.click()
|
||||||
|
|
||||||
# Wait for search results to load
|
# Wait for search results to load
|
||||||
await page.wait_for_selector('[data-component-type="s-search-result"]', timeout=10000)
|
await page.wait_for_selector(
|
||||||
|
'[data-component-type="s-search-result"]', timeout=10000
|
||||||
|
)
|
||||||
print("[HOOK] Search completed and results loaded!")
|
print("[HOOK] Search completed and results loaded!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -115,7 +119,6 @@ async def extract_amazon_products():
|
|||||||
|
|
||||||
# Use context manager for proper resource handling
|
# Use context manager for proper resource handling
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
|
|
||||||
crawler.crawler_strategy.set_hook("after_goto", after_goto)
|
crawler.crawler_strategy.set_hook("after_goto", after_goto)
|
||||||
|
|
||||||
# Extract the data
|
# Extract the data
|
||||||
@@ -136,10 +139,12 @@ async def extract_amazon_products():
|
|||||||
print(f"Rating: {product.get('rating')}")
|
print(f"Rating: {product.get('rating')}")
|
||||||
print(f"Reviews: {product.get('reviews_count')}")
|
print(f"Reviews: {product.get('reviews_count')}")
|
||||||
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
||||||
if product.get('delivery_info'):
|
if product.get("delivery_info"):
|
||||||
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(extract_amazon_products())
|
asyncio.run(extract_amazon_products())
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from crawl4ai import AsyncWebCrawler, CacheMode
|
|||||||
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
||||||
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
||||||
import json
|
import json
|
||||||
from playwright.async_api import Page, BrowserContext
|
|
||||||
|
|
||||||
async def extract_amazon_products():
|
async def extract_amazon_products():
|
||||||
# Initialize browser config
|
# Initialize browser config
|
||||||
@@ -41,65 +41,60 @@ async def extract_amazon_products():
|
|||||||
"name": "asin",
|
"name": "asin",
|
||||||
"selector": "",
|
"selector": "",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "data-asin"
|
"attribute": "data-asin",
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "title",
|
|
||||||
"selector": "h2 a span",
|
|
||||||
"type": "text"
|
|
||||||
},
|
},
|
||||||
|
{"name": "title", "selector": "h2 a span", "type": "text"},
|
||||||
{
|
{
|
||||||
"name": "url",
|
"name": "url",
|
||||||
"selector": "h2 a",
|
"selector": "h2 a",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "href"
|
"attribute": "href",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "image",
|
"name": "image",
|
||||||
"selector": ".s-image",
|
"selector": ".s-image",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "src"
|
"attribute": "src",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "rating",
|
"name": "rating",
|
||||||
"selector": ".a-icon-star-small .a-icon-alt",
|
"selector": ".a-icon-star-small .a-icon-alt",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "reviews_count",
|
"name": "reviews_count",
|
||||||
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
"selector": "[data-csa-c-func-deps='aui-da-a-popover'] ~ span span",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "price",
|
"name": "price",
|
||||||
"selector": ".a-price .a-offscreen",
|
"selector": ".a-price .a-offscreen",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "original_price",
|
"name": "original_price",
|
||||||
"selector": ".a-price.a-text-price .a-offscreen",
|
"selector": ".a-price.a-text-price .a-offscreen",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "sponsored",
|
"name": "sponsored",
|
||||||
"selector": ".puis-sponsored-label-text",
|
"selector": ".puis-sponsored-label-text",
|
||||||
"type": "exists"
|
"type": "exists",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "delivery_info",
|
"name": "delivery_info",
|
||||||
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
"selector": "[data-cy='delivery-recipe'] .a-color-base",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"multiple": True
|
"multiple": True,
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
]
|
),
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example search URL (you should replace with your actual Amazon URL)
|
# Example search URL (you should replace with your actual Amazon URL)
|
||||||
url = "https://www.amazon.com/"
|
url = "https://www.amazon.com/"
|
||||||
|
|
||||||
|
|
||||||
# Use context manager for proper resource handling
|
# Use context manager for proper resource handling
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
# Extract the data
|
# Extract the data
|
||||||
@@ -120,10 +115,12 @@ async def extract_amazon_products():
|
|||||||
print(f"Rating: {product.get('rating')}")
|
print(f"Rating: {product.get('rating')}")
|
||||||
print(f"Reviews: {product.get('reviews_count')}")
|
print(f"Reviews: {product.get('reviews_count')}")
|
||||||
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
print(f"Sponsored: {'Yes' if product.get('sponsored') else 'No'}")
|
||||||
if product.get('delivery_info'):
|
if product.get("delivery_info"):
|
||||||
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
print(f"Delivery: {' '.join(product['delivery_info'])}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(extract_amazon_products())
|
asyncio.run(extract_amazon_products())
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
# File: async_webcrawler_multiple_urls_example.py
|
# File: async_webcrawler_multiple_urls_example.py
|
||||||
import os, sys
|
import os, sys
|
||||||
|
|
||||||
# append 2 parent directories to sys.path to import crawl4ai
|
# append 2 parent directories to sys.path to import crawl4ai
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
parent_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
)
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from crawl4ai import AsyncWebCrawler
|
from crawl4ai import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Initialize the AsyncWebCrawler
|
# Initialize the AsyncWebCrawler
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -16,7 +20,7 @@ async def main():
|
|||||||
"https://python.org",
|
"https://python.org",
|
||||||
"https://github.com",
|
"https://github.com",
|
||||||
"https://stackoverflow.com",
|
"https://stackoverflow.com",
|
||||||
"https://news.ycombinator.com"
|
"https://news.ycombinator.com",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Set up crawling parameters
|
# Set up crawling parameters
|
||||||
@@ -27,7 +31,7 @@ async def main():
|
|||||||
urls=urls,
|
urls=urls,
|
||||||
word_count_threshold=word_count_threshold,
|
word_count_threshold=word_count_threshold,
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
verbose=True
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process the results
|
# Process the results
|
||||||
@@ -36,7 +40,9 @@ async def main():
|
|||||||
print(f"Successfully crawled: {result.url}")
|
print(f"Successfully crawled: {result.url}")
|
||||||
print(f"Title: {result.metadata.get('title', 'N/A')}")
|
print(f"Title: {result.metadata.get('title', 'N/A')}")
|
||||||
print(f"Word count: {len(result.markdown.split())}")
|
print(f"Word count: {len(result.markdown.split())}")
|
||||||
print(f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}")
|
print(
|
||||||
|
f"Number of links: {len(result.links.get('internal', [])) + len(result.links.get('external', []))}"
|
||||||
|
)
|
||||||
print(f"Number of images: {len(result.media.get('images', []))}")
|
print(f"Number of images: {len(result.media.get('images', []))}")
|
||||||
print("---")
|
print("---")
|
||||||
else:
|
else:
|
||||||
@@ -44,5 +50,6 @@ async def main():
|
|||||||
print(f"Error: {result.error_message}")
|
print(f"Error: {result.error_message}")
|
||||||
print("---")
|
print("---")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -6,10 +6,8 @@ This example demonstrates optimal browser usage patterns in Crawl4AI:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig
|
||||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,32 @@
|
|||||||
import os, time
|
import os, time
|
||||||
|
|
||||||
# append the path to the root of the project
|
# append the path to the root of the project
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
from firecrawl import FirecrawlApp
|
from firecrawl import FirecrawlApp
|
||||||
from crawl4ai import AsyncWebCrawler
|
from crawl4ai import AsyncWebCrawler
|
||||||
__data__ = os.path.join(os.path.dirname(__file__), '..', '..') + '/.data'
|
|
||||||
|
__data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data"
|
||||||
|
|
||||||
|
|
||||||
async def compare():
|
async def compare():
|
||||||
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
|
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
|
||||||
|
|
||||||
# Tet Firecrawl with a simple crawl
|
# Tet Firecrawl with a simple crawl
|
||||||
start = time.time()
|
start = time.time()
|
||||||
scrape_status = app.scrape_url(
|
scrape_status = app.scrape_url(
|
||||||
'https://www.nbcnews.com/business',
|
"https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
|
||||||
params={'formats': ['markdown', 'html']}
|
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Time taken: {end - start} seconds")
|
print(f"Time taken: {end - start} seconds")
|
||||||
print(len(scrape_status['markdown']))
|
print(len(scrape_status["markdown"]))
|
||||||
# save the markdown content with provider name
|
# save the markdown content with provider name
|
||||||
with open(f"{__data__}/firecrawl_simple.md", "w") as f:
|
with open(f"{__data__}/firecrawl_simple.md", "w") as f:
|
||||||
f.write(scrape_status['markdown'])
|
f.write(scrape_status["markdown"])
|
||||||
# Count how many "cldnry.s-nbcnews.com" are in the markdown
|
# Count how many "cldnry.s-nbcnews.com" are in the markdown
|
||||||
print(scrape_status['markdown'].count("cldnry.s-nbcnews.com"))
|
print(scrape_status["markdown"].count("cldnry.s-nbcnews.com"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -34,7 +35,7 @@ async def compare():
|
|||||||
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
|
# js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
|
||||||
word_count_threshold=0,
|
word_count_threshold=0,
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
verbose=False
|
verbose=False,
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Time taken: {end - start} seconds")
|
print(f"Time taken: {end - start} seconds")
|
||||||
@@ -48,10 +49,12 @@ async def compare():
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"],
|
js_code=[
|
||||||
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
|
],
|
||||||
word_count_threshold=0,
|
word_count_threshold=0,
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
verbose=False
|
verbose=False,
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Time taken: {end - start} seconds")
|
print(f"Time taken: {end - start} seconds")
|
||||||
@@ -62,6 +65,6 @@ async def compare():
|
|||||||
# count how many "cldnry.s-nbcnews.com" are in the markdown
|
# count how many "cldnry.s-nbcnews.com" are in the markdown
|
||||||
print(result.markdown.count("cldnry.s-nbcnews.com"))
|
print(result.markdown.count("cldnry.s-nbcnews.com"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(compare())
|
asyncio.run(compare())
|
||||||
|
|
||||||
@@ -3,11 +3,18 @@ import time
|
|||||||
from rich import print
|
from rich import print
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from crawl4ai import (
|
from crawl4ai import (
|
||||||
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig,
|
AsyncWebCrawler,
|
||||||
MemoryAdaptiveDispatcher, SemaphoreDispatcher,
|
BrowserConfig,
|
||||||
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode
|
CrawlerRunConfig,
|
||||||
|
MemoryAdaptiveDispatcher,
|
||||||
|
SemaphoreDispatcher,
|
||||||
|
RateLimiter,
|
||||||
|
CrawlerMonitor,
|
||||||
|
DisplayMode,
|
||||||
|
CacheMode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def memory_adaptive(urls, browser_config, run_config):
|
async def memory_adaptive(urls, browser_config, run_config):
|
||||||
"""Memory adaptive crawler with monitoring"""
|
"""Memory adaptive crawler with monitoring"""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -16,14 +23,16 @@ async def memory_adaptive(urls, browser_config, run_config):
|
|||||||
memory_threshold_percent=70.0,
|
memory_threshold_percent=70.0,
|
||||||
max_session_permit=10,
|
max_session_permit=10,
|
||||||
monitor=CrawlerMonitor(
|
monitor=CrawlerMonitor(
|
||||||
max_visible_rows=15,
|
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||||
display_mode=DisplayMode.DETAILED
|
),
|
||||||
)
|
)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
return len(results), duration
|
return len(results), duration
|
||||||
|
|
||||||
|
|
||||||
async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
|
async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
|
||||||
"""Memory adaptive crawler with rate limiting"""
|
"""Memory adaptive crawler with rate limiting"""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -32,19 +41,19 @@ async def memory_adaptive_with_rate_limit(urls, browser_config, run_config):
|
|||||||
memory_threshold_percent=70.0,
|
memory_threshold_percent=70.0,
|
||||||
max_session_permit=10,
|
max_session_permit=10,
|
||||||
rate_limiter=RateLimiter(
|
rate_limiter=RateLimiter(
|
||||||
base_delay=(1.0, 2.0),
|
base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
|
||||||
max_delay=30.0,
|
|
||||||
max_retries=2
|
|
||||||
),
|
),
|
||||||
monitor=CrawlerMonitor(
|
monitor=CrawlerMonitor(
|
||||||
max_visible_rows=15,
|
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||||
display_mode=DisplayMode.DETAILED
|
),
|
||||||
)
|
)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
return len(results), duration
|
return len(results), duration
|
||||||
|
|
||||||
|
|
||||||
async def semaphore(urls, browser_config, run_config):
|
async def semaphore(urls, browser_config, run_config):
|
||||||
"""Basic semaphore crawler"""
|
"""Basic semaphore crawler"""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -52,14 +61,16 @@ async def semaphore(urls, browser_config, run_config):
|
|||||||
dispatcher = SemaphoreDispatcher(
|
dispatcher = SemaphoreDispatcher(
|
||||||
semaphore_count=5,
|
semaphore_count=5,
|
||||||
monitor=CrawlerMonitor(
|
monitor=CrawlerMonitor(
|
||||||
max_visible_rows=15,
|
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||||
display_mode=DisplayMode.DETAILED
|
),
|
||||||
)
|
)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
return len(results), duration
|
return len(results), duration
|
||||||
|
|
||||||
|
|
||||||
async def semaphore_with_rate_limit(urls, browser_config, run_config):
|
async def semaphore_with_rate_limit(urls, browser_config, run_config):
|
||||||
"""Semaphore crawler with rate limiting"""
|
"""Semaphore crawler with rate limiting"""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -67,19 +78,19 @@ async def semaphore_with_rate_limit(urls, browser_config, run_config):
|
|||||||
dispatcher = SemaphoreDispatcher(
|
dispatcher = SemaphoreDispatcher(
|
||||||
semaphore_count=5,
|
semaphore_count=5,
|
||||||
rate_limiter=RateLimiter(
|
rate_limiter=RateLimiter(
|
||||||
base_delay=(1.0, 2.0),
|
base_delay=(1.0, 2.0), max_delay=30.0, max_retries=2
|
||||||
max_delay=30.0,
|
|
||||||
max_retries=2
|
|
||||||
),
|
),
|
||||||
monitor=CrawlerMonitor(
|
monitor=CrawlerMonitor(
|
||||||
max_visible_rows=15,
|
max_visible_rows=15, display_mode=DisplayMode.DETAILED
|
||||||
display_mode=DisplayMode.DETAILED
|
),
|
||||||
)
|
)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
return len(results), duration
|
return len(results), duration
|
||||||
|
|
||||||
|
|
||||||
def create_performance_table(results):
|
def create_performance_table(results):
|
||||||
"""Creates a rich table showing performance results"""
|
"""Creates a rich table showing performance results"""
|
||||||
table = Table(title="Crawler Strategy Performance Comparison")
|
table = Table(title="Crawler Strategy Performance Comparison")
|
||||||
@@ -93,14 +104,12 @@ def create_performance_table(results):
|
|||||||
for strategy, (urls_crawled, duration) in sorted_results:
|
for strategy, (urls_crawled, duration) in sorted_results:
|
||||||
urls_per_second = urls_crawled / duration
|
urls_per_second = urls_crawled / duration
|
||||||
table.add_row(
|
table.add_row(
|
||||||
strategy,
|
strategy, str(urls_crawled), f"{duration:.2f}", f"{urls_per_second:.2f}"
|
||||||
str(urls_crawled),
|
|
||||||
f"{duration:.2f}",
|
|
||||||
f"{urls_per_second:.2f}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
urls = [f"https://example.com/page{i}" for i in range(1, 20)]
|
urls = [f"https://example.com/page{i}" for i in range(1, 20)]
|
||||||
browser_config = BrowserConfig(headless=True, verbose=False)
|
browser_config = BrowserConfig(headless=True, verbose=False)
|
||||||
@@ -108,14 +117,19 @@ async def main():
|
|||||||
|
|
||||||
results = {
|
results = {
|
||||||
"Memory Adaptive": await memory_adaptive(urls, browser_config, run_config),
|
"Memory Adaptive": await memory_adaptive(urls, browser_config, run_config),
|
||||||
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(urls, browser_config, run_config),
|
"Memory Adaptive + Rate Limit": await memory_adaptive_with_rate_limit(
|
||||||
|
urls, browser_config, run_config
|
||||||
|
),
|
||||||
"Semaphore": await semaphore(urls, browser_config, run_config),
|
"Semaphore": await semaphore(urls, browser_config, run_config),
|
||||||
"Semaphore + Rate Limit": await semaphore_with_rate_limit(urls, browser_config, run_config),
|
"Semaphore + Rate Limit": await semaphore_with_rate_limit(
|
||||||
|
urls, browser_config, run_config
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
table = create_performance_table(results)
|
table = create_performance_table(results)
|
||||||
print("\nPerformance Summary:")
|
print("\nPerformance Summary:")
|
||||||
print(table)
|
print(table)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -6,15 +6,24 @@ import base64
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class Crawl4AiTester:
|
class Crawl4AiTester:
|
||||||
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
|
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code" # Check environment variable as fallback
|
self.api_token = (
|
||||||
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {}
|
api_token or os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
|
||||||
|
) # Check environment variable as fallback
|
||||||
|
self.headers = (
|
||||||
|
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
|
||||||
|
)
|
||||||
|
|
||||||
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
|
def submit_and_wait(
|
||||||
|
self, request_data: Dict[str, Any], timeout: int = 300
|
||||||
|
) -> Dict[str, Any]:
|
||||||
# Submit crawl job
|
# Submit crawl job
|
||||||
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers)
|
response = requests.post(
|
||||||
|
f"{self.base_url}/crawl", json=request_data, headers=self.headers
|
||||||
|
)
|
||||||
if response.status_code == 403:
|
if response.status_code == 403:
|
||||||
raise Exception("API token is invalid or missing")
|
raise Exception("API token is invalid or missing")
|
||||||
task_id = response.json()["task_id"]
|
task_id = response.json()["task_id"]
|
||||||
@@ -24,9 +33,13 @@ class Crawl4AiTester:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
raise TimeoutError(
|
||||||
|
f"Task {task_id} did not complete within {timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers)
|
result = requests.get(
|
||||||
|
f"{self.base_url}/task/{task_id}", headers=self.headers
|
||||||
|
)
|
||||||
status = result.json()
|
status = result.json()
|
||||||
|
|
||||||
if status["status"] == "failed":
|
if status["status"] == "failed":
|
||||||
@@ -39,7 +52,12 @@ class Crawl4AiTester:
|
|||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60)
|
response = requests.post(
|
||||||
|
f"{self.base_url}/crawl_sync",
|
||||||
|
json=request_data,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
if response.status_code == 408:
|
if response.status_code == 408:
|
||||||
raise TimeoutError("Task did not complete within server timeout")
|
raise TimeoutError("Task did not complete within server timeout")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -48,13 +66,12 @@ class Crawl4AiTester:
|
|||||||
def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
def crawl_direct(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Directly crawl without using task queue"""
|
"""Directly crawl without using task queue"""
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.base_url}/crawl_direct",
|
f"{self.base_url}/crawl_direct", json=request_data, headers=self.headers
|
||||||
json=request_data,
|
|
||||||
headers=self.headers
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def test_docker_deployment(version="basic"):
|
def test_docker_deployment(version="basic"):
|
||||||
tester = Crawl4AiTester(
|
tester = Crawl4AiTester(
|
||||||
base_url="http://localhost:11235",
|
base_url="http://localhost:11235",
|
||||||
@@ -70,7 +87,7 @@ def test_docker_deployment(version="basic"):
|
|||||||
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
||||||
print("Health check:", health.json())
|
print("Health check:", health.json())
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException:
|
||||||
if i == max_retries - 1:
|
if i == max_retries - 1:
|
||||||
print(f"Failed to connect after {max_retries} attempts")
|
print(f"Failed to connect after {max_retries} attempts")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -99,7 +116,7 @@ def test_basic_crawl(tester: Crawl4AiTester):
|
|||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 10,
|
"priority": 10,
|
||||||
"session_id": "test"
|
"session_id": "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -107,19 +124,21 @@ def test_basic_crawl(tester: Crawl4AiTester):
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(result["result"]["markdown"]) > 0
|
assert len(result["result"]["markdown"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_basic_crawl_sync(tester: Crawl4AiTester):
|
def test_basic_crawl_sync(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Basic Crawl (Sync) ===")
|
print("\n=== Testing Basic Crawl (Sync) ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 10,
|
"priority": 10,
|
||||||
"session_id": "test"
|
"session_id": "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_sync(request)
|
result = tester.submit_sync(request)
|
||||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||||
assert result['status'] == 'completed'
|
assert result["status"] == "completed"
|
||||||
assert result['result']['success']
|
assert result["result"]["success"]
|
||||||
assert len(result['result']['markdown']) > 0
|
assert len(result["result"]["markdown"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_basic_crawl_direct(tester: Crawl4AiTester):
|
def test_basic_crawl_direct(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Basic Crawl (Direct) ===")
|
print("\n=== Testing Basic Crawl (Direct) ===")
|
||||||
@@ -127,13 +146,14 @@ def test_basic_crawl_direct(tester: Crawl4AiTester):
|
|||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 10,
|
"priority": 10,
|
||||||
# "session_id": "test"
|
# "session_id": "test"
|
||||||
"cache_mode": "bypass" # or "enabled", "disabled", "read_only", "write_only"
|
"cache_mode": "bypass", # or "enabled", "disabled", "read_only", "write_only"
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.crawl_direct(request)
|
result = tester.crawl_direct(request)
|
||||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||||
assert result['result']['success']
|
assert result["result"]["success"]
|
||||||
assert len(result['result']['markdown']) > 0
|
assert len(result["result"]["markdown"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_js_execution(tester: Crawl4AiTester):
|
def test_js_execution(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing JS Execution ===")
|
print("\n=== Testing JS Execution ===")
|
||||||
@@ -144,32 +164,29 @@ def test_js_execution(tester: Crawl4AiTester):
|
|||||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
],
|
],
|
||||||
"wait_for": "article.tease-card:nth-child(10)",
|
"wait_for": "article.tease-card:nth-child(10)",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
def test_css_selector(tester: Crawl4AiTester):
|
def test_css_selector(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing CSS Selector ===")
|
print("\n=== Testing CSS Selector ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 7,
|
"priority": 7,
|
||||||
"css_selector": ".wide-tease-item__description",
|
"css_selector": ".wide-tease-item__description",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
"extra": {"word_count_threshold": 10},
|
||||||
},
|
|
||||||
"extra": {"word_count_threshold": 10}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_extraction(tester: Crawl4AiTester):
|
def test_structured_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Structured Extraction ===")
|
print("\n=== Testing Structured Extraction ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -190,19 +207,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
|||||||
"name": "price",
|
"name": "price",
|
||||||
"selector": "td:nth-child(2)",
|
"selector": "td:nth-child(2)",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.coinbase.com/explore",
|
"urls": "https://www.coinbase.com/explore",
|
||||||
"priority": 9,
|
"priority": 9,
|
||||||
"extraction_config": {
|
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||||
"type": "json_css",
|
|
||||||
"params": {
|
|
||||||
"schema": schema
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -212,6 +224,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(extracted) > 0
|
assert len(extracted) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_llm_extraction(tester: Crawl4AiTester):
|
def test_llm_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing LLM Extraction ===")
|
print("\n=== Testing LLM Extraction ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -219,18 +232,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"model_name": {
|
"model_name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Name of the OpenAI model."
|
"description": "Name of the OpenAI model.",
|
||||||
},
|
},
|
||||||
"input_fee": {
|
"input_fee": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Fee for input token for the OpenAI model."
|
"description": "Fee for input token for the OpenAI model.",
|
||||||
},
|
},
|
||||||
"output_fee": {
|
"output_fee": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Fee for output token for the OpenAI model."
|
"description": "Fee for output token for the OpenAI model.",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["model_name", "input_fee", "output_fee"]
|
},
|
||||||
|
"required": ["model_name", "input_fee", "output_fee"],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -243,10 +256,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
"api_token": os.getenv("OPENAI_API_KEY"),
|
"api_token": os.getenv("OPENAI_API_KEY"),
|
||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
|
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"crawler_params": {"word_count_threshold": 1}
|
},
|
||||||
|
"crawler_params": {"word_count_threshold": 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -258,6 +271,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_llm_with_ollama(tester: Crawl4AiTester):
|
def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing LLM with Ollama ===")
|
print("\n=== Testing LLM with Ollama ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -265,18 +279,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"article_title": {
|
"article_title": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The main title of the news article"
|
"description": "The main title of the news article",
|
||||||
},
|
},
|
||||||
"summary": {
|
"summary": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A brief summary of the article content"
|
"description": "A brief summary of the article content",
|
||||||
},
|
},
|
||||||
"main_topics": {
|
"main_topics": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Main topics or themes discussed in the article"
|
"description": "Main topics or themes discussed in the article",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -288,11 +302,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
"provider": "ollama/llama2",
|
"provider": "ollama/llama2",
|
||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": "Extract the main article information including title, summary, and main topics."
|
"instruction": "Extract the main article information including title, summary, and main topics.",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"extra": {"word_count_threshold": 1},
|
"extra": {"word_count_threshold": 1},
|
||||||
"crawler_params": {"verbose": True}
|
"crawler_params": {"verbose": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -303,6 +317,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ollama extraction test failed: {str(e)}")
|
print(f"Ollama extraction test failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_cosine_extraction(tester: Crawl4AiTester):
|
def test_cosine_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Cosine Extraction ===")
|
print("\n=== Testing Cosine Extraction ===")
|
||||||
request = {
|
request = {
|
||||||
@@ -314,9 +329,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
|||||||
"semantic_filter": "business finance economy",
|
"semantic_filter": "business finance economy",
|
||||||
"word_count_threshold": 10,
|
"word_count_threshold": 10,
|
||||||
"max_dist": 0.2,
|
"max_dist": 0.2,
|
||||||
"top_k": 3
|
"top_k": 3,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -328,15 +343,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Cosine extraction test failed: {str(e)}")
|
print(f"Cosine extraction test failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_screenshot(tester: Crawl4AiTester):
|
def test_screenshot(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Screenshot ===")
|
print("\n=== Testing Screenshot ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"screenshot": True,
|
"screenshot": True,
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -351,6 +365,7 @@ def test_screenshot(tester: Crawl4AiTester):
|
|||||||
|
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
||||||
# version = "full"
|
# version = "full"
|
||||||
|
|||||||
@@ -9,18 +9,17 @@ This example shows how to:
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||||
from crawl4ai.extraction_strategy import (
|
from crawl4ai.extraction_strategy import (
|
||||||
LLMExtractionStrategy,
|
LLMExtractionStrategy,
|
||||||
JsonCssExtractionStrategy,
|
JsonCssExtractionStrategy,
|
||||||
JsonXPathExtractionStrategy
|
JsonXPathExtractionStrategy,
|
||||||
)
|
)
|
||||||
from crawl4ai.chunking_strategy import RegexChunking, IdentityChunking
|
|
||||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
|
|
||||||
|
|
||||||
async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str):
|
async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str):
|
||||||
"""Helper function to run extraction with proper configuration"""
|
"""Helper function to run extraction with proper configuration"""
|
||||||
try:
|
try:
|
||||||
@@ -30,7 +29,7 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
|
|||||||
extraction_strategy=strategy,
|
extraction_strategy=strategy,
|
||||||
markdown_generator=DefaultMarkdownGenerator(
|
markdown_generator=DefaultMarkdownGenerator(
|
||||||
content_filter=PruningContentFilter() # For fit_markdown support
|
content_filter=PruningContentFilter() # For fit_markdown support
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the crawler
|
# Run the crawler
|
||||||
@@ -40,22 +39,22 @@ async def run_extraction(crawler: AsyncWebCrawler, url: str, strategy, name: str
|
|||||||
print(f"\n=== {name} Results ===")
|
print(f"\n=== {name} Results ===")
|
||||||
print(f"Extracted Content: {result.extracted_content}")
|
print(f"Extracted Content: {result.extracted_content}")
|
||||||
print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}")
|
print(f"Raw Markdown Length: {len(result.markdown_v2.raw_markdown)}")
|
||||||
print(f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}")
|
print(
|
||||||
|
f"Citations Markdown Length: {len(result.markdown_v2.markdown_with_citations)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Error in {name}: Crawl failed")
|
print(f"Error in {name}: Crawl failed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in {name}: {str(e)}")
|
print(f"Error in {name}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Example URL (replace with actual URL)
|
# Example URL (replace with actual URL)
|
||||||
url = "https://example.com/product-page"
|
url = "https://example.com/product-page"
|
||||||
|
|
||||||
# Configure browser settings
|
# Configure browser settings
|
||||||
browser_config = BrowserConfig(
|
browser_config = BrowserConfig(headless=True, verbose=True)
|
||||||
headless=True,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize extraction strategies
|
# Initialize extraction strategies
|
||||||
|
|
||||||
@@ -63,21 +62,21 @@ async def main():
|
|||||||
markdown_strategy = LLMExtractionStrategy(
|
markdown_strategy = LLMExtractionStrategy(
|
||||||
provider="openai/gpt-4o-mini",
|
provider="openai/gpt-4o-mini",
|
||||||
api_token=os.getenv("OPENAI_API_KEY"),
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
instruction="Extract product information including name, price, and description"
|
instruction="Extract product information including name, price, and description",
|
||||||
)
|
)
|
||||||
|
|
||||||
html_strategy = LLMExtractionStrategy(
|
html_strategy = LLMExtractionStrategy(
|
||||||
input_format="html",
|
input_format="html",
|
||||||
provider="openai/gpt-4o-mini",
|
provider="openai/gpt-4o-mini",
|
||||||
api_token=os.getenv("OPENAI_API_KEY"),
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
instruction="Extract product information from HTML including structured data"
|
instruction="Extract product information from HTML including structured data",
|
||||||
)
|
)
|
||||||
|
|
||||||
fit_markdown_strategy = LLMExtractionStrategy(
|
fit_markdown_strategy = LLMExtractionStrategy(
|
||||||
input_format="fit_markdown",
|
input_format="fit_markdown",
|
||||||
provider="openai/gpt-4o-mini",
|
provider="openai/gpt-4o-mini",
|
||||||
api_token=os.getenv("OPENAI_API_KEY"),
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
instruction="Extract product information from cleaned markdown"
|
instruction="Extract product information from cleaned markdown",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. JSON CSS Extraction (automatically uses HTML input)
|
# 2. JSON CSS Extraction (automatically uses HTML input)
|
||||||
@@ -86,8 +85,8 @@ async def main():
|
|||||||
"fields": [
|
"fields": [
|
||||||
{"name": "title", "selector": "h1.product-title", "type": "text"},
|
{"name": "title", "selector": "h1.product-title", "type": "text"},
|
||||||
{"name": "price", "selector": ".price", "type": "text"},
|
{"name": "price", "selector": ".price", "type": "text"},
|
||||||
{"name": "description", "selector": ".description", "type": "text"}
|
{"name": "description", "selector": ".description", "type": "text"},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
css_strategy = JsonCssExtractionStrategy(schema=css_schema)
|
css_strategy = JsonCssExtractionStrategy(schema=css_schema)
|
||||||
|
|
||||||
@@ -95,10 +94,22 @@ async def main():
|
|||||||
xpath_schema = {
|
xpath_schema = {
|
||||||
"baseSelector": "//div[@class='product']",
|
"baseSelector": "//div[@class='product']",
|
||||||
"fields": [
|
"fields": [
|
||||||
{"name": "title", "selector": ".//h1[@class='product-title']/text()", "type": "text"},
|
{
|
||||||
{"name": "price", "selector": ".//span[@class='price']/text()", "type": "text"},
|
"name": "title",
|
||||||
{"name": "description", "selector": ".//div[@class='description']/text()", "type": "text"}
|
"selector": ".//h1[@class='product-title']/text()",
|
||||||
]
|
"type": "text",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "price",
|
||||||
|
"selector": ".//span[@class='price']/text()",
|
||||||
|
"type": "text",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "description",
|
||||||
|
"selector": ".//div[@class='description']/text()",
|
||||||
|
"type": "text",
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema)
|
xpath_strategy = JsonXPathExtractionStrategy(schema=xpath_schema)
|
||||||
|
|
||||||
@@ -111,5 +122,6 @@ async def main():
|
|||||||
await run_extraction(crawler, url, css_strategy, "CSS Extraction")
|
await run_extraction(crawler, url, css_strategy, "CSS Extraction")
|
||||||
await run_extraction(crawler, url, xpath_strategy, "XPath Extraction")
|
await run_extraction(crawler, url, xpath_strategy, "XPath Extraction")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from crawl4ai import *
|
from crawl4ai import *
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
browser_config = BrowserConfig(headless=True, verbose=True)
|
browser_config = BrowserConfig(headless=True, verbose=True)
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(
|
markdown_generator=DefaultMarkdownGenerator(
|
||||||
content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
|
content_filter=PruningContentFilter(
|
||||||
|
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||||
)
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.helloworld.org",
|
url="https://www.helloworld.org", config=crawler_config
|
||||||
config=crawler_config
|
|
||||||
)
|
)
|
||||||
print(result.markdown_v2.raw_markdown[:500])
|
print(result.markdown_v2.raw_markdown[:500])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -1,19 +1,18 @@
|
|||||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode
|
||||||
from playwright.async_api import Page, BrowserContext
|
from playwright.async_api import Page, BrowserContext
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
print("🔗 Hooks Example: Demonstrating different hook use cases")
|
print("🔗 Hooks Example: Demonstrating different hook use cases")
|
||||||
|
|
||||||
# Configure browser settings
|
# Configure browser settings
|
||||||
browser_config = BrowserConfig(
|
browser_config = BrowserConfig(headless=True)
|
||||||
headless=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure crawler settings
|
# Configure crawler settings
|
||||||
crawler_run_config = CrawlerRunConfig(
|
crawler_run_config = CrawlerRunConfig(
|
||||||
js_code="window.scrollTo(0, document.body.scrollHeight);",
|
js_code="window.scrollTo(0, document.body.scrollHeight);",
|
||||||
wait_for="body",
|
wait_for="body",
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create crawler instance
|
# Create crawler instance
|
||||||
@@ -30,16 +29,22 @@ async def main():
|
|||||||
"""Hook called after a new page and context are created"""
|
"""Hook called after a new page and context are created"""
|
||||||
print("[HOOK] on_page_context_created - New page created!")
|
print("[HOOK] on_page_context_created - New page created!")
|
||||||
# Example: Set default viewport size
|
# Example: Set default viewport size
|
||||||
await context.add_cookies([{
|
await context.add_cookies(
|
||||||
'name': 'session_id',
|
[
|
||||||
'value': 'example_session',
|
{
|
||||||
'domain': '.example.com',
|
"name": "session_id",
|
||||||
'path': '/'
|
"value": "example_session",
|
||||||
}])
|
"domain": ".example.com",
|
||||||
|
"path": "/",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
await page.set_viewport_size({"width": 1080, "height": 800})
|
await page.set_viewport_size({"width": 1080, "height": 800})
|
||||||
return page
|
return page
|
||||||
|
|
||||||
async def on_user_agent_updated(page: Page, context: BrowserContext, user_agent: str, **kwargs):
|
async def on_user_agent_updated(
|
||||||
|
page: Page, context: BrowserContext, user_agent: str, **kwargs
|
||||||
|
):
|
||||||
"""Hook called when the user agent is updated"""
|
"""Hook called when the user agent is updated"""
|
||||||
print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}")
|
print(f"[HOOK] on_user_agent_updated - New user agent: {user_agent}")
|
||||||
return page
|
return page
|
||||||
@@ -53,17 +58,17 @@ async def main():
|
|||||||
"""Hook called before navigating to each URL"""
|
"""Hook called before navigating to each URL"""
|
||||||
print(f"[HOOK] before_goto - About to visit: {url}")
|
print(f"[HOOK] before_goto - About to visit: {url}")
|
||||||
# Example: Add custom headers for the request
|
# Example: Add custom headers for the request
|
||||||
await page.set_extra_http_headers({
|
await page.set_extra_http_headers({"Custom-Header": "my-value"})
|
||||||
"Custom-Header": "my-value"
|
|
||||||
})
|
|
||||||
return page
|
return page
|
||||||
|
|
||||||
async def after_goto(page: Page, context: BrowserContext, url: str, response: dict, **kwargs):
|
async def after_goto(
|
||||||
|
page: Page, context: BrowserContext, url: str, response: dict, **kwargs
|
||||||
|
):
|
||||||
"""Hook called after navigating to each URL"""
|
"""Hook called after navigating to each URL"""
|
||||||
print(f"[HOOK] after_goto - Successfully loaded: {url}")
|
print(f"[HOOK] after_goto - Successfully loaded: {url}")
|
||||||
# Example: Wait for a specific element to be loaded
|
# Example: Wait for a specific element to be loaded
|
||||||
try:
|
try:
|
||||||
await page.wait_for_selector('.content', timeout=1000)
|
await page.wait_for_selector(".content", timeout=1000)
|
||||||
print("Content element found!")
|
print("Content element found!")
|
||||||
except:
|
except:
|
||||||
print("Content element not found, continuing anyway")
|
print("Content element not found, continuing anyway")
|
||||||
@@ -76,7 +81,9 @@ async def main():
|
|||||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight);")
|
await page.evaluate("window.scrollTo(0, document.body.scrollHeight);")
|
||||||
return page
|
return page
|
||||||
|
|
||||||
async def before_return_html(page: Page, context: BrowserContext, html:str, **kwargs):
|
async def before_return_html(
|
||||||
|
page: Page, context: BrowserContext, html: str, **kwargs
|
||||||
|
):
|
||||||
"""Hook called before returning the HTML content"""
|
"""Hook called before returning the HTML content"""
|
||||||
print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})")
|
print(f"[HOOK] before_return_html - Got HTML content (length: {len(html)})")
|
||||||
# Example: You could modify the HTML content here if needed
|
# Example: You could modify the HTML content here if needed
|
||||||
@@ -84,7 +91,9 @@ async def main():
|
|||||||
|
|
||||||
# Set all the hooks
|
# Set all the hooks
|
||||||
crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created)
|
crawler.crawler_strategy.set_hook("on_browser_created", on_browser_created)
|
||||||
crawler.crawler_strategy.set_hook("on_page_context_created", on_page_context_created)
|
crawler.crawler_strategy.set_hook(
|
||||||
|
"on_page_context_created", on_page_context_created
|
||||||
|
)
|
||||||
crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated)
|
crawler.crawler_strategy.set_hook("on_user_agent_updated", on_user_agent_updated)
|
||||||
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
|
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
|
||||||
crawler.crawler_strategy.set_hook("before_goto", before_goto)
|
crawler.crawler_strategy.set_hook("before_goto", before_goto)
|
||||||
@@ -95,13 +104,15 @@ async def main():
|
|||||||
await crawler.start()
|
await crawler.start()
|
||||||
|
|
||||||
# Example usage: crawl a simple website
|
# Example usage: crawl a simple website
|
||||||
url = 'https://example.com'
|
url = "https://example.com"
|
||||||
result = await crawler.arun(url, config=crawler_run_config)
|
result = await crawler.arun(url, config=crawler_run_config)
|
||||||
print(f"\nCrawled URL: {result.url}")
|
print(f"\nCrawled URL: {result.url}")
|
||||||
print(f"HTML length: {len(result.html)}")
|
print(f"HTML length: {len(result.html)}")
|
||||||
|
|
||||||
await crawler.close()
|
await crawler.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy
|
from crawl4ai import AsyncWebCrawler, AsyncPlaywrightCrawlerStrategy
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Example 1: Setting language when creating the crawler
|
# Example 1: Setting language when creating the crawler
|
||||||
crawler1 = AsyncWebCrawler(
|
crawler1 = AsyncWebCrawler(
|
||||||
@@ -9,11 +10,15 @@ async def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
result1 = await crawler1.arun("https://www.example.com")
|
result1 = await crawler1.arun("https://www.example.com")
|
||||||
print("Example 1 result:", result1.extracted_content[:100]) # Print first 100 characters
|
print(
|
||||||
|
"Example 1 result:", result1.extracted_content[:100]
|
||||||
|
) # Print first 100 characters
|
||||||
|
|
||||||
# Example 2: Setting language before crawling
|
# Example 2: Setting language before crawling
|
||||||
crawler2 = AsyncWebCrawler()
|
crawler2 = AsyncWebCrawler()
|
||||||
crawler2.crawler_strategy.headers["Accept-Language"] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
|
crawler2.crawler_strategy.headers[
|
||||||
|
"Accept-Language"
|
||||||
|
] = "es-ES,es;q=0.9,en-US;q=0.8,en;q=0.7"
|
||||||
result2 = await crawler2.arun("https://www.example.com")
|
result2 = await crawler2.arun("https://www.example.com")
|
||||||
print("Example 2 result:", result2.extracted_content[:100])
|
print("Example 2 result:", result2.extracted_content[:100])
|
||||||
|
|
||||||
@@ -21,7 +26,7 @@ async def main():
|
|||||||
crawler3 = AsyncWebCrawler()
|
crawler3 = AsyncWebCrawler()
|
||||||
result3 = await crawler3.arun(
|
result3 = await crawler3.arun(
|
||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"}
|
headers={"Accept-Language": "de-DE,de;q=0.9,en-US;q=0.8,en;q=0.7"},
|
||||||
)
|
)
|
||||||
print("Example 3 result:", result3.extracted_content[:100])
|
print("Example 3 result:", result3.extracted_content[:100])
|
||||||
|
|
||||||
@@ -33,13 +38,13 @@ async def main():
|
|||||||
]
|
]
|
||||||
|
|
||||||
crawler4 = AsyncWebCrawler()
|
crawler4 = AsyncWebCrawler()
|
||||||
results = await asyncio.gather(*[
|
results = await asyncio.gather(
|
||||||
crawler4.arun(url, headers={"Accept-Language": lang})
|
*[crawler4.arun(url, headers={"Accept-Language": lang}) for url, lang in urls]
|
||||||
for url, lang in urls
|
)
|
||||||
])
|
|
||||||
|
|
||||||
for url, result in zip([u for u, _ in urls], results):
|
for url, result in zip([u for u, _ in urls], results):
|
||||||
print(f"Result for {url}:", result.extracted_content[:100])
|
print(f"Result for {url}:", result.extracted_content[:100])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -3,15 +3,20 @@ from crawl4ai.crawler_strategy import *
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
url = r'https://openai.com/api/pricing/'
|
url = r"https://openai.com/api/pricing/"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelFee(BaseModel):
|
class OpenAIModelFee(BaseModel):
|
||||||
model_name: str = Field(..., description="Name of the OpenAI model.")
|
model_name: str = Field(..., description="Name of the OpenAI model.")
|
||||||
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
|
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
|
||||||
output_fee: str = Field(..., description="Fee for output token for the OpenAI model.")
|
output_fee: str = Field(
|
||||||
|
..., description="Fee for output token for the OpenAI model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from crawl4ai import AsyncWebCrawler
|
from crawl4ai import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Use AsyncWebCrawler
|
# Use AsyncWebCrawler
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
@@ -20,15 +25,15 @@ async def main():
|
|||||||
word_count_threshold=1,
|
word_count_threshold=1,
|
||||||
extraction_strategy=LLMExtractionStrategy(
|
extraction_strategy=LLMExtractionStrategy(
|
||||||
# provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
|
# provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
|
||||||
provider= "groq/llama-3.1-70b-versatile", api_token = os.getenv('GROQ_API_KEY'),
|
provider="groq/llama-3.1-70b-versatile",
|
||||||
|
api_token=os.getenv("GROQ_API_KEY"),
|
||||||
schema=OpenAIModelFee.model_json_schema(),
|
schema=OpenAIModelFee.model_json_schema(),
|
||||||
extraction_type="schema",
|
extraction_type="schema",
|
||||||
instruction="From the crawled content, extract all mentioned model names along with their " \
|
instruction="From the crawled content, extract all mentioned model names along with their "
|
||||||
"fees for input and output tokens. Make sure not to miss anything in the entire content. " \
|
"fees for input and output tokens. Make sure not to miss anything in the entire content. "
|
||||||
'One extracted model JSON format should look like this: ' \
|
"One extracted model JSON format should look like this: "
|
||||||
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }'
|
'{ "model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens" }',
|
||||||
),
|
),
|
||||||
|
|
||||||
)
|
)
|
||||||
print("Success:", result.success)
|
print("Success:", result.success)
|
||||||
model_fees = json.loads(result.extracted_content)
|
model_fees = json.loads(result.extracted_content)
|
||||||
@@ -37,4 +42,5 @@ async def main():
|
|||||||
with open(".data/data.json", "w", encoding="utf-8") as f:
|
with open(".data/data.json", "w", encoding="utf-8") as f:
|
||||||
f.write(result.extracted_content)
|
f.write(result.extracted_content)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig
|
from crawl4ai import AsyncWebCrawler, CacheMode, BrowserConfig, CrawlerRunConfig
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter
|
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||||
from crawl4ai.extraction_strategy import (
|
from crawl4ai.extraction_strategy import (
|
||||||
JsonCssExtractionStrategy,
|
JsonCssExtractionStrategy,
|
||||||
LLMExtractionStrategy,
|
LLMExtractionStrategy,
|
||||||
@@ -62,6 +62,7 @@ async def clean_content():
|
|||||||
print(f"Full Markdown Length: {full_markdown_length}")
|
print(f"Full Markdown Length: {full_markdown_length}")
|
||||||
print(f"Fit Markdown Length: {fit_markdown_length}")
|
print(f"Fit Markdown Length: {fit_markdown_length}")
|
||||||
|
|
||||||
|
|
||||||
async def link_analysis():
|
async def link_analysis():
|
||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(
|
||||||
cache_mode=CacheMode.ENABLED,
|
cache_mode=CacheMode.ENABLED,
|
||||||
@@ -76,9 +77,10 @@ async def link_analysis():
|
|||||||
print(f"Found {len(result.links['internal'])} internal links")
|
print(f"Found {len(result.links['internal'])} internal links")
|
||||||
print(f"Found {len(result.links['external'])} external links")
|
print(f"Found {len(result.links['external'])} external links")
|
||||||
|
|
||||||
for link in result.links['internal'][:5]:
|
for link in result.links["internal"][:5]:
|
||||||
print(f"Href: {link['href']}\nText: {link['text']}\n")
|
print(f"Href: {link['href']}\nText: {link['text']}\n")
|
||||||
|
|
||||||
|
|
||||||
# JavaScript Execution Example
|
# JavaScript Execution Example
|
||||||
async def simple_example_with_running_js_code():
|
async def simple_example_with_running_js_code():
|
||||||
print("\n--- Executing JavaScript and Using CSS Selectors ---")
|
print("\n--- Executing JavaScript and Using CSS Selectors ---")
|
||||||
@@ -112,25 +114,29 @@ async def simple_example_with_css_selector():
|
|||||||
)
|
)
|
||||||
print(result.markdown[:500])
|
print(result.markdown[:500])
|
||||||
|
|
||||||
|
|
||||||
async def media_handling():
|
async def media_handling():
|
||||||
crawler_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True)
|
crawler_config = CrawlerRunConfig(
|
||||||
|
cache_mode=CacheMode.BYPASS, exclude_external_images=True, screenshot=True
|
||||||
|
)
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business", config=crawler_config
|
||||||
config=crawler_config
|
|
||||||
)
|
)
|
||||||
for img in result.media['images'][:5]:
|
for img in result.media["images"][:5]:
|
||||||
print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}")
|
print(f"Image URL: {img['src']}, Alt: {img['alt']}, Score: {img['score']}")
|
||||||
|
|
||||||
|
|
||||||
async def custom_hook_workflow(verbose=True):
|
async def custom_hook_workflow(verbose=True):
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
# Set a 'before_goto' hook to run custom code just before navigation
|
# Set a 'before_goto' hook to run custom code just before navigation
|
||||||
crawler.crawler_strategy.set_hook("before_goto", lambda page, context: print("[Hook] Preparing to navigate..."))
|
crawler.crawler_strategy.set_hook(
|
||||||
|
"before_goto",
|
||||||
|
lambda page, context: print("[Hook] Preparing to navigate..."),
|
||||||
|
)
|
||||||
|
|
||||||
# Perform the crawl operation
|
# Perform the crawl operation
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(url="https://crawl4ai.com")
|
||||||
url="https://crawl4ai.com"
|
|
||||||
)
|
|
||||||
print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- "))
|
print(result.markdown_v2.raw_markdown[:500].replace("\n", " -- "))
|
||||||
|
|
||||||
|
|
||||||
@@ -417,16 +423,17 @@ async def cosine_similarity_extraction():
|
|||||||
top_k=3, # Number of top keywords to extract
|
top_k=3, # Number of top keywords to extract
|
||||||
sim_threshold=0.3, # Similarity threshold for clustering
|
sim_threshold=0.3, # Similarity threshold for clustering
|
||||||
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
|
semantic_filter="McDonald's economic impact, American consumer trends", # Keywords to filter the content semantically using embeddings
|
||||||
verbose=True
|
verbose=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156",
|
url="https://www.nbcnews.com/business/consumer/how-mcdonalds-e-coli-crisis-inflation-politics-reflect-american-story-rcna177156",
|
||||||
config=crawl_config
|
config=crawl_config,
|
||||||
)
|
)
|
||||||
print(json.loads(result.extracted_content)[:5])
|
print(json.loads(result.extracted_content)[:5])
|
||||||
|
|
||||||
|
|
||||||
# Browser Comparison
|
# Browser Comparison
|
||||||
async def crawl_custom_browser_type():
|
async def crawl_custom_browser_type():
|
||||||
print("\n--- Browser Comparison ---")
|
print("\n--- Browser Comparison ---")
|
||||||
@@ -484,18 +491,16 @@ async def crawl_with_user_simulation():
|
|||||||
result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config)
|
result = await crawler.arun(url="YOUR-URL-HERE", config=crawler_config)
|
||||||
print(result.markdown)
|
print(result.markdown)
|
||||||
|
|
||||||
|
|
||||||
async def ssl_certification():
|
async def ssl_certification():
|
||||||
# Configure crawler to fetch SSL certificate
|
# Configure crawler to fetch SSL certificate
|
||||||
config = CrawlerRunConfig(
|
config = CrawlerRunConfig(
|
||||||
fetch_ssl_certificate=True,
|
fetch_ssl_certificate=True,
|
||||||
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates
|
cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(url="https://example.com", config=config)
|
||||||
url='https://example.com',
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.success and result.ssl_certificate:
|
if result.success and result.ssl_certificate:
|
||||||
cert = result.ssl_certificate
|
cert = result.ssl_certificate
|
||||||
@@ -511,12 +516,17 @@ async def ssl_certification():
|
|||||||
print("\nCertificate exported to:")
|
print("\nCertificate exported to:")
|
||||||
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
|
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
|
||||||
|
|
||||||
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers
|
pem_data = cert.to_pem(
|
||||||
|
os.path.join(tmp_dir, "certificate.pem")
|
||||||
|
) # For web servers
|
||||||
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
|
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
|
||||||
|
|
||||||
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps
|
der_data = cert.to_der(
|
||||||
|
os.path.join(tmp_dir, "certificate.der")
|
||||||
|
) # For Java apps
|
||||||
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
|
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
|
||||||
|
|
||||||
|
|
||||||
# Speed Comparison
|
# Speed Comparison
|
||||||
async def speed_comparison():
|
async def speed_comparison():
|
||||||
print("\n--- Speed Comparison ---")
|
print("\n--- Speed Comparison ---")
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
|
|
||||||
# append parent directory to system path
|
# append parent directory to system path
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))); os.environ['FIRECRAWL_API_KEY'] = "fc-84b370ccfad44beabc686b38f1769692";
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
os.environ["FIRECRAWL_API_KEY"] = "fc-84b370ccfad44beabc686b38f1769692"
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
# import nest_asyncio
|
# import nest_asyncio
|
||||||
@@ -15,7 +19,7 @@ from bs4 import BeautifulSoup
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter
|
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||||
from crawl4ai.extraction_strategy import (
|
from crawl4ai.extraction_strategy import (
|
||||||
JsonCssExtractionStrategy,
|
JsonCssExtractionStrategy,
|
||||||
LLMExtractionStrategy,
|
LLMExtractionStrategy,
|
||||||
@@ -32,9 +36,12 @@ print("Website: https://crawl4ai.com")
|
|||||||
async def simple_crawl():
|
async def simple_crawl():
|
||||||
print("\n--- Basic Usage ---")
|
print("\n--- Basic Usage ---")
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
result = await crawler.arun(url="https://www.nbcnews.com/business", cache_mode= CacheMode.BYPASS)
|
result = await crawler.arun(
|
||||||
|
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
|
||||||
|
)
|
||||||
print(result.markdown[:500]) # Print first 500 characters
|
print(result.markdown[:500]) # Print first 500 characters
|
||||||
|
|
||||||
|
|
||||||
async def simple_example_with_running_js_code():
|
async def simple_example_with_running_js_code():
|
||||||
print("\n--- Executing JavaScript and Using CSS Selectors ---")
|
print("\n--- Executing JavaScript and Using CSS Selectors ---")
|
||||||
# New code to handle the wait_for parameter
|
# New code to handle the wait_for parameter
|
||||||
@@ -57,6 +64,7 @@ async def simple_example_with_running_js_code():
|
|||||||
)
|
)
|
||||||
print(result.markdown[:500]) # Print first 500 characters
|
print(result.markdown[:500]) # Print first 500 characters
|
||||||
|
|
||||||
|
|
||||||
async def simple_example_with_css_selector():
|
async def simple_example_with_css_selector():
|
||||||
print("\n--- Using CSS Selectors ---")
|
print("\n--- Using CSS Selectors ---")
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -67,26 +75,27 @@ async def simple_example_with_css_selector():
|
|||||||
)
|
)
|
||||||
print(result.markdown[:500]) # Print first 500 characters
|
print(result.markdown[:500]) # Print first 500 characters
|
||||||
|
|
||||||
|
|
||||||
async def use_proxy():
|
async def use_proxy():
|
||||||
print("\n--- Using a Proxy ---")
|
print("\n--- Using a Proxy ---")
|
||||||
print(
|
print(
|
||||||
"Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example."
|
"Note: Replace 'http://your-proxy-url:port' with a working proxy to run this example."
|
||||||
)
|
)
|
||||||
# Uncomment and modify the following lines to use a proxy
|
# Uncomment and modify the following lines to use a proxy
|
||||||
async with AsyncWebCrawler(verbose=True, proxy="http://your-proxy-url:port") as crawler:
|
async with AsyncWebCrawler(
|
||||||
|
verbose=True, proxy="http://your-proxy-url:port"
|
||||||
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business", cache_mode=CacheMode.BYPASS
|
||||||
cache_mode= CacheMode.BYPASS
|
|
||||||
)
|
)
|
||||||
if result.success:
|
if result.success:
|
||||||
print(result.markdown[:500]) # Print first 500 characters
|
print(result.markdown[:500]) # Print first 500 characters
|
||||||
|
|
||||||
|
|
||||||
async def capture_and_save_screenshot(url: str, output_path: str):
|
async def capture_and_save_screenshot(url: str, output_path: str):
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url=url,
|
url=url, screenshot=True, cache_mode=CacheMode.BYPASS
|
||||||
screenshot=True,
|
|
||||||
cache_mode= CacheMode.BYPASS
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.success and result.screenshot:
|
if result.success and result.screenshot:
|
||||||
@@ -96,13 +105,14 @@ async def capture_and_save_screenshot(url: str, output_path: str):
|
|||||||
screenshot_data = base64.b64decode(result.screenshot)
|
screenshot_data = base64.b64decode(result.screenshot)
|
||||||
|
|
||||||
# Save the screenshot as a JPEG file
|
# Save the screenshot as a JPEG file
|
||||||
with open(output_path, 'wb') as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(screenshot_data)
|
f.write(screenshot_data)
|
||||||
|
|
||||||
print(f"Screenshot saved successfully to {output_path}")
|
print(f"Screenshot saved successfully to {output_path}")
|
||||||
else:
|
else:
|
||||||
print("Failed to capture screenshot")
|
print("Failed to capture screenshot")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelFee(BaseModel):
|
class OpenAIModelFee(BaseModel):
|
||||||
model_name: str = Field(..., description="Name of the OpenAI model.")
|
model_name: str = Field(..., description="Name of the OpenAI model.")
|
||||||
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
|
input_fee: str = Field(..., description="Fee for input token for the OpenAI model.")
|
||||||
@@ -110,7 +120,10 @@ class OpenAIModelFee(BaseModel):
|
|||||||
..., description="Fee for output token for the OpenAI model."
|
..., description="Fee for output token for the OpenAI model."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def extract_structured_data_using_llm(provider: str, api_token: str = None, extra_headers: Dict[str, str] = None):
|
|
||||||
|
async def extract_structured_data_using_llm(
|
||||||
|
provider: str, api_token: str = None, extra_headers: Dict[str, str] = None
|
||||||
|
):
|
||||||
print(f"\n--- Extracting Structured Data with {provider} ---")
|
print(f"\n--- Extracting Structured Data with {provider} ---")
|
||||||
|
|
||||||
if api_token is None and provider != "ollama":
|
if api_token is None and provider != "ollama":
|
||||||
@@ -139,12 +152,13 @@ async def extract_structured_data_using_llm(provider: str, api_token: str = None
|
|||||||
instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens.
|
instruction="""From the crawled content, extract all mentioned model names along with their fees for input and output tokens.
|
||||||
Do not miss any models in the entire content. One extracted model JSON format should look like this:
|
Do not miss any models in the entire content. One extracted model JSON format should look like this:
|
||||||
{"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""",
|
{"model_name": "GPT-4", "input_fee": "US$10.00 / 1M tokens", "output_fee": "US$30.00 / 1M tokens"}.""",
|
||||||
extra_args=extra_args
|
extra_args=extra_args,
|
||||||
),
|
),
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
print(result.extracted_content)
|
print(result.extracted_content)
|
||||||
|
|
||||||
|
|
||||||
async def extract_structured_data_using_css_extractor():
|
async def extract_structured_data_using_css_extractor():
|
||||||
print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---")
|
print("\n--- Using JsonCssExtractionStrategy for Fast Structured Output ---")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -175,16 +189,12 @@ async def extract_structured_data_using_css_extractor():
|
|||||||
"name": "course_icon",
|
"name": "course_icon",
|
||||||
"selector": ".image-92",
|
"selector": ".image-92",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "src"
|
"attribute": "src",
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(headless=True, verbose=True) as crawler:
|
||||||
headless=True,
|
|
||||||
verbose=True
|
|
||||||
) as crawler:
|
|
||||||
|
|
||||||
# Create the JavaScript that handles clicking multiple times
|
# Create the JavaScript that handles clicking multiple times
|
||||||
js_click_tabs = """
|
js_click_tabs = """
|
||||||
(async () => {
|
(async () => {
|
||||||
@@ -204,13 +214,14 @@ async def extract_structured_data_using_css_extractor():
|
|||||||
url="https://www.kidocode.com/degrees/technology",
|
url="https://www.kidocode.com/degrees/technology",
|
||||||
extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True),
|
extraction_strategy=JsonCssExtractionStrategy(schema, verbose=True),
|
||||||
js_code=[js_click_tabs],
|
js_code=[js_click_tabs],
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
|
|
||||||
companies = json.loads(result.extracted_content)
|
companies = json.loads(result.extracted_content)
|
||||||
print(f"Successfully extracted {len(companies)} companies")
|
print(f"Successfully extracted {len(companies)} companies")
|
||||||
print(json.dumps(companies[0], indent=2))
|
print(json.dumps(companies[0], indent=2))
|
||||||
|
|
||||||
|
|
||||||
# Advanced Session-Based Crawling with Dynamic Content 🔄
|
# Advanced Session-Based Crawling with Dynamic Content 🔄
|
||||||
async def crawl_dynamic_content_pages_method_1():
|
async def crawl_dynamic_content_pages_method_1():
|
||||||
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
|
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
|
||||||
@@ -267,6 +278,7 @@ async def crawl_dynamic_content_pages_method_1():
|
|||||||
await crawler.crawler_strategy.kill_session(session_id)
|
await crawler.crawler_strategy.kill_session(session_id)
|
||||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||||
|
|
||||||
|
|
||||||
async def crawl_dynamic_content_pages_method_2():
|
async def crawl_dynamic_content_pages_method_2():
|
||||||
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
|
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution ---")
|
||||||
|
|
||||||
@@ -334,8 +346,11 @@ async def crawl_dynamic_content_pages_method_2():
|
|||||||
await crawler.crawler_strategy.kill_session(session_id)
|
await crawler.crawler_strategy.kill_session(session_id)
|
||||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||||
|
|
||||||
|
|
||||||
async def crawl_dynamic_content_pages_method_3():
|
async def crawl_dynamic_content_pages_method_3():
|
||||||
print("\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---")
|
print(
|
||||||
|
"\n--- Advanced Multi-Page Crawling with JavaScript Execution using `wait_for` ---"
|
||||||
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
url = "https://github.com/microsoft/TypeScript/commits/main"
|
url = "https://github.com/microsoft/TypeScript/commits/main"
|
||||||
@@ -395,28 +410,40 @@ async def crawl_dynamic_content_pages_method_3():
|
|||||||
await crawler.crawler_strategy.kill_session(session_id)
|
await crawler.crawler_strategy.kill_session(session_id)
|
||||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||||
|
|
||||||
|
|
||||||
async def crawl_custom_browser_type():
|
async def crawl_custom_browser_type():
|
||||||
# Use Firefox
|
# Use Firefox
|
||||||
start = time.time()
|
start = time.time()
|
||||||
async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler:
|
async with AsyncWebCrawler(
|
||||||
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
|
browser_type="firefox", verbose=True, headless=True
|
||||||
|
) as crawler:
|
||||||
|
result = await crawler.arun(
|
||||||
|
url="https://www.example.com", cache_mode=CacheMode.BYPASS
|
||||||
|
)
|
||||||
print(result.markdown[:500])
|
print(result.markdown[:500])
|
||||||
print("Time taken: ", time.time() - start)
|
print("Time taken: ", time.time() - start)
|
||||||
|
|
||||||
# Use WebKit
|
# Use WebKit
|
||||||
start = time.time()
|
start = time.time()
|
||||||
async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler:
|
async with AsyncWebCrawler(
|
||||||
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
|
browser_type="webkit", verbose=True, headless=True
|
||||||
|
) as crawler:
|
||||||
|
result = await crawler.arun(
|
||||||
|
url="https://www.example.com", cache_mode=CacheMode.BYPASS
|
||||||
|
)
|
||||||
print(result.markdown[:500])
|
print(result.markdown[:500])
|
||||||
print("Time taken: ", time.time() - start)
|
print("Time taken: ", time.time() - start)
|
||||||
|
|
||||||
# Use Chromium (default)
|
# Use Chromium (default)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
|
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
|
||||||
result = await crawler.arun(url="https://www.example.com", cache_mode= CacheMode.BYPASS)
|
result = await crawler.arun(
|
||||||
|
url="https://www.example.com", cache_mode=CacheMode.BYPASS
|
||||||
|
)
|
||||||
print(result.markdown[:500])
|
print(result.markdown[:500])
|
||||||
print("Time taken: ", time.time() - start)
|
print("Time taken: ", time.time() - start)
|
||||||
|
|
||||||
|
|
||||||
async def crawl_with_user_simultion():
|
async def crawl_with_user_simultion():
|
||||||
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
|
async with AsyncWebCrawler(verbose=True, headless=True) as crawler:
|
||||||
url = "YOUR-URL-HERE"
|
url = "YOUR-URL-HERE"
|
||||||
@@ -430,6 +457,7 @@ async def crawl_with_user_simultion():
|
|||||||
|
|
||||||
print(result.markdown)
|
print(result.markdown)
|
||||||
|
|
||||||
|
|
||||||
async def speed_comparison():
|
async def speed_comparison():
|
||||||
# print("\n--- Speed Comparison ---")
|
# print("\n--- Speed Comparison ---")
|
||||||
# print("Firecrawl (simulated):")
|
# print("Firecrawl (simulated):")
|
||||||
@@ -439,11 +467,11 @@ async def speed_comparison():
|
|||||||
# print()
|
# print()
|
||||||
# Simulated Firecrawl performance
|
# Simulated Firecrawl performance
|
||||||
from firecrawl import FirecrawlApp
|
from firecrawl import FirecrawlApp
|
||||||
app = FirecrawlApp(api_key=os.environ['FIRECRAWL_API_KEY'])
|
|
||||||
|
app = FirecrawlApp(api_key=os.environ["FIRECRAWL_API_KEY"])
|
||||||
start = time.time()
|
start = time.time()
|
||||||
scrape_status = app.scrape_url(
|
scrape_status = app.scrape_url(
|
||||||
'https://www.nbcnews.com/business',
|
"https://www.nbcnews.com/business", params={"formats": ["markdown", "html"]}
|
||||||
params={'formats': ['markdown', 'html']}
|
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("Firecrawl:")
|
print("Firecrawl:")
|
||||||
@@ -474,7 +502,9 @@ async def speed_comparison():
|
|||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
word_count_threshold=0,
|
word_count_threshold=0,
|
||||||
markdown_generator=DefaultMarkdownGenerator(
|
markdown_generator=DefaultMarkdownGenerator(
|
||||||
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
|
content_filter=PruningContentFilter(
|
||||||
|
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||||
|
)
|
||||||
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
|
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
|
||||||
),
|
),
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
@@ -498,7 +528,9 @@ async def speed_comparison():
|
|||||||
word_count_threshold=0,
|
word_count_threshold=0,
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(
|
markdown_generator=DefaultMarkdownGenerator(
|
||||||
content_filter = PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0)
|
content_filter=PruningContentFilter(
|
||||||
|
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||||
|
)
|
||||||
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
|
# content_filter=BM25ContentFilter(user_query=None, bm25_threshold=1.0)
|
||||||
),
|
),
|
||||||
verbose=False,
|
verbose=False,
|
||||||
@@ -520,6 +552,7 @@ async def speed_comparison():
|
|||||||
print("If you run these tests in an environment with better network conditions,")
|
print("If you run these tests in an environment with better network conditions,")
|
||||||
print("you may observe an even more significant speed advantage for Crawl4AI.")
|
print("you may observe an even more significant speed advantage for Crawl4AI.")
|
||||||
|
|
||||||
|
|
||||||
async def generate_knowledge_graph():
|
async def generate_knowledge_graph():
|
||||||
class Entity(BaseModel):
|
class Entity(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -536,11 +569,11 @@ async def generate_knowledge_graph():
|
|||||||
relationships: List[Relationship]
|
relationships: List[Relationship]
|
||||||
|
|
||||||
extraction_strategy = LLMExtractionStrategy(
|
extraction_strategy = LLMExtractionStrategy(
|
||||||
provider='openai/gpt-4o-mini', # Or any other provider, including Ollama and open source models
|
provider="openai/gpt-4o-mini", # Or any other provider, including Ollama and open source models
|
||||||
api_token=os.getenv('OPENAI_API_KEY'), # In case of Ollama just pass "no-token"
|
api_token=os.getenv("OPENAI_API_KEY"), # In case of Ollama just pass "no-token"
|
||||||
schema=KnowledgeGraph.model_json_schema(),
|
schema=KnowledgeGraph.model_json_schema(),
|
||||||
extraction_type="schema",
|
extraction_type="schema",
|
||||||
instruction="""Extract entities and relationships from the given text."""
|
instruction="""Extract entities and relationships from the given text.""",
|
||||||
)
|
)
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
url = "https://paulgraham.com/love.html"
|
url = "https://paulgraham.com/love.html"
|
||||||
@@ -554,27 +587,22 @@ async def generate_knowledge_graph():
|
|||||||
with open(os.path.join(__location__, "kb.json"), "w") as f:
|
with open(os.path.join(__location__, "kb.json"), "w") as f:
|
||||||
f.write(result.extracted_content)
|
f.write(result.extracted_content)
|
||||||
|
|
||||||
async def fit_markdown_remove_overlay():
|
|
||||||
|
|
||||||
|
async def fit_markdown_remove_overlay():
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
headless=True, # Set to False to see what is happening
|
headless=True, # Set to False to see what is happening
|
||||||
verbose=True,
|
verbose=True,
|
||||||
user_agent_mode="random",
|
user_agent_mode="random",
|
||||||
user_agent_generator_config={
|
user_agent_generator_config={"device_type": "mobile", "os_type": "android"},
|
||||||
"device_type": "mobile",
|
|
||||||
"os_type": "android"
|
|
||||||
},
|
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url='https://www.kidocode.com/degrees/technology',
|
url="https://www.kidocode.com/degrees/technology",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(
|
markdown_generator=DefaultMarkdownGenerator(
|
||||||
content_filter=PruningContentFilter(
|
content_filter=PruningContentFilter(
|
||||||
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
threshold=0.48, threshold_type="fixed", min_word_threshold=0
|
||||||
),
|
),
|
||||||
options={
|
options={"ignore_links": True},
|
||||||
"ignore_links": True
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
# markdown_generator=DefaultMarkdownGenerator(
|
# markdown_generator=DefaultMarkdownGenerator(
|
||||||
# content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0),
|
# content_filter=BM25ContentFilter(user_query="", bm25_threshold=1.0),
|
||||||
@@ -593,13 +621,20 @@ async def fit_markdown_remove_overlay():
|
|||||||
with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f:
|
with open(os.path.join(__location__, "output/cleaned_html.html"), "w") as f:
|
||||||
f.write(result.cleaned_html)
|
f.write(result.cleaned_html)
|
||||||
|
|
||||||
with open(os.path.join(__location__, "output/output_raw_markdown.md"), "w") as f:
|
with open(
|
||||||
|
os.path.join(__location__, "output/output_raw_markdown.md"), "w"
|
||||||
|
) as f:
|
||||||
f.write(result.markdown_v2.raw_markdown)
|
f.write(result.markdown_v2.raw_markdown)
|
||||||
|
|
||||||
with open(os.path.join(__location__, "output/output_markdown_with_citations.md"), "w") as f:
|
with open(
|
||||||
|
os.path.join(__location__, "output/output_markdown_with_citations.md"),
|
||||||
|
"w",
|
||||||
|
) as f:
|
||||||
f.write(result.markdown_v2.markdown_with_citations)
|
f.write(result.markdown_v2.markdown_with_citations)
|
||||||
|
|
||||||
with open(os.path.join(__location__, "output/output_fit_markdown.md"), "w") as f:
|
with open(
|
||||||
|
os.path.join(__location__, "output/output_fit_markdown.md"), "w"
|
||||||
|
) as f:
|
||||||
f.write(result.markdown_v2.fit_markdown)
|
f.write(result.markdown_v2.fit_markdown)
|
||||||
|
|
||||||
print("Done")
|
print("Done")
|
||||||
|
|||||||
@@ -10,15 +10,17 @@ from functools import lru_cache
|
|||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def create_crawler():
|
def create_crawler():
|
||||||
crawler = WebCrawler(verbose=True)
|
crawler = WebCrawler(verbose=True)
|
||||||
crawler.warmup()
|
crawler.warmup()
|
||||||
return crawler
|
return crawler
|
||||||
|
|
||||||
|
|
||||||
def print_result(result):
|
def print_result(result):
|
||||||
# Print each key in one line and just the first 10 characters of each one's value and three dots
|
# Print each key in one line and just the first 10 characters of each one's value and three dots
|
||||||
console.print(f"\t[bold]Result:[/bold]")
|
console.print("\t[bold]Result:[/bold]")
|
||||||
for key, value in result.model_dump().items():
|
for key, value in result.model_dump().items():
|
||||||
if isinstance(value, str) and value:
|
if isinstance(value, str) and value:
|
||||||
console.print(f"\t{key}: [green]{value[:20]}...[/green]")
|
console.print(f"\t{key}: [green]{value[:20]}...[/green]")
|
||||||
@@ -33,18 +35,27 @@ def cprint(message, press_any_key=False):
|
|||||||
console.print("Press any key to continue...", style="")
|
console.print("Press any key to continue...", style="")
|
||||||
input()
|
input()
|
||||||
|
|
||||||
|
|
||||||
def basic_usage(crawler):
|
def basic_usage(crawler):
|
||||||
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
|
cprint(
|
||||||
|
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
|
||||||
|
)
|
||||||
result = crawler.run(url="https://www.nbcnews.com/business", only_text=True)
|
result = crawler.run(url="https://www.nbcnews.com/business", only_text=True)
|
||||||
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def basic_usage_some_params(crawler):
|
def basic_usage_some_params(crawler):
|
||||||
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
|
cprint(
|
||||||
result = crawler.run(url="https://www.nbcnews.com/business", word_count_threshold=1, only_text = True)
|
"🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]"
|
||||||
|
)
|
||||||
|
result = crawler.run(
|
||||||
|
url="https://www.nbcnews.com/business", word_count_threshold=1, only_text=True
|
||||||
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def screenshot_usage(crawler):
|
def screenshot_usage(crawler):
|
||||||
cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]")
|
cprint("\n📸 [bold cyan]Let's take a screenshot of the page![/bold cyan]")
|
||||||
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
|
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
|
||||||
@@ -55,16 +66,23 @@ def screenshot_usage(crawler):
|
|||||||
cprint("Screenshot saved to 'screenshot.png'!")
|
cprint("Screenshot saved to 'screenshot.png'!")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def understanding_parameters(crawler):
|
def understanding_parameters(crawler):
|
||||||
cprint("\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]")
|
cprint(
|
||||||
cprint("By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action.")
|
"\n🧠 [bold cyan]Understanding 'bypass_cache' and 'include_raw_html' parameters:[/bold cyan]"
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"By default, Crawl4ai caches the results of your crawls. This means that subsequent crawls of the same URL will be much faster! Let's see this in action."
|
||||||
|
)
|
||||||
|
|
||||||
# First crawl (reads from cache)
|
# First crawl (reads from cache)
|
||||||
cprint("1️⃣ First crawl (caches the result):", True)
|
cprint("1️⃣ First crawl (caches the result):", True)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = crawler.run(url="https://www.nbcnews.com/business")
|
result = crawler.run(url="https://www.nbcnews.com/business")
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
cprint(f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]")
|
cprint(
|
||||||
|
f"[LOG] 📦 [bold yellow]First crawl took {end_time - start_time} seconds and result (from cache):[/bold yellow]"
|
||||||
|
)
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
# Force to crawl again
|
# Force to crawl again
|
||||||
@@ -72,132 +90,194 @@ def understanding_parameters(crawler):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True)
|
result = crawler.run(url="https://www.nbcnews.com/business", bypass_cache=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
cprint(f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]")
|
cprint(
|
||||||
|
f"[LOG] 📦 [bold yellow]Second crawl took {end_time - start_time} seconds and result (forced to crawl):[/bold yellow]"
|
||||||
|
)
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def add_chunking_strategy(crawler):
|
def add_chunking_strategy(crawler):
|
||||||
# Adding a chunking strategy: RegexChunking
|
# Adding a chunking strategy: RegexChunking
|
||||||
cprint("\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]", True)
|
cprint(
|
||||||
cprint("RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!")
|
"\n🧩 [bold cyan]Let's add a chunking strategy: RegexChunking![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"RegexChunking is a simple chunking strategy that splits the text based on a given regex pattern. Let's see it in action!"
|
||||||
|
)
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
chunking_strategy=RegexChunking(patterns=["\n\n"])
|
chunking_strategy=RegexChunking(patterns=["\n\n"]),
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]RegexChunking result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
# Adding another chunking strategy: NlpSentenceChunking
|
# Adding another chunking strategy: NlpSentenceChunking
|
||||||
cprint("\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]", True)
|
cprint(
|
||||||
cprint("NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!")
|
"\n🔍 [bold cyan]Time to explore another chunking strategy: NlpSentenceChunking![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"NlpSentenceChunking uses NLP techniques to split the text into sentences. Let's see how it performs!"
|
||||||
|
)
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business", chunking_strategy=NlpSentenceChunking()
|
||||||
chunking_strategy=NlpSentenceChunking()
|
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]NlpSentenceChunking result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def add_extraction_strategy(crawler):
|
def add_extraction_strategy(crawler):
|
||||||
# Adding an extraction strategy: CosineStrategy
|
# Adding an extraction strategy: CosineStrategy
|
||||||
cprint("\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]", True)
|
cprint(
|
||||||
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
"\n🧠 [bold cyan]Let's get smarter with an extraction strategy: CosineStrategy![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!"
|
||||||
|
)
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True)
|
extraction_strategy=CosineStrategy(
|
||||||
|
word_count_threshold=10,
|
||||||
|
max_dist=0.2,
|
||||||
|
linkage_method="ward",
|
||||||
|
top_k=3,
|
||||||
|
sim_threshold=0.3,
|
||||||
|
verbose=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
# Using semantic_filter with CosineStrategy
|
# Using semantic_filter with CosineStrategy
|
||||||
cprint("You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!")
|
cprint(
|
||||||
|
"You can pass other parameters like 'semantic_filter' to the CosineStrategy to extract semantically similar blocks of text. Let's see it in action!"
|
||||||
|
)
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=CosineStrategy(
|
extraction_strategy=CosineStrategy(
|
||||||
semantic_filter="inflation rent prices",
|
semantic_filter="inflation rent prices",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
cprint(
|
||||||
|
"[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]"
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result with semantic filter:[/bold yellow]")
|
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def add_llm_extraction_strategy(crawler):
|
def add_llm_extraction_strategy(crawler):
|
||||||
# Adding an LLM extraction strategy without instructions
|
# Adding an LLM extraction strategy without instructions
|
||||||
cprint("\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]", True)
|
cprint(
|
||||||
cprint("LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!")
|
"\n🤖 [bold cyan]Time to bring in the big guns: LLMExtractionStrategy without instructions![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"LLMExtractionStrategy uses a large language model to extract relevant information from the web page. Let's see it in action!"
|
||||||
|
)
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-4o", api_token=os.getenv('OPENAI_API_KEY'))
|
extraction_strategy=LLMExtractionStrategy(
|
||||||
|
provider="openai/gpt-4o", api_token=os.getenv("OPENAI_API_KEY")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]"
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (no instructions) result:[/bold yellow]")
|
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
# Adding an LLM extraction strategy with instructions
|
# Adding an LLM extraction strategy with instructions
|
||||||
cprint("\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]", True)
|
cprint(
|
||||||
cprint("Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!")
|
"\n📜 [bold cyan]Let's make it even more interesting: LLMExtractionStrategy with instructions![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"Let's say we are only interested in financial news. Let's see how LLMExtractionStrategy performs with instructions!"
|
||||||
|
)
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=LLMExtractionStrategy(
|
extraction_strategy=LLMExtractionStrategy(
|
||||||
provider="openai/gpt-4o",
|
provider="openai/gpt-4o",
|
||||||
api_token=os.getenv('OPENAI_API_KEY'),
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
instruction="I am interested in only financial news"
|
instruction="I am interested in only financial news",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
cprint(
|
||||||
|
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]"
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with instructions) result:[/bold yellow]")
|
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=LLMExtractionStrategy(
|
extraction_strategy=LLMExtractionStrategy(
|
||||||
provider="openai/gpt-4o",
|
provider="openai/gpt-4o",
|
||||||
api_token=os.getenv('OPENAI_API_KEY'),
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
instruction="Extract only content related to technology"
|
instruction="Extract only content related to technology",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
cprint(
|
||||||
|
"[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]"
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]LLMExtractionStrategy (with technology instruction) result:[/bold yellow]")
|
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def targeted_extraction(crawler):
|
def targeted_extraction(crawler):
|
||||||
# Using a CSS selector to extract only H2 tags
|
# Using a CSS selector to extract only H2 tags
|
||||||
cprint("\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]", True)
|
cprint(
|
||||||
result = crawler.run(
|
"\n🎯 [bold cyan]Targeted extraction: Let's use a CSS selector to extract only H2 tags![/bold cyan]",
|
||||||
url="https://www.nbcnews.com/business",
|
True,
|
||||||
css_selector="h2"
|
|
||||||
)
|
)
|
||||||
|
result = crawler.run(url="https://www.nbcnews.com/business", css_selector="h2")
|
||||||
cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]CSS Selector (H2 tags) result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def interactive_extraction(crawler):
|
def interactive_extraction(crawler):
|
||||||
# Passing JavaScript code to interact with the page
|
# Passing JavaScript code to interact with the page
|
||||||
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True)
|
cprint(
|
||||||
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.")
|
"\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"In this example we try to click the 'Load More' button on the page using JavaScript code."
|
||||||
|
)
|
||||||
js_code = """
|
js_code = """
|
||||||
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
|
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
|
||||||
loadMoreButton && loadMoreButton.click();
|
loadMoreButton && loadMoreButton.click();
|
||||||
"""
|
"""
|
||||||
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
|
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
|
||||||
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
|
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
|
||||||
result = crawler.run(
|
result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
|
||||||
url="https://www.nbcnews.com/business",
|
cprint(
|
||||||
js = js_code
|
"[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
|
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def multiple_scrip(crawler):
|
def multiple_scrip(crawler):
|
||||||
# Passing JavaScript code to interact with the page
|
# Passing JavaScript code to interact with the page
|
||||||
cprint("\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]", True)
|
cprint(
|
||||||
cprint("In this example we try to click the 'Load More' button on the page using JavaScript code.")
|
"\n🖱️ [bold cyan]Let's get interactive: Passing JavaScript code to click 'Load More' button![/bold cyan]",
|
||||||
js_code = ["""
|
True,
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"In this example we try to click the 'Load More' button on the page using JavaScript code."
|
||||||
|
)
|
||||||
|
js_code = [
|
||||||
|
"""
|
||||||
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
|
const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More'));
|
||||||
loadMoreButton && loadMoreButton.click();
|
loadMoreButton && loadMoreButton.click();
|
||||||
"""] * 2
|
"""
|
||||||
|
] * 2
|
||||||
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
|
# crawler_strategy = LocalSeleniumCrawlerStrategy(js_code=js_code)
|
||||||
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
|
# crawler = WebCrawler(crawler_strategy=crawler_strategy, always_by_pass_cache=True)
|
||||||
result = crawler.run(
|
result = crawler.run(url="https://www.nbcnews.com/business", js=js_code)
|
||||||
url="https://www.nbcnews.com/business",
|
cprint(
|
||||||
js = js_code
|
"[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]"
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]JavaScript Code (Load More button) result:[/bold yellow]")
|
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
def using_crawler_hooks(crawler):
|
def using_crawler_hooks(crawler):
|
||||||
# Example usage of the hooks for authentication and setting a cookie
|
# Example usage of the hooks for authentication and setting a cookie
|
||||||
def on_driver_created(driver):
|
def on_driver_created(driver):
|
||||||
@@ -206,33 +286,34 @@ def using_crawler_hooks(crawler):
|
|||||||
driver.maximize_window()
|
driver.maximize_window()
|
||||||
|
|
||||||
# Example customization: logging in to a hypothetical website
|
# Example customization: logging in to a hypothetical website
|
||||||
driver.get('https://example.com/login')
|
driver.get("https://example.com/login")
|
||||||
|
|
||||||
from selenium.webdriver.support.ui import WebDriverWait
|
from selenium.webdriver.support.ui import WebDriverWait
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
from selenium.webdriver.support import expected_conditions as EC
|
from selenium.webdriver.support import expected_conditions as EC
|
||||||
|
|
||||||
WebDriverWait(driver, 10).until(
|
WebDriverWait(driver, 10).until(
|
||||||
EC.presence_of_element_located((By.NAME, 'username'))
|
EC.presence_of_element_located((By.NAME, "username"))
|
||||||
)
|
)
|
||||||
driver.find_element(By.NAME, 'username').send_keys('testuser')
|
driver.find_element(By.NAME, "username").send_keys("testuser")
|
||||||
driver.find_element(By.NAME, 'password').send_keys('password123')
|
driver.find_element(By.NAME, "password").send_keys("password123")
|
||||||
driver.find_element(By.NAME, 'login').click()
|
driver.find_element(By.NAME, "login").click()
|
||||||
WebDriverWait(driver, 10).until(
|
WebDriverWait(driver, 10).until(
|
||||||
EC.presence_of_element_located((By.ID, 'welcome'))
|
EC.presence_of_element_located((By.ID, "welcome"))
|
||||||
)
|
)
|
||||||
# Add a custom cookie
|
# Add a custom cookie
|
||||||
driver.add_cookie({'name': 'test_cookie', 'value': 'cookie_value'})
|
driver.add_cookie({"name": "test_cookie", "value": "cookie_value"})
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
|
|
||||||
def before_get_url(driver):
|
def before_get_url(driver):
|
||||||
print("[HOOK] before_get_url")
|
print("[HOOK] before_get_url")
|
||||||
# Example customization: add a custom header
|
# Example customization: add a custom header
|
||||||
# Enable Network domain for sending headers
|
# Enable Network domain for sending headers
|
||||||
driver.execute_cdp_cmd('Network.enable', {})
|
driver.execute_cdp_cmd("Network.enable", {})
|
||||||
# Add a custom header
|
# Add a custom header
|
||||||
driver.execute_cdp_cmd('Network.setExtraHTTPHeaders', {'headers': {'X-Test-Header': 'test'}})
|
driver.execute_cdp_cmd(
|
||||||
|
"Network.setExtraHTTPHeaders", {"headers": {"X-Test-Header": "test"}}
|
||||||
|
)
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
def after_get_url(driver):
|
def after_get_url(driver):
|
||||||
@@ -247,13 +328,16 @@ def using_crawler_hooks(crawler):
|
|||||||
print(len(html))
|
print(len(html))
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]", True)
|
cprint(
|
||||||
|
"\n🔗 [bold cyan]Using Crawler Hooks: Let's see how we can customize the crawler using hooks![/bold cyan]",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
|
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
|
||||||
crawler_strategy.set_hook('on_driver_created', on_driver_created)
|
crawler_strategy.set_hook("on_driver_created", on_driver_created)
|
||||||
crawler_strategy.set_hook('before_get_url', before_get_url)
|
crawler_strategy.set_hook("before_get_url", before_get_url)
|
||||||
crawler_strategy.set_hook('after_get_url', after_get_url)
|
crawler_strategy.set_hook("after_get_url", after_get_url)
|
||||||
crawler_strategy.set_hook('before_return_html', before_return_html)
|
crawler_strategy.set_hook("before_return_html", before_return_html)
|
||||||
|
|
||||||
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
|
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
|
||||||
crawler.warmup()
|
crawler.warmup()
|
||||||
@@ -262,6 +346,7 @@ def using_crawler_hooks(crawler):
|
|||||||
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]Crawler Hooks result:[/bold yellow]")
|
||||||
print_result(result=result)
|
print_result(result=result)
|
||||||
|
|
||||||
|
|
||||||
def using_crawler_hooks_dleay_example(crawler):
|
def using_crawler_hooks_dleay_example(crawler):
|
||||||
def delay(driver):
|
def delay(driver):
|
||||||
print("Delaying for 5 seconds...")
|
print("Delaying for 5 seconds...")
|
||||||
@@ -270,12 +355,14 @@ def using_crawler_hooks_dleay_example(crawler):
|
|||||||
|
|
||||||
def create_crawler():
|
def create_crawler():
|
||||||
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
|
crawler_strategy = LocalSeleniumCrawlerStrategy(verbose=True)
|
||||||
crawler_strategy.set_hook('after_get_url', delay)
|
crawler_strategy.set_hook("after_get_url", delay)
|
||||||
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
|
crawler = WebCrawler(verbose=True, crawler_strategy=crawler_strategy)
|
||||||
crawler.warmup()
|
crawler.warmup()
|
||||||
return crawler
|
return crawler
|
||||||
|
|
||||||
cprint("\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]")
|
cprint(
|
||||||
|
"\n🔗 [bold cyan]Using Crawler Hooks: Let's add a delay after fetching the url to make sure entire page is fetched.[/bold cyan]"
|
||||||
|
)
|
||||||
crawler = create_crawler()
|
crawler = create_crawler()
|
||||||
result = crawler.run(url="https://google.com", bypass_cache=True)
|
result = crawler.run(url="https://google.com", bypass_cache=True)
|
||||||
|
|
||||||
@@ -283,11 +370,16 @@ def using_crawler_hooks_dleay_example(crawler):
|
|||||||
print_result(result)
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cprint("🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]")
|
cprint(
|
||||||
cprint("⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]")
|
"🌟 [bold green]Welcome to the Crawl4ai Quickstart Guide! Let's dive into some web crawling fun! 🌐[/bold green]"
|
||||||
cprint("If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files.")
|
)
|
||||||
|
cprint(
|
||||||
|
"⛳️ [bold cyan]First Step: Create an instance of WebCrawler and call the `warmup()` function.[/bold cyan]"
|
||||||
|
)
|
||||||
|
cprint(
|
||||||
|
"If this is the first time you're running Crawl4ai, this might take a few seconds to load required model files."
|
||||||
|
)
|
||||||
|
|
||||||
crawler = create_crawler()
|
crawler = create_crawler()
|
||||||
|
|
||||||
@@ -305,8 +397,10 @@ def main():
|
|||||||
interactive_extraction(crawler)
|
interactive_extraction(crawler)
|
||||||
multiple_scrip(crawler)
|
multiple_scrip(crawler)
|
||||||
|
|
||||||
cprint("\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]")
|
cprint(
|
||||||
|
"\n🎉 [bold green]Congratulations! You've made it through the Crawl4ai Quickstart Guide! Now go forth and crawl the web like a pro! 🕸️[/bold green]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ from groq import Groq
|
|||||||
# Import threadpools to run the crawl_url function in a separate thread
|
# Import threadpools to run the crawl_url function in a separate thread
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
client = AsyncOpenAI(base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY"))
|
client = AsyncOpenAI(
|
||||||
|
base_url="https://api.groq.com/openai/v1", api_key=os.getenv("GROQ_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
# Instrument the OpenAI client
|
# Instrument the OpenAI client
|
||||||
cl.instrument_openai()
|
cl.instrument_openai()
|
||||||
@@ -25,32 +27,31 @@ settings = {
|
|||||||
"presence_penalty": 0,
|
"presence_penalty": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def extract_urls(text):
|
def extract_urls(text):
|
||||||
url_pattern = re.compile(r'(https?://\S+)')
|
url_pattern = re.compile(r"(https?://\S+)")
|
||||||
return url_pattern.findall(text)
|
return url_pattern.findall(text)
|
||||||
|
|
||||||
|
|
||||||
def crawl_url(url):
|
def crawl_url(url):
|
||||||
data = {
|
data = {
|
||||||
"urls": [url],
|
"urls": [url],
|
||||||
"include_raw_html": True,
|
"include_raw_html": True,
|
||||||
"word_count_threshold": 10,
|
"word_count_threshold": 10,
|
||||||
"extraction_strategy": "NoExtractionStrategy",
|
"extraction_strategy": "NoExtractionStrategy",
|
||||||
"chunking_strategy": "RegexChunking"
|
"chunking_strategy": "RegexChunking",
|
||||||
}
|
}
|
||||||
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
response_data = response_data['results'][0]
|
response_data = response_data["results"][0]
|
||||||
return response_data['markdown']
|
return response_data["markdown"]
|
||||||
|
|
||||||
|
|
||||||
@cl.on_chat_start
|
@cl.on_chat_start
|
||||||
async def on_chat_start():
|
async def on_chat_start():
|
||||||
cl.user_session.set("session", {
|
cl.user_session.set("session", {"history": [], "context": {}})
|
||||||
"history": [],
|
await cl.Message(content="Welcome to the chat! How can I assist you today?").send()
|
||||||
"context": {}
|
|
||||||
})
|
|
||||||
await cl.Message(
|
|
||||||
content="Welcome to the chat! How can I assist you today?"
|
|
||||||
).send()
|
|
||||||
|
|
||||||
@cl.on_message
|
@cl.on_message
|
||||||
async def on_message(message: cl.Message):
|
async def on_message(message: cl.Message):
|
||||||
@@ -59,7 +60,6 @@ async def on_message(message: cl.Message):
|
|||||||
# Extract URLs from the user's message
|
# Extract URLs from the user's message
|
||||||
urls = extract_urls(message.content)
|
urls = extract_urls(message.content)
|
||||||
|
|
||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
for url in urls:
|
for url in urls:
|
||||||
@@ -69,16 +69,9 @@ async def on_message(message: cl.Message):
|
|||||||
|
|
||||||
for url, result in zip(urls, results):
|
for url, result in zip(urls, results):
|
||||||
ref_number = f"REF_{len(user_session['context']) + 1}"
|
ref_number = f"REF_{len(user_session['context']) + 1}"
|
||||||
user_session["context"][ref_number] = {
|
user_session["context"][ref_number] = {"url": url, "content": result}
|
||||||
"url": url,
|
|
||||||
"content": result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
user_session["history"].append({"role": "user", "content": message.content})
|
||||||
user_session["history"].append({
|
|
||||||
"role": "user",
|
|
||||||
"content": message.content
|
|
||||||
})
|
|
||||||
|
|
||||||
# Create a system message that includes the context
|
# Create a system message that includes the context
|
||||||
context_messages = [
|
context_messages = [
|
||||||
@@ -95,26 +88,17 @@ async def on_message(message: cl.Message):
|
|||||||
"If not, there is no need to add a references section. "
|
"If not, there is no need to add a references section. "
|
||||||
"At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n"
|
"At the end of your response, provide a reference section listing the URLs and their REF numbers only if sources from the appendices were used.\n\n"
|
||||||
"\n\n".join(context_messages)
|
"\n\n".join(context_messages)
|
||||||
)
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
system_message = {
|
system_message = {"role": "system", "content": "You are a helpful assistant."}
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant."
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
msg = cl.Message(content="")
|
msg = cl.Message(content="")
|
||||||
await msg.send()
|
await msg.send()
|
||||||
|
|
||||||
# Get response from the LLM
|
# Get response from the LLM
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
messages=[
|
messages=[system_message, *user_session["history"]], stream=True, **settings
|
||||||
system_message,
|
|
||||||
*user_session["history"]
|
|
||||||
],
|
|
||||||
stream=True,
|
|
||||||
**settings
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assistant_response = ""
|
assistant_response = ""
|
||||||
@@ -124,10 +108,7 @@ async def on_message(message: cl.Message):
|
|||||||
await msg.stream_token(token)
|
await msg.stream_token(token)
|
||||||
|
|
||||||
# Add assistant message to the history
|
# Add assistant message to the history
|
||||||
user_session["history"].append({
|
user_session["history"].append({"role": "assistant", "content": assistant_response})
|
||||||
"role": "assistant",
|
|
||||||
"content": assistant_response
|
|
||||||
})
|
|
||||||
await msg.update()
|
await msg.update()
|
||||||
|
|
||||||
# Append the reference section to the assistant's response
|
# Append the reference section to the assistant's response
|
||||||
@@ -154,6 +135,7 @@ async def on_audio_chunk(chunk: cl.AudioChunk):
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@cl.step(type="tool")
|
@cl.step(type="tool")
|
||||||
async def speech_to_text(audio_file):
|
async def speech_to_text(audio_file):
|
||||||
cli = Groq()
|
cli = Groq()
|
||||||
@@ -179,17 +161,12 @@ async def on_audio_end(elements: list[ElementBased]):
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"Transcription took {end_time - start_time} seconds")
|
print(f"Transcription took {end_time - start_time} seconds")
|
||||||
|
|
||||||
user_msg = cl.Message(
|
user_msg = cl.Message(author="You", type="user_message", content=transcription)
|
||||||
author="You",
|
|
||||||
type="user_message",
|
|
||||||
content=transcription
|
|
||||||
)
|
|
||||||
await user_msg.send()
|
await user_msg.send()
|
||||||
await on_message(user_msg)
|
await on_message(user_msg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from chainlit.cli import run_chainlit
|
from chainlit.cli import run_chainlit
|
||||||
|
|
||||||
run_chainlit(__file__)
|
run_chainlit(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import requests, base64, os
|
import requests, base64, os
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
@@ -7,58 +6,49 @@ data = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
response = requests.post("https://crawl4ai.com/crawl", json=data)
|
||||||
result = response.json()['results'][0]
|
result = response.json()["results"][0]
|
||||||
print(result.keys())
|
print(result.keys())
|
||||||
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
|
# dict_keys(['url', 'html', 'success', 'cleaned_html', 'media',
|
||||||
# 'links', 'screenshot', 'markdown', 'extracted_content',
|
# 'links', 'screenshot', 'markdown', 'extracted_content',
|
||||||
# 'metadata', 'error_message'])
|
# 'metadata', 'error_message'])
|
||||||
with open("screenshot.png", "wb") as f:
|
with open("screenshot.png", "wb") as f:
|
||||||
f.write(base64.b64decode(result['screenshot']))
|
f.write(base64.b64decode(result["screenshot"]))
|
||||||
|
|
||||||
# Example of filtering the content using CSS selectors
|
# Example of filtering the content using CSS selectors
|
||||||
data = {
|
data = {
|
||||||
"urls": [
|
"urls": ["https://www.nbcnews.com/business"],
|
||||||
"https://www.nbcnews.com/business"
|
|
||||||
],
|
|
||||||
"css_selector": "article",
|
"css_selector": "article",
|
||||||
"screenshot": True,
|
"screenshot": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Example of executing a JS script on the page before extracting the content
|
# Example of executing a JS script on the page before extracting the content
|
||||||
data = {
|
data = {
|
||||||
"urls": [
|
"urls": ["https://www.nbcnews.com/business"],
|
||||||
"https://www.nbcnews.com/business"
|
|
||||||
],
|
|
||||||
"screenshot": True,
|
"screenshot": True,
|
||||||
'js' : ["""
|
"js": [
|
||||||
|
"""
|
||||||
const loadMoreButton = Array.from(document.querySelectorAll('button')).
|
const loadMoreButton = Array.from(document.querySelectorAll('button')).
|
||||||
find(button => button.textContent.includes('Load More'));
|
find(button => button.textContent.includes('Load More'));
|
||||||
loadMoreButton && loadMoreButton.click();
|
loadMoreButton && loadMoreButton.click();
|
||||||
"""]
|
"""
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Example of using a custom extraction strategy
|
# Example of using a custom extraction strategy
|
||||||
data = {
|
data = {
|
||||||
"urls": [
|
"urls": ["https://www.nbcnews.com/business"],
|
||||||
"https://www.nbcnews.com/business"
|
|
||||||
],
|
|
||||||
"extraction_strategy": "CosineStrategy",
|
"extraction_strategy": "CosineStrategy",
|
||||||
"extraction_strategy_args": {
|
"extraction_strategy_args": {"semantic_filter": "inflation rent prices"},
|
||||||
"semantic_filter": "inflation rent prices"
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Example of using LLM to extract content
|
# Example of using LLM to extract content
|
||||||
data = {
|
data = {
|
||||||
"urls": [
|
"urls": ["https://www.nbcnews.com/business"],
|
||||||
"https://www.nbcnews.com/business"
|
|
||||||
],
|
|
||||||
"extraction_strategy": "LLMExtractionStrategy",
|
"extraction_strategy": "LLMExtractionStrategy",
|
||||||
"extraction_strategy_args": {
|
"extraction_strategy_args": {
|
||||||
"provider": "groq/llama3-8b-8192",
|
"provider": "groq/llama3-8b-8192",
|
||||||
"api_token": os.environ.get("GROQ_API_KEY"),
|
"api_token": os.environ.get("GROQ_API_KEY"),
|
||||||
"instruction": """I am interested in only financial news,
|
"instruction": """I am interested in only financial news,
|
||||||
and translate them in French."""
|
and translate them in French.""",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,22 +5,22 @@ import os
|
|||||||
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode
|
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, CacheMode
|
||||||
|
|
||||||
# Create tmp directory if it doesn't exist
|
# Create tmp directory if it doesn't exist
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
parent_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
)
|
||||||
tmp_dir = os.path.join(parent_dir, "tmp")
|
tmp_dir = os.path.join(parent_dir, "tmp")
|
||||||
os.makedirs(tmp_dir, exist_ok=True)
|
os.makedirs(tmp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Configure crawler to fetch SSL certificate
|
# Configure crawler to fetch SSL certificate
|
||||||
config = CrawlerRunConfig(
|
config = CrawlerRunConfig(
|
||||||
fetch_ssl_certificate=True,
|
fetch_ssl_certificate=True,
|
||||||
cache_mode=CacheMode.BYPASS # Bypass cache to always get fresh certificates
|
cache_mode=CacheMode.BYPASS, # Bypass cache to always get fresh certificates
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(url="https://example.com", config=config)
|
||||||
url='https://example.com',
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.success and result.ssl_certificate:
|
if result.success and result.ssl_certificate:
|
||||||
cert = result.ssl_certificate
|
cert = result.ssl_certificate
|
||||||
@@ -36,11 +36,16 @@ async def main():
|
|||||||
print("\nCertificate exported to:")
|
print("\nCertificate exported to:")
|
||||||
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
|
print(f"- JSON: {os.path.join(tmp_dir, 'certificate.json')}")
|
||||||
|
|
||||||
pem_data = cert.to_pem(os.path.join(tmp_dir, "certificate.pem")) # For web servers
|
pem_data = cert.to_pem(
|
||||||
|
os.path.join(tmp_dir, "certificate.pem")
|
||||||
|
) # For web servers
|
||||||
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
|
print(f"- PEM: {os.path.join(tmp_dir, 'certificate.pem')}")
|
||||||
|
|
||||||
der_data = cert.to_der(os.path.join(tmp_dir, "certificate.der")) # For Java apps
|
der_data = cert.to_der(
|
||||||
|
os.path.join(tmp_dir, "certificate.der")
|
||||||
|
) # For Java apps
|
||||||
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
|
print(f"- DER: {os.path.join(tmp_dir, 'certificate.der')}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,39 +1,41 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import json
|
import json
|
||||||
from crawl4ai.web_crawler import WebCrawler
|
from crawl4ai.web_crawler import WebCrawler
|
||||||
from crawl4ai.chunking_strategy import *
|
from crawl4ai.chunking_strategy import *
|
||||||
from crawl4ai.extraction_strategy import *
|
from crawl4ai.extraction_strategy import *
|
||||||
from crawl4ai.crawler_strategy import *
|
from crawl4ai.crawler_strategy import *
|
||||||
|
|
||||||
url = r'https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot'
|
url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot"
|
||||||
|
|
||||||
crawler = WebCrawler()
|
crawler = WebCrawler()
|
||||||
crawler.warmup()
|
crawler.warmup()
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class PageSummary(BaseModel):
|
class PageSummary(BaseModel):
|
||||||
title: str = Field(..., description="Title of the page.")
|
title: str = Field(..., description="Title of the page.")
|
||||||
summary: str = Field(..., description="Summary of the page.")
|
summary: str = Field(..., description="Summary of the page.")
|
||||||
brief_summary: str = Field(..., description="Brief summary of the page.")
|
brief_summary: str = Field(..., description="Brief summary of the page.")
|
||||||
keywords: list = Field(..., description="Keywords assigned to the page.")
|
keywords: list = Field(..., description="Keywords assigned to the page.")
|
||||||
|
|
||||||
|
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url=url,
|
url=url,
|
||||||
word_count_threshold=1,
|
word_count_threshold=1,
|
||||||
extraction_strategy=LLMExtractionStrategy(
|
extraction_strategy=LLMExtractionStrategy(
|
||||||
provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY'),
|
provider="openai/gpt-4o",
|
||||||
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
schema=PageSummary.model_json_schema(),
|
schema=PageSummary.model_json_schema(),
|
||||||
extraction_type="schema",
|
extraction_type="schema",
|
||||||
apply_chunking=False,
|
apply_chunking=False,
|
||||||
instruction="From the crawled content, extract the following details: "\
|
instruction="From the crawled content, extract the following details: "
|
||||||
"1. Title of the page "\
|
"1. Title of the page "
|
||||||
"2. Summary of the page, which is a detailed summary "\
|
"2. Summary of the page, which is a detailed summary "
|
||||||
"3. Brief summary of the page, which is a paragraph text "\
|
"3. Brief summary of the page, which is a paragraph text "
|
||||||
"4. Keywords assigned to the page, which is a list of keywords. "\
|
"4. Keywords assigned to the page, which is a list of keywords. "
|
||||||
'The extracted JSON format should look like this: '\
|
"The extracted JSON format should look like this: "
|
||||||
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }'
|
'{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }',
|
||||||
),
|
),
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
|
|
||||||
# append the parent directory to the sys.path
|
# append the parent directory to the sys.path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
@@ -13,6 +14,7 @@ import json
|
|||||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||||
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
||||||
|
|
||||||
|
|
||||||
# 1. File Download Processing Example
|
# 1. File Download Processing Example
|
||||||
async def download_example():
|
async def download_example():
|
||||||
"""Example of downloading files from Python.org"""
|
"""Example of downloading files from Python.org"""
|
||||||
@@ -23,9 +25,7 @@ async def download_example():
|
|||||||
print(f"Downloads will be saved to: {downloads_path}")
|
print(f"Downloads will be saved to: {downloads_path}")
|
||||||
|
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
accept_downloads=True,
|
accept_downloads=True, downloads_path=downloads_path, verbose=True
|
||||||
downloads_path=downloads_path,
|
|
||||||
verbose=True
|
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
@@ -40,7 +40,7 @@ async def download_example():
|
|||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
delay_before_return_html=1, # Wait 5 seconds to ensure download starts
|
delay_before_return_html=1, # Wait 5 seconds to ensure download starts
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.downloaded_files:
|
if result.downloaded_files:
|
||||||
@@ -52,24 +52,25 @@ async def download_example():
|
|||||||
else:
|
else:
|
||||||
print("\nNo files were downloaded")
|
print("\nNo files were downloaded")
|
||||||
|
|
||||||
|
|
||||||
# 2. Local File and Raw HTML Processing Example
|
# 2. Local File and Raw HTML Processing Example
|
||||||
async def local_and_raw_html_example():
|
async def local_and_raw_html_example():
|
||||||
"""Example of processing local files and raw HTML"""
|
"""Example of processing local files and raw HTML"""
|
||||||
# Create a sample HTML file
|
# Create a sample HTML file
|
||||||
sample_file = os.path.join(__data__, "sample.html")
|
sample_file = os.path.join(__data__, "sample.html")
|
||||||
with open(sample_file, "w") as f:
|
with open(sample_file, "w") as f:
|
||||||
f.write("""
|
f.write(
|
||||||
|
"""
|
||||||
<html><body>
|
<html><body>
|
||||||
<h1>Test Content</h1>
|
<h1>Test Content</h1>
|
||||||
<p>This is a test paragraph.</p>
|
<p>This is a test paragraph.</p>
|
||||||
</body></html>
|
</body></html>
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
# Process local file
|
# Process local file
|
||||||
local_result = await crawler.arun(
|
local_result = await crawler.arun(url=f"file://{os.path.abspath(sample_file)}")
|
||||||
url=f"file://{os.path.abspath(sample_file)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process raw HTML
|
# Process raw HTML
|
||||||
raw_html = """
|
raw_html = """
|
||||||
@@ -78,9 +79,7 @@ async def local_and_raw_html_example():
|
|||||||
<p>This is a test of raw HTML processing.</p>
|
<p>This is a test of raw HTML processing.</p>
|
||||||
</body></html>
|
</body></html>
|
||||||
"""
|
"""
|
||||||
raw_result = await crawler.arun(
|
raw_result = await crawler.arun(url=f"raw:{raw_html}")
|
||||||
url=f"raw:{raw_html}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
os.remove(sample_file)
|
os.remove(sample_file)
|
||||||
@@ -88,6 +87,7 @@ async def local_and_raw_html_example():
|
|||||||
print("Local file content:", local_result.markdown)
|
print("Local file content:", local_result.markdown)
|
||||||
print("\nRaw HTML content:", raw_result.markdown)
|
print("\nRaw HTML content:", raw_result.markdown)
|
||||||
|
|
||||||
|
|
||||||
# 3. Enhanced Markdown Generation Example
|
# 3. Enhanced Markdown Generation Example
|
||||||
async def markdown_generation_example():
|
async def markdown_generation_example():
|
||||||
"""Example of enhanced markdown generation with citations and LLM-friendly features"""
|
"""Example of enhanced markdown generation with citations and LLM-friendly features"""
|
||||||
@@ -102,27 +102,32 @@ async def markdown_generation_example():
|
|||||||
url="https://en.wikipedia.org/wiki/Apple",
|
url="https://en.wikipedia.org/wiki/Apple",
|
||||||
css_selector="main div#bodyContent",
|
css_selector="main div#bodyContent",
|
||||||
content_filter=content_filter,
|
content_filter=content_filter,
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
|
|
||||||
from crawl4ai import AsyncWebCrawler
|
|
||||||
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
||||||
|
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://en.wikipedia.org/wiki/Apple",
|
url="https://en.wikipedia.org/wiki/Apple",
|
||||||
css_selector="main div#bodyContent",
|
css_selector="main div#bodyContent",
|
||||||
content_filter=BM25ContentFilter()
|
content_filter=BM25ContentFilter(),
|
||||||
)
|
)
|
||||||
print(result.markdown_v2.fit_markdown)
|
print(result.markdown_v2.fit_markdown)
|
||||||
|
|
||||||
print("\nMarkdown Generation Results:")
|
print("\nMarkdown Generation Results:")
|
||||||
print(f"1. Original markdown length: {len(result.markdown)}")
|
print(f"1. Original markdown length: {len(result.markdown)}")
|
||||||
print(f"2. New markdown versions (markdown_v2):")
|
print("2. New markdown versions (markdown_v2):")
|
||||||
print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}")
|
print(f" - Raw markdown length: {len(result.markdown_v2.raw_markdown)}")
|
||||||
print(f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}")
|
print(
|
||||||
print(f" - References section length: {len(result.markdown_v2.references_markdown)}")
|
f" - Citations markdown length: {len(result.markdown_v2.markdown_with_citations)}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" - References section length: {len(result.markdown_v2.references_markdown)}"
|
||||||
|
)
|
||||||
if result.markdown_v2.fit_markdown:
|
if result.markdown_v2.fit_markdown:
|
||||||
print(f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}")
|
print(
|
||||||
|
f" - Filtered markdown length: {len(result.markdown_v2.fit_markdown)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Save examples to files
|
# Save examples to files
|
||||||
output_dir = os.path.join(__data__, "markdown_examples")
|
output_dir = os.path.join(__data__, "markdown_examples")
|
||||||
@@ -148,7 +153,10 @@ async def markdown_generation_example():
|
|||||||
print("\nSample of markdown with citations:")
|
print("\nSample of markdown with citations:")
|
||||||
print(result.markdown_v2.markdown_with_citations[:500] + "...\n")
|
print(result.markdown_v2.markdown_with_citations[:500] + "...\n")
|
||||||
print("Sample of references:")
|
print("Sample of references:")
|
||||||
print('\n'.join(result.markdown_v2.references_markdown.split('\n')[:10]) + "...")
|
print(
|
||||||
|
"\n".join(result.markdown_v2.references_markdown.split("\n")[:10]) + "..."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 4. Browser Management Example
|
# 4. Browser Management Example
|
||||||
async def browser_management_example():
|
async def browser_management_example():
|
||||||
@@ -163,31 +171,31 @@ async def browser_management_example():
|
|||||||
use_managed_browser=True,
|
use_managed_browser=True,
|
||||||
user_data_dir=user_data_dir,
|
user_data_dir=user_data_dir,
|
||||||
headless=False,
|
headless=False,
|
||||||
verbose=True
|
verbose=True,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
|
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://crawl4ai.com",
|
url="https://crawl4ai.com",
|
||||||
# session_id="persistent_session_1",
|
# session_id="persistent_session_1",
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
# Use GitHub as an example - it's a good test for browser management
|
# Use GitHub as an example - it's a good test for browser management
|
||||||
# because it requires proper browser handling
|
# because it requires proper browser handling
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://github.com/trending",
|
url="https://github.com/trending",
|
||||||
# session_id="persistent_session_1",
|
# session_id="persistent_session_1",
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nBrowser session result:", result.success)
|
print("\nBrowser session result:", result.success)
|
||||||
if result.success:
|
if result.success:
|
||||||
print("Page title:", result.metadata.get('title', 'No title found'))
|
print("Page title:", result.metadata.get("title", "No title found"))
|
||||||
|
|
||||||
|
|
||||||
# 5. API Usage Example
|
# 5. API Usage Example
|
||||||
async def api_example():
|
async def api_example():
|
||||||
"""Example of using the new API endpoints"""
|
"""Example of using the new API endpoints"""
|
||||||
api_token = os.getenv('CRAWL4AI_API_TOKEN') or "test_api_code"
|
api_token = os.getenv("CRAWL4AI_API_TOKEN") or "test_api_code"
|
||||||
headers = {'Authorization': f'Bearer {api_token}'}
|
headers = {"Authorization": f"Bearer {api_token}"}
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
# Submit crawl job
|
# Submit crawl job
|
||||||
crawl_request = {
|
crawl_request = {
|
||||||
@@ -199,26 +207,18 @@ async def api_example():
|
|||||||
"name": "Hacker News Articles",
|
"name": "Hacker News Articles",
|
||||||
"baseSelector": ".athing",
|
"baseSelector": ".athing",
|
||||||
"fields": [
|
"fields": [
|
||||||
{
|
{"name": "title", "selector": ".title a", "type": "text"},
|
||||||
"name": "title",
|
{"name": "score", "selector": ".score", "type": "text"},
|
||||||
"selector": ".title a",
|
|
||||||
"type": "text"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "score",
|
|
||||||
"selector": ".score",
|
|
||||||
"type": "text"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "url",
|
"name": "url",
|
||||||
"selector": ".title a",
|
"selector": ".title a",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "href"
|
"attribute": "href",
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
},
|
||||||
"crawler_params": {
|
"crawler_params": {
|
||||||
"headless": True,
|
"headless": True,
|
||||||
# "use_managed_browser": True
|
# "use_managed_browser": True
|
||||||
@@ -229,9 +229,7 @@ async def api_example():
|
|||||||
}
|
}
|
||||||
|
|
||||||
async with session.post(
|
async with session.post(
|
||||||
"http://localhost:11235/crawl",
|
"http://localhost:11235/crawl", json=crawl_request, headers=headers
|
||||||
json=crawl_request,
|
|
||||||
headers=headers
|
|
||||||
) as response:
|
) as response:
|
||||||
task_data = await response.json()
|
task_data = await response.json()
|
||||||
task_id = task_data["task_id"]
|
task_id = task_data["task_id"]
|
||||||
@@ -239,8 +237,7 @@ async def api_example():
|
|||||||
# Check task status
|
# Check task status
|
||||||
while True:
|
while True:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
f"http://localhost:11235/task/{task_id}",
|
f"http://localhost:11235/task/{task_id}", headers=headers
|
||||||
headers=headers
|
|
||||||
) as status_response:
|
) as status_response:
|
||||||
result = await status_response.json()
|
result = await status_response.json()
|
||||||
print(f"Task status: {result['status']}")
|
print(f"Task status: {result['status']}")
|
||||||
@@ -248,12 +245,13 @@ async def api_example():
|
|||||||
if result["status"] == "completed":
|
if result["status"] == "completed":
|
||||||
print("Task completed!")
|
print("Task completed!")
|
||||||
print("Results:")
|
print("Results:")
|
||||||
news = json.loads(result["results"][0]['extracted_content'])
|
news = json.loads(result["results"][0]["extracted_content"])
|
||||||
print(json.dumps(news[:4], indent=2))
|
print(json.dumps(news[:4], indent=2))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
# Main execution
|
# Main execution
|
||||||
async def main():
|
async def main():
|
||||||
# print("Running Crawl4AI feature examples...")
|
# print("Running Crawl4AI feature examples...")
|
||||||
@@ -273,5 +271,6 @@ async def main():
|
|||||||
# print("\n5. Running API Example:")
|
# print("\n5. Running API Example:")
|
||||||
await api_example()
|
await api_example()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -10,15 +10,14 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from crawl4ai import (
|
from crawl4ai import (
|
||||||
AsyncWebCrawler,
|
AsyncWebCrawler,
|
||||||
BrowserConfig,
|
BrowserConfig,
|
||||||
CrawlerRunConfig,
|
CrawlerRunConfig,
|
||||||
CacheMode,
|
CacheMode,
|
||||||
LLMExtractionStrategy,
|
LLMExtractionStrategy,
|
||||||
JsonCssExtractionStrategy
|
JsonCssExtractionStrategy,
|
||||||
)
|
)
|
||||||
from crawl4ai.content_filter_strategy import RelevantContentFilter
|
from crawl4ai.content_filter_strategy import RelevantContentFilter
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
@@ -52,6 +51,7 @@ SAMPLE_HTML = """
|
|||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def demo_ssl_features():
|
async def demo_ssl_features():
|
||||||
"""
|
"""
|
||||||
Enhanced SSL & Security Features Demo
|
Enhanced SSL & Security Features Demo
|
||||||
@@ -76,14 +76,11 @@ async def demo_ssl_features():
|
|||||||
|
|
||||||
run_config = CrawlerRunConfig(
|
run_config = CrawlerRunConfig(
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
fetch_ssl_certificate=True # Enable SSL certificate fetching
|
fetch_ssl_certificate=True, # Enable SSL certificate fetching
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(url="https://example.com", config=run_config)
|
||||||
url="https://example.com",
|
|
||||||
config=run_config
|
|
||||||
)
|
|
||||||
print(f"SSL Crawl Success: {result.success}")
|
print(f"SSL Crawl Success: {result.success}")
|
||||||
result.ssl_certificate.to_json(
|
result.ssl_certificate.to_json(
|
||||||
os.path.join(os.getcwd(), "ssl_certificate.json")
|
os.path.join(os.getcwd(), "ssl_certificate.json")
|
||||||
@@ -91,6 +88,7 @@ async def demo_ssl_features():
|
|||||||
if not result.success:
|
if not result.success:
|
||||||
print(f"SSL Error: {result.error_message}")
|
print(f"SSL Error: {result.error_message}")
|
||||||
|
|
||||||
|
|
||||||
async def demo_content_filtering():
|
async def demo_content_filtering():
|
||||||
"""
|
"""
|
||||||
Smart Content Filtering Demo
|
Smart Content Filtering Demo
|
||||||
@@ -110,12 +108,14 @@ async def demo_content_filtering():
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
# Add news-specific patterns
|
# Add news-specific patterns
|
||||||
self.negative_patterns = re.compile(
|
self.negative_patterns = re.compile(
|
||||||
r'nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending',
|
r"nav|footer|header|sidebar|ads|comment|share|related|recommended|popular|trending",
|
||||||
re.I
|
re.I,
|
||||||
)
|
)
|
||||||
self.min_word_count = 30 # Higher threshold for news content
|
self.min_word_count = 30 # Higher threshold for news content
|
||||||
|
|
||||||
def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]:
|
def filter_content(
|
||||||
|
self, html: str, min_word_threshold: int = None
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Implements news-specific content filtering logic.
|
Implements news-specific content filtering logic.
|
||||||
|
|
||||||
@@ -129,14 +129,16 @@ async def demo_content_filtering():
|
|||||||
if not html or not isinstance(html, str):
|
if not html or not isinstance(html, str):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
soup = BeautifulSoup(html, 'lxml')
|
soup = BeautifulSoup(html, "lxml")
|
||||||
if not soup.body:
|
if not soup.body:
|
||||||
soup = BeautifulSoup(f'<body>{html}</body>', 'lxml')
|
soup = BeautifulSoup(f"<body>{html}</body>", "lxml")
|
||||||
|
|
||||||
body = soup.find('body')
|
body = soup.find("body")
|
||||||
|
|
||||||
# Extract chunks with metadata
|
# Extract chunks with metadata
|
||||||
chunks = self.extract_text_chunks(body, min_word_threshold or self.min_word_count)
|
chunks = self.extract_text_chunks(
|
||||||
|
body, min_word_threshold or self.min_word_count
|
||||||
|
)
|
||||||
|
|
||||||
# Filter chunks based on news-specific criteria
|
# Filter chunks based on news-specific criteria
|
||||||
filtered_chunks = []
|
filtered_chunks = []
|
||||||
@@ -146,7 +148,7 @@ async def demo_content_filtering():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Headers are important in news articles
|
# Headers are important in news articles
|
||||||
if tag_type == 'header':
|
if tag_type == "header":
|
||||||
filtered_chunks.append(self.clean_element(element))
|
filtered_chunks.append(self.clean_element(element))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -154,7 +156,9 @@ async def demo_content_filtering():
|
|||||||
text = element.get_text(strip=True)
|
text = element.get_text(strip=True)
|
||||||
if len(text.split()) >= (min_word_threshold or self.min_word_count):
|
if len(text.split()) >= (min_word_threshold or self.min_word_count):
|
||||||
# Calculate link density
|
# Calculate link density
|
||||||
links_text = ' '.join(a.get_text(strip=True) for a in element.find_all('a'))
|
links_text = " ".join(
|
||||||
|
a.get_text(strip=True) for a in element.find_all("a")
|
||||||
|
)
|
||||||
link_density = len(links_text) / len(text) if text else 1
|
link_density = len(links_text) / len(text) if text else 1
|
||||||
|
|
||||||
# Accept if link density is reasonable
|
# Accept if link density is reasonable
|
||||||
@@ -164,23 +168,20 @@ async def demo_content_filtering():
|
|||||||
return filtered_chunks
|
return filtered_chunks
|
||||||
|
|
||||||
# Create markdown generator with custom filter
|
# Create markdown generator with custom filter
|
||||||
markdown_gen = DefaultMarkdownGenerator(
|
markdown_gen = DefaultMarkdownGenerator(content_filter=CustomNewsFilter())
|
||||||
content_filter=CustomNewsFilter()
|
|
||||||
)
|
|
||||||
|
|
||||||
run_config = CrawlerRunConfig(
|
run_config = CrawlerRunConfig(
|
||||||
markdown_generator=markdown_gen,
|
markdown_generator=markdown_gen, cache_mode=CacheMode.BYPASS
|
||||||
cache_mode=CacheMode.BYPASS
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://news.ycombinator.com",
|
url="https://news.ycombinator.com", config=run_config
|
||||||
config=run_config
|
|
||||||
)
|
)
|
||||||
print("Filtered Content Sample:")
|
print("Filtered Content Sample:")
|
||||||
print(result.markdown[:500]) # Show first 500 chars
|
print(result.markdown[:500]) # Show first 500 chars
|
||||||
|
|
||||||
|
|
||||||
async def demo_json_extraction():
|
async def demo_json_extraction():
|
||||||
"""
|
"""
|
||||||
Improved JSON Extraction Demo
|
Improved JSON Extraction Demo
|
||||||
@@ -206,7 +207,7 @@ async def demo_json_extraction():
|
|||||||
"baseSelector": "div.article-list",
|
"baseSelector": "div.article-list",
|
||||||
"baseFields": [
|
"baseFields": [
|
||||||
{"name": "list_id", "type": "attribute", "attribute": "data-list-id"},
|
{"name": "list_id", "type": "attribute", "attribute": "data-list-id"},
|
||||||
{"name": "category", "type": "attribute", "attribute": "data-category"}
|
{"name": "category", "type": "attribute", "attribute": "data-category"},
|
||||||
],
|
],
|
||||||
"fields": [
|
"fields": [
|
||||||
{
|
{
|
||||||
@@ -214,8 +215,16 @@ async def demo_json_extraction():
|
|||||||
"selector": "article.post",
|
"selector": "article.post",
|
||||||
"type": "nested_list",
|
"type": "nested_list",
|
||||||
"baseFields": [
|
"baseFields": [
|
||||||
{"name": "post_id", "type": "attribute", "attribute": "data-post-id"},
|
{
|
||||||
{"name": "author_id", "type": "attribute", "attribute": "data-author"}
|
"name": "post_id",
|
||||||
|
"type": "attribute",
|
||||||
|
"attribute": "data-post-id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "author_id",
|
||||||
|
"type": "attribute",
|
||||||
|
"attribute": "data-author",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
"fields": [
|
"fields": [
|
||||||
{
|
{
|
||||||
@@ -223,51 +232,59 @@ async def demo_json_extraction():
|
|||||||
"selector": "h2.title a",
|
"selector": "h2.title a",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"baseFields": [
|
"baseFields": [
|
||||||
{"name": "url", "type": "attribute", "attribute": "href"}
|
{
|
||||||
]
|
"name": "url",
|
||||||
|
"type": "attribute",
|
||||||
|
"attribute": "href",
|
||||||
|
}
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "author",
|
"name": "author",
|
||||||
"selector": "div.meta a.author",
|
"selector": "div.meta a.author",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"baseFields": [
|
"baseFields": [
|
||||||
{"name": "profile_url", "type": "attribute", "attribute": "href"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "date",
|
"name": "profile_url",
|
||||||
"selector": "span.date",
|
"type": "attribute",
|
||||||
"type": "text"
|
"attribute": "href",
|
||||||
|
}
|
||||||
|
],
|
||||||
},
|
},
|
||||||
|
{"name": "date", "selector": "span.date", "type": "text"},
|
||||||
{
|
{
|
||||||
"name": "read_more",
|
"name": "read_more",
|
||||||
"selector": "a.read-more",
|
"selector": "a.read-more",
|
||||||
"type": "nested",
|
"type": "nested",
|
||||||
"fields": [
|
"fields": [
|
||||||
{"name": "text", "type": "text"},
|
{"name": "text", "type": "text"},
|
||||||
{"name": "url", "type": "attribute", "attribute": "href"}
|
{
|
||||||
]
|
"name": "url",
|
||||||
|
"type": "attribute",
|
||||||
|
"attribute": "href",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Demonstrate extraction from raw HTML
|
# Demonstrate extraction from raw HTML
|
||||||
run_config = CrawlerRunConfig(
|
run_config = CrawlerRunConfig(
|
||||||
extraction_strategy=json_strategy,
|
extraction_strategy=json_strategy, cache_mode=CacheMode.BYPASS
|
||||||
cache_mode=CacheMode.BYPASS
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML
|
url="raw:" + SAMPLE_HTML, # Use raw: prefix for raw HTML
|
||||||
config=run_config
|
config=run_config,
|
||||||
)
|
)
|
||||||
print("Extracted Content:")
|
print("Extracted Content:")
|
||||||
print(result.extracted_content)
|
print(result.extracted_content)
|
||||||
|
|
||||||
|
|
||||||
async def demo_input_formats():
|
async def demo_input_formats():
|
||||||
"""
|
"""
|
||||||
Input Format Handling Demo
|
Input Format Handling Demo
|
||||||
@@ -359,18 +376,30 @@ async def demo_input_formats():
|
|||||||
|
|
||||||
# Define our schema using Pydantic
|
# Define our schema using Pydantic
|
||||||
class JobRequirement(BaseModel):
|
class JobRequirement(BaseModel):
|
||||||
category: str = Field(description="Category of the requirement (e.g., Technical, Soft Skills)")
|
category: str = Field(
|
||||||
items: List[str] = Field(description="List of specific requirements in this category")
|
description="Category of the requirement (e.g., Technical, Soft Skills)"
|
||||||
priority: str = Field(description="Priority level (Required/Preferred) based on the HTML class or context")
|
)
|
||||||
|
items: List[str] = Field(
|
||||||
|
description="List of specific requirements in this category"
|
||||||
|
)
|
||||||
|
priority: str = Field(
|
||||||
|
description="Priority level (Required/Preferred) based on the HTML class or context"
|
||||||
|
)
|
||||||
|
|
||||||
class JobPosting(BaseModel):
|
class JobPosting(BaseModel):
|
||||||
title: str = Field(description="Job title")
|
title: str = Field(description="Job title")
|
||||||
department: str = Field(description="Department or team")
|
department: str = Field(description="Department or team")
|
||||||
location: str = Field(description="Job location, including remote options")
|
location: str = Field(description="Job location, including remote options")
|
||||||
salary_range: Optional[str] = Field(description="Salary range if specified")
|
salary_range: Optional[str] = Field(description="Salary range if specified")
|
||||||
requirements: List[JobRequirement] = Field(description="Categorized job requirements")
|
requirements: List[JobRequirement] = Field(
|
||||||
application_deadline: Optional[str] = Field(description="Application deadline if specified")
|
description="Categorized job requirements"
|
||||||
contact_info: Optional[dict] = Field(description="Contact information from footer or contact section")
|
)
|
||||||
|
application_deadline: Optional[str] = Field(
|
||||||
|
description="Application deadline if specified"
|
||||||
|
)
|
||||||
|
contact_info: Optional[dict] = Field(
|
||||||
|
description="Contact information from footer or contact section"
|
||||||
|
)
|
||||||
|
|
||||||
# First try with markdown (default)
|
# First try with markdown (default)
|
||||||
markdown_strategy = LLMExtractionStrategy(
|
markdown_strategy = LLMExtractionStrategy(
|
||||||
@@ -382,7 +411,7 @@ async def demo_input_formats():
|
|||||||
Extract job posting details into structured data. Focus on the visible text content
|
Extract job posting details into structured data. Focus on the visible text content
|
||||||
and organize requirements into categories.
|
and organize requirements into categories.
|
||||||
""",
|
""",
|
||||||
input_format="markdown" # default
|
input_format="markdown", # default
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then with HTML for better structure understanding
|
# Then with HTML for better structure understanding
|
||||||
@@ -400,34 +429,25 @@ async def demo_input_formats():
|
|||||||
|
|
||||||
Use HTML attributes and classes to enhance extraction accuracy.
|
Use HTML attributes and classes to enhance extraction accuracy.
|
||||||
""",
|
""",
|
||||||
input_format="html" # explicitly use HTML
|
input_format="html", # explicitly use HTML
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
# Try with markdown first
|
# Try with markdown first
|
||||||
markdown_config = CrawlerRunConfig(
|
markdown_config = CrawlerRunConfig(extraction_strategy=markdown_strategy)
|
||||||
extraction_strategy=markdown_strategy
|
markdown_result = await crawler.arun(url=url, config=markdown_config)
|
||||||
)
|
|
||||||
markdown_result = await crawler.arun(
|
|
||||||
url=url,
|
|
||||||
config=markdown_config
|
|
||||||
)
|
|
||||||
print("\nMarkdown-based Extraction Result:")
|
print("\nMarkdown-based Extraction Result:")
|
||||||
items = json.loads(markdown_result.extracted_content)
|
items = json.loads(markdown_result.extracted_content)
|
||||||
print(json.dumps(items, indent=2))
|
print(json.dumps(items, indent=2))
|
||||||
|
|
||||||
# Then with HTML for better structure understanding
|
# Then with HTML for better structure understanding
|
||||||
html_config = CrawlerRunConfig(
|
html_config = CrawlerRunConfig(extraction_strategy=html_strategy)
|
||||||
extraction_strategy=html_strategy
|
html_result = await crawler.arun(url=url, config=html_config)
|
||||||
)
|
|
||||||
html_result = await crawler.arun(
|
|
||||||
url=url,
|
|
||||||
config=html_config
|
|
||||||
)
|
|
||||||
print("\nHTML-based Extraction Result:")
|
print("\nHTML-based Extraction Result:")
|
||||||
items = json.loads(html_result.extracted_content)
|
items = json.loads(html_result.extracted_content)
|
||||||
print(json.dumps(items, indent=2))
|
print(json.dumps(items, indent=2))
|
||||||
|
|
||||||
|
|
||||||
# Main execution
|
# Main execution
|
||||||
async def main():
|
async def main():
|
||||||
print("Crawl4AI v0.4.24 Feature Walkthrough")
|
print("Crawl4AI v0.4.24 Feature Walkthrough")
|
||||||
@@ -439,5 +459,6 @@ async def main():
|
|||||||
await demo_json_extraction()
|
await demo_json_extraction()
|
||||||
# await demo_input_formats()
|
# await demo_input_formats()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
78
main.py
78
main.py
@@ -1,14 +1,9 @@
|
|||||||
import asyncio, os
|
import asyncio, os
|
||||||
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from fastapi.exceptions import RequestValidationError
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import FileResponse
|
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
from fastapi import Depends, Security
|
from fastapi import Depends, Security
|
||||||
@@ -18,13 +13,10 @@ from typing import Optional, List, Dict, Any, Union
|
|||||||
import psutil
|
import psutil
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import json
|
|
||||||
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
|
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode
|
||||||
from crawl4ai.config import MIN_WORD_THRESHOLD
|
from crawl4ai.config import MIN_WORD_THRESHOLD
|
||||||
from crawl4ai.extraction_strategy import (
|
from crawl4ai.extraction_strategy import (
|
||||||
@@ -38,30 +30,36 @@ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
class TaskStatus(str, Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
PROCESSING = "processing"
|
PROCESSING = "processing"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
class CrawlerType(str, Enum):
|
class CrawlerType(str, Enum):
|
||||||
BASIC = "basic"
|
BASIC = "basic"
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
COSINE = "cosine"
|
COSINE = "cosine"
|
||||||
JSON_CSS = "json_css"
|
JSON_CSS = "json_css"
|
||||||
|
|
||||||
|
|
||||||
class ExtractionConfig(BaseModel):
|
class ExtractionConfig(BaseModel):
|
||||||
type: CrawlerType
|
type: CrawlerType
|
||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class ChunkingStrategy(BaseModel):
|
class ChunkingStrategy(BaseModel):
|
||||||
type: str
|
type: str
|
||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class ContentFilter(BaseModel):
|
class ContentFilter(BaseModel):
|
||||||
type: str = "bm25"
|
type: str = "bm25"
|
||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
class CrawlRequest(BaseModel):
|
class CrawlRequest(BaseModel):
|
||||||
urls: Union[HttpUrl, List[HttpUrl]]
|
urls: Union[HttpUrl, List[HttpUrl]]
|
||||||
word_count_threshold: int = MIN_WORD_THRESHOLD
|
word_count_threshold: int = MIN_WORD_THRESHOLD
|
||||||
@@ -80,6 +78,7 @@ class CrawlRequest(BaseModel):
|
|||||||
ttl: Optional[int] = 3600
|
ttl: Optional[int] = 3600
|
||||||
crawler_params: Dict[str, Any] = {}
|
crawler_params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskInfo:
|
class TaskInfo:
|
||||||
id: str
|
id: str
|
||||||
@@ -89,6 +88,7 @@ class TaskInfo:
|
|||||||
created_at: float = time.time()
|
created_at: float = time.time()
|
||||||
ttl: int = 3600
|
ttl: int = 3600
|
||||||
|
|
||||||
|
|
||||||
class ResourceMonitor:
|
class ResourceMonitor:
|
||||||
def __init__(self, max_concurrent_tasks: int = 10):
|
def __init__(self, max_concurrent_tasks: int = 10):
|
||||||
self.max_concurrent_tasks = max_concurrent_tasks
|
self.max_concurrent_tasks = max_concurrent_tasks
|
||||||
@@ -106,7 +106,9 @@ class ResourceMonitor:
|
|||||||
mem_usage = psutil.virtual_memory().percent / 100
|
mem_usage = psutil.virtual_memory().percent / 100
|
||||||
cpu_usage = psutil.cpu_percent() / 100
|
cpu_usage = psutil.cpu_percent() / 100
|
||||||
|
|
||||||
memory_factor = max(0, (self.memory_threshold - mem_usage) / self.memory_threshold)
|
memory_factor = max(
|
||||||
|
0, (self.memory_threshold - mem_usage) / self.memory_threshold
|
||||||
|
)
|
||||||
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
|
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold)
|
||||||
|
|
||||||
self._last_available_slots = math.floor(
|
self._last_available_slots = math.floor(
|
||||||
@@ -116,6 +118,7 @@ class ResourceMonitor:
|
|||||||
|
|
||||||
return self._last_available_slots
|
return self._last_available_slots
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
def __init__(self, cleanup_interval: int = 300):
|
def __init__(self, cleanup_interval: int = 300):
|
||||||
self.tasks: Dict[str, TaskInfo] = {}
|
self.tasks: Dict[str, TaskInfo] = {}
|
||||||
@@ -149,12 +152,16 @@ class TaskManager:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
try:
|
try:
|
||||||
# Then try low priority
|
# Then try low priority
|
||||||
_, task_id = await asyncio.wait_for(self.low_priority.get(), timeout=0.1)
|
_, task_id = await asyncio.wait_for(
|
||||||
|
self.low_priority.get(), timeout=0.1
|
||||||
|
)
|
||||||
return task_id
|
return task_id
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_task(self, task_id: str, status: TaskStatus, result: Any = None, error: str = None):
|
def update_task(
|
||||||
|
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None
|
||||||
|
):
|
||||||
if task_id in self.tasks:
|
if task_id in self.tasks:
|
||||||
task_info = self.tasks[task_id]
|
task_info = self.tasks[task_id]
|
||||||
task_info.status = status
|
task_info.status = status
|
||||||
@@ -180,6 +187,7 @@ class TaskManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in cleanup loop: {e}")
|
logger.error(f"Error in cleanup loop: {e}")
|
||||||
|
|
||||||
|
|
||||||
class CrawlerPool:
|
class CrawlerPool:
|
||||||
def __init__(self, max_size: int = 10):
|
def __init__(self, max_size: int = 10):
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
@@ -222,6 +230,7 @@ class CrawlerPool:
|
|||||||
await crawler.__aexit__(None, None, None)
|
await crawler.__aexit__(None, None, None)
|
||||||
self.active_crawlers.clear()
|
self.active_crawlers.clear()
|
||||||
|
|
||||||
|
|
||||||
class CrawlerService:
|
class CrawlerService:
|
||||||
def __init__(self, max_concurrent_tasks: int = 10):
|
def __init__(self, max_concurrent_tasks: int = 10):
|
||||||
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
|
self.resource_monitor = ResourceMonitor(max_concurrent_tasks)
|
||||||
@@ -287,7 +296,9 @@ class CrawlerService:
|
|||||||
try:
|
try:
|
||||||
crawler = await self.crawler_pool.acquire(**request.crawler_params)
|
crawler = await self.crawler_pool.acquire(**request.crawler_params)
|
||||||
|
|
||||||
extraction_strategy = self._create_extraction_strategy(request.extraction_config)
|
extraction_strategy = self._create_extraction_strategy(
|
||||||
|
request.extraction_config
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(request.urls, list):
|
if isinstance(request.urls, list):
|
||||||
results = await crawler.arun_many(
|
results = await crawler.arun_many(
|
||||||
@@ -318,16 +329,21 @@ class CrawlerService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.crawler_pool.release(crawler)
|
await self.crawler_pool.release(crawler)
|
||||||
self.task_manager.update_task(task_id, TaskStatus.COMPLETED, results)
|
self.task_manager.update_task(
|
||||||
|
task_id, TaskStatus.COMPLETED, results
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing task {task_id}: {str(e)}")
|
logger.error(f"Error processing task {task_id}: {str(e)}")
|
||||||
self.task_manager.update_task(task_id, TaskStatus.FAILED, error=str(e))
|
self.task_manager.update_task(
|
||||||
|
task_id, TaskStatus.FAILED, error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in queue processing: {str(e)}")
|
logger.error(f"Error in queue processing: {str(e)}")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Crawl4AI API")
|
app = FastAPI(title="Crawl4AI API")
|
||||||
|
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
@@ -344,6 +360,7 @@ app.add_middleware(
|
|||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
|
CRAWL4AI_API_TOKEN = os.getenv("CRAWL4AI_API_TOKEN")
|
||||||
|
|
||||||
|
|
||||||
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||||
if not CRAWL4AI_API_TOKEN:
|
if not CRAWL4AI_API_TOKEN:
|
||||||
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
|
return credentials # No token verification if CRAWL4AI_API_TOKEN is not set
|
||||||
@@ -351,10 +368,12 @@ async def verify_token(credentials: HTTPAuthorizationCredentials = Security(secu
|
|||||||
raise HTTPException(status_code=401, detail="Invalid token")
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
def secure_endpoint():
|
def secure_endpoint():
|
||||||
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
|
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set"""
|
||||||
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
|
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None
|
||||||
|
|
||||||
|
|
||||||
# Check if site directory exists
|
# Check if site directory exists
|
||||||
if os.path.exists(__location__ + "/site"):
|
if os.path.exists(__location__ + "/site"):
|
||||||
# Mount the site directory as a static directory
|
# Mount the site directory as a static directory
|
||||||
@@ -364,14 +383,17 @@ site_templates = Jinja2Templates(directory=__location__ + "/site")
|
|||||||
|
|
||||||
crawler_service = CrawlerService()
|
crawler_service = CrawlerService()
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
await crawler_service.start()
|
await crawler_service.start()
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
await crawler_service.stop()
|
await crawler_service.stop()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def read_root():
|
def read_root():
|
||||||
if os.path.exists(__location__ + "/site"):
|
if os.path.exists(__location__ + "/site"):
|
||||||
@@ -379,12 +401,16 @@ def read_root():
|
|||||||
# Return a json response
|
# Return a json response
|
||||||
return {"message": "Crawl4AI API service is running"}
|
return {"message": "Crawl4AI API service is running"}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||||
async def crawl(request: CrawlRequest) -> Dict[str, str]:
|
async def crawl(request: CrawlRequest) -> Dict[str, str]:
|
||||||
task_id = await crawler_service.submit_task(request)
|
task_id = await crawler_service.submit_task(request)
|
||||||
return {"task_id": task_id}
|
return {"task_id": task_id}
|
||||||
|
|
||||||
@app.get("/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
|
||||||
|
@app.get(
|
||||||
|
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
|
||||||
|
)
|
||||||
async def get_task_status(task_id: str):
|
async def get_task_status(task_id: str):
|
||||||
task_info = crawler_service.task_manager.get_task(task_id)
|
task_info = crawler_service.task_manager.get_task(task_id)
|
||||||
if not task_info:
|
if not task_info:
|
||||||
@@ -406,6 +432,7 @@ async def get_task_status(task_id: str):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
||||||
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
|
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
|
||||||
task_id = await crawler_service.submit_task(request)
|
task_id = await crawler_service.submit_task(request)
|
||||||
@@ -419,7 +446,10 @@ async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
|
|||||||
if task_info.status == TaskStatus.COMPLETED:
|
if task_info.status == TaskStatus.COMPLETED:
|
||||||
# Return same format as /task/{task_id} endpoint
|
# Return same format as /task/{task_id} endpoint
|
||||||
if isinstance(task_info.result, list):
|
if isinstance(task_info.result, list):
|
||||||
return {"status": task_info.status, "results": [result.dict() for result in task_info.result]}
|
return {
|
||||||
|
"status": task_info.status,
|
||||||
|
"results": [result.dict() for result in task_info.result],
|
||||||
|
}
|
||||||
return {"status": task_info.status, "result": task_info.result.dict()}
|
return {"status": task_info.status, "result": task_info.result.dict()}
|
||||||
|
|
||||||
if task_info.status == TaskStatus.FAILED:
|
if task_info.status == TaskStatus.FAILED:
|
||||||
@@ -430,11 +460,16 @@ async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]:
|
|||||||
# If we get here, task didn't complete within timeout
|
# If we get here, task didn't complete within timeout
|
||||||
raise HTTPException(status_code=408, detail="Task timed out")
|
raise HTTPException(status_code=408, detail="Task timed out")
|
||||||
|
|
||||||
@app.post("/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [])
|
|
||||||
|
@app.post(
|
||||||
|
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []
|
||||||
|
)
|
||||||
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
|
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params)
|
||||||
extraction_strategy = crawler_service._create_extraction_strategy(request.extraction_config)
|
extraction_strategy = crawler_service._create_extraction_strategy(
|
||||||
|
request.extraction_config
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(request.urls, list):
|
if isinstance(request.urls, list):
|
||||||
@@ -471,6 +506,7 @@ async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]:
|
|||||||
logger.error(f"Error in direct crawl: {str(e)}")
|
logger.error(f"Error in direct crawl: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
available_slots = await crawler_service.resource_monitor.get_available_slots()
|
available_slots = await crawler_service.resource_monitor.get_available_slots()
|
||||||
@@ -482,6 +518,8 @@ async def health_check():
|
|||||||
"cpu_usage": psutil.cpu_percent(),
|
"cpu_usage": psutil.cpu_percent(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=11235)
|
uvicorn.run(app, host="0.0.0.0", port=11235)
|
||||||
4
setup.py
4
setup.py
@@ -51,9 +51,7 @@ setup(
|
|||||||
author_email="unclecode@kidocode.com",
|
author_email="unclecode@kidocode.com",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
package_data={
|
package_data={"crawl4ai": ["js_snippet/*.js"]},
|
||||||
'crawl4ai': ['js_snippet/*.js']
|
|
||||||
},
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from crawl4ai import AsyncWebCrawler, CacheMode
|
||||||
|
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
||||||
|
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
import os, sys
|
|
||||||
import asyncio
|
|
||||||
from crawl4ai import AsyncWebCrawler, CacheMode
|
|
||||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
|
||||||
|
|
||||||
# Assuming that the changes made allow different configurations
|
# Assuming that the changes made allow different configurations
|
||||||
# for managed browser, persistent context, and so forth.
|
# for managed browser, persistent context, and so forth.
|
||||||
|
|
||||||
|
|
||||||
async def test_default_headless():
|
async def test_default_headless():
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
headless=True,
|
headless=True,
|
||||||
@@ -24,13 +25,14 @@ async def test_default_headless():
|
|||||||
# Testing normal ephemeral context
|
# Testing normal ephemeral context
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url='https://www.kidocode.com/degrees/technology',
|
url="https://www.kidocode.com/degrees/technology",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_default_headless] success:", result.success)
|
print("[test_default_headless] success:", result.success)
|
||||||
print("HTML length:", len(result.html if result.html else ""))
|
print("HTML length:", len(result.html if result.html else ""))
|
||||||
|
|
||||||
|
|
||||||
async def test_managed_browser_persistent():
|
async def test_managed_browser_persistent():
|
||||||
# Treating use_persistent_context=True as managed_browser scenario.
|
# Treating use_persistent_context=True as managed_browser scenario.
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
@@ -44,13 +46,14 @@ async def test_managed_browser_persistent():
|
|||||||
# This should store and reuse profile data across runs
|
# This should store and reuse profile data across runs
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url='https://www.google.com',
|
url="https://www.google.com",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_managed_browser_persistent] success:", result.success)
|
print("[test_managed_browser_persistent] success:", result.success)
|
||||||
print("HTML length:", len(result.html if result.html else ""))
|
print("HTML length:", len(result.html if result.html else ""))
|
||||||
|
|
||||||
|
|
||||||
async def test_session_reuse():
|
async def test_session_reuse():
|
||||||
# Test creating a session, using it for multiple calls
|
# Test creating a session, using it for multiple calls
|
||||||
session_id = "my_session"
|
session_id = "my_session"
|
||||||
@@ -62,25 +65,25 @@ async def test_session_reuse():
|
|||||||
use_managed_browser=False,
|
use_managed_browser=False,
|
||||||
use_persistent_context=False,
|
use_persistent_context=False,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
|
|
||||||
# First call: create session
|
# First call: create session
|
||||||
result1 = await crawler.arun(
|
result1 = await crawler.arun(
|
||||||
url='https://www.example.com',
|
url="https://www.example.com",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_session_reuse first call] success:", result1.success)
|
print("[test_session_reuse first call] success:", result1.success)
|
||||||
|
|
||||||
# Second call: same session, possibly cookie retained
|
# Second call: same session, possibly cookie retained
|
||||||
result2 = await crawler.arun(
|
result2 = await crawler.arun(
|
||||||
url='https://www.example.com/about',
|
url="https://www.example.com/about",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_session_reuse second call] success:", result2.success)
|
print("[test_session_reuse second call] success:", result2.success)
|
||||||
|
|
||||||
|
|
||||||
async def test_magic_mode():
|
async def test_magic_mode():
|
||||||
# Test magic mode with override_navigator and simulate_user
|
# Test magic mode with override_navigator and simulate_user
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
@@ -95,13 +98,14 @@ async def test_magic_mode():
|
|||||||
simulate_user=True,
|
simulate_user=True,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url='https://www.kidocode.com/degrees/business',
|
url="https://www.kidocode.com/degrees/business",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_magic_mode] success:", result.success)
|
print("[test_magic_mode] success:", result.success)
|
||||||
print("HTML length:", len(result.html if result.html else ""))
|
print("HTML length:", len(result.html if result.html else ""))
|
||||||
|
|
||||||
|
|
||||||
async def test_proxy_settings():
|
async def test_proxy_settings():
|
||||||
# Test with a proxy (if available) to ensure code runs with proxy
|
# Test with a proxy (if available) to ensure code runs with proxy
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
@@ -113,14 +117,15 @@ async def test_proxy_settings():
|
|||||||
use_persistent_context=False,
|
use_persistent_context=False,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url='https://httpbin.org/ip',
|
url="https://httpbin.org/ip",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_proxy_settings] success:", result.success)
|
print("[test_proxy_settings] success:", result.success)
|
||||||
if result.success:
|
if result.success:
|
||||||
print("HTML preview:", result.html[:200] if result.html else "")
|
print("HTML preview:", result.html[:200] if result.html else "")
|
||||||
|
|
||||||
|
|
||||||
async def test_ignore_https_errors():
|
async def test_ignore_https_errors():
|
||||||
# Test ignore HTTPS errors with a self-signed or invalid cert domain
|
# Test ignore HTTPS errors with a self-signed or invalid cert domain
|
||||||
# This is just conceptual, the domain should be one that triggers SSL error.
|
# This is just conceptual, the domain should be one that triggers SSL error.
|
||||||
@@ -134,12 +139,13 @@ async def test_ignore_https_errors():
|
|||||||
use_persistent_context=False,
|
use_persistent_context=False,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url='https://self-signed.badssl.com/',
|
url="https://self-signed.badssl.com/",
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True})
|
markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}),
|
||||||
)
|
)
|
||||||
print("[test_ignore_https_errors] success:", result.success)
|
print("[test_ignore_https_errors] success:", result.success)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
print("Running tests...")
|
print("Running tests...")
|
||||||
# await test_default_headless()
|
# await test_default_headless()
|
||||||
@@ -149,5 +155,6 @@ async def main():
|
|||||||
# await test_proxy_settings()
|
# await test_proxy_settings()
|
||||||
await test_ignore_https_errors()
|
await test_ignore_https_errors()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
|
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
@@ -9,7 +10,7 @@ from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
|
|||||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||||
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
from crawl4ai.extraction_strategy import JsonCssExtractionStrategy
|
||||||
from crawl4ai.chunking_strategy import RegexChunking
|
from crawl4ai.chunking_strategy import RegexChunking
|
||||||
from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
|
||||||
|
|
||||||
# Category 1: Browser Configuration Tests
|
# Category 1: Browser Configuration Tests
|
||||||
async def test_browser_config_object():
|
async def test_browser_config_object():
|
||||||
@@ -21,29 +22,31 @@ async def test_browser_config_object():
|
|||||||
viewport_height=1080,
|
viewport_height=1080,
|
||||||
use_managed_browser=True,
|
use_managed_browser=True,
|
||||||
user_agent_mode="random",
|
user_agent_mode="random",
|
||||||
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}
|
user_agent_generator_config={"device_type": "desktop", "os_type": "windows"},
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler:
|
async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler:
|
||||||
result = await crawler.arun('https://example.com', cache_mode=CacheMode.BYPASS)
|
result = await crawler.arun("https://example.com", cache_mode=CacheMode.BYPASS)
|
||||||
assert result.success, "Browser config crawl failed"
|
assert result.success, "Browser config crawl failed"
|
||||||
assert len(result.html) > 0, "No HTML content retrieved"
|
assert len(result.html) > 0, "No HTML content retrieved"
|
||||||
|
|
||||||
|
|
||||||
async def test_browser_performance_config():
|
async def test_browser_performance_config():
|
||||||
"""Test browser configurations focused on performance"""
|
"""Test browser configurations focused on performance"""
|
||||||
browser_config = BrowserConfig(
|
browser_config = BrowserConfig(
|
||||||
text_mode=True,
|
text_mode=True,
|
||||||
light_mode=True,
|
light_mode=True,
|
||||||
extra_args=['--disable-gpu', '--disable-software-rasterizer'],
|
extra_args=["--disable-gpu", "--disable-software-rasterizer"],
|
||||||
ignore_https_errors=True,
|
ignore_https_errors=True,
|
||||||
java_script_enabled=False
|
java_script_enabled=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
result = await crawler.arun('https://example.com')
|
result = await crawler.arun("https://example.com")
|
||||||
assert result.success, "Performance optimized crawl failed"
|
assert result.success, "Performance optimized crawl failed"
|
||||||
assert result.status_code == 200, "Unexpected status code"
|
assert result.status_code == 200, "Unexpected status code"
|
||||||
|
|
||||||
|
|
||||||
# Category 2: Content Processing Tests
|
# Category 2: Content Processing Tests
|
||||||
async def test_content_extraction_config():
|
async def test_content_extraction_config():
|
||||||
"""Test content extraction with various strategies"""
|
"""Test content extraction with various strategies"""
|
||||||
@@ -53,24 +56,20 @@ async def test_content_extraction_config():
|
|||||||
schema={
|
schema={
|
||||||
"name": "article",
|
"name": "article",
|
||||||
"baseSelector": "div",
|
"baseSelector": "div",
|
||||||
"fields": [{
|
"fields": [{"name": "title", "selector": "h1", "type": "text"}],
|
||||||
"name": "title",
|
|
||||||
"selector": "h1",
|
|
||||||
"type": "text"
|
|
||||||
}]
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
chunking_strategy=RegexChunking(),
|
chunking_strategy=RegexChunking(),
|
||||||
content_filter=PruningContentFilter()
|
content_filter=PruningContentFilter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
'https://example.com/article',
|
"https://example.com/article", config=crawler_config
|
||||||
config=crawler_config
|
|
||||||
)
|
)
|
||||||
assert result.extracted_content is not None, "Content extraction failed"
|
assert result.extracted_content is not None, "Content extraction failed"
|
||||||
assert 'title' in result.extracted_content, "Missing expected content field"
|
assert "title" in result.extracted_content, "Missing expected content field"
|
||||||
|
|
||||||
|
|
||||||
# Category 3: Cache and Session Management Tests
|
# Category 3: Cache and Session Management Tests
|
||||||
async def test_cache_and_session_management():
|
async def test_cache_and_session_management():
|
||||||
@@ -79,25 +78,20 @@ async def test_cache_and_session_management():
|
|||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(
|
||||||
cache_mode=CacheMode.WRITE_ONLY,
|
cache_mode=CacheMode.WRITE_ONLY,
|
||||||
process_iframes=True,
|
process_iframes=True,
|
||||||
remove_overlay_elements=True
|
remove_overlay_elements=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
# First request - should write to cache
|
# First request - should write to cache
|
||||||
result1 = await crawler.arun(
|
result1 = await crawler.arun("https://example.com", config=crawler_config)
|
||||||
'https://example.com',
|
|
||||||
config=crawler_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Second request - should use fresh fetch due to WRITE_ONLY mode
|
# Second request - should use fresh fetch due to WRITE_ONLY mode
|
||||||
result2 = await crawler.arun(
|
result2 = await crawler.arun("https://example.com", config=crawler_config)
|
||||||
'https://example.com',
|
|
||||||
config=crawler_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result1.success and result2.success, "Cache mode crawl failed"
|
assert result1.success and result2.success, "Cache mode crawl failed"
|
||||||
assert result1.html == result2.html, "Inconsistent results between requests"
|
assert result1.html == result2.html, "Inconsistent results between requests"
|
||||||
|
|
||||||
|
|
||||||
# Category 4: Media Handling Tests
|
# Category 4: Media Handling Tests
|
||||||
async def test_media_handling_config():
|
async def test_media_handling_config():
|
||||||
"""Test configurations related to media handling"""
|
"""Test configurations related to media handling"""
|
||||||
@@ -107,24 +101,22 @@ async def test_media_handling_config():
|
|||||||
viewport_width=1920,
|
viewport_width=1920,
|
||||||
viewport_height=1080,
|
viewport_height=1080,
|
||||||
accept_downloads=True,
|
accept_downloads=True,
|
||||||
downloads_path= os.path.expanduser("~/.crawl4ai/downloads")
|
downloads_path=os.path.expanduser("~/.crawl4ai/downloads"),
|
||||||
)
|
)
|
||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(
|
||||||
screenshot=True,
|
screenshot=True,
|
||||||
pdf=True,
|
pdf=True,
|
||||||
adjust_viewport_to_content=True,
|
adjust_viewport_to_content=True,
|
||||||
wait_for_images=True,
|
wait_for_images=True,
|
||||||
screenshot_height_threshold=20000
|
screenshot_height_threshold=20000,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun("https://example.com", config=crawler_config)
|
||||||
'https://example.com',
|
|
||||||
config=crawler_config
|
|
||||||
)
|
|
||||||
assert result.screenshot is not None, "Screenshot capture failed"
|
assert result.screenshot is not None, "Screenshot capture failed"
|
||||||
assert result.pdf is not None, "PDF generation failed"
|
assert result.pdf is not None, "PDF generation failed"
|
||||||
|
|
||||||
|
|
||||||
# Category 5: Anti-Bot and Site Interaction Tests
|
# Category 5: Anti-Bot and Site Interaction Tests
|
||||||
async def test_antibot_config():
|
async def test_antibot_config():
|
||||||
"""Test configurations for handling anti-bot measures"""
|
"""Test configurations for handling anti-bot measures"""
|
||||||
@@ -135,57 +127,43 @@ async def test_antibot_config():
|
|||||||
wait_for="js:()=>document.querySelector('body')",
|
wait_for="js:()=>document.querySelector('body')",
|
||||||
delay_before_return_html=1.0,
|
delay_before_return_html=1.0,
|
||||||
log_console=True,
|
log_console=True,
|
||||||
cache_mode=CacheMode.BYPASS
|
cache_mode=CacheMode.BYPASS,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun("https://example.com", config=crawler_config)
|
||||||
'https://example.com',
|
|
||||||
config=crawler_config
|
|
||||||
)
|
|
||||||
assert result.success, "Anti-bot measure handling failed"
|
assert result.success, "Anti-bot measure handling failed"
|
||||||
|
|
||||||
|
|
||||||
# Category 6: Parallel Processing Tests
|
# Category 6: Parallel Processing Tests
|
||||||
async def test_parallel_processing():
|
async def test_parallel_processing():
|
||||||
"""Test parallel processing capabilities"""
|
"""Test parallel processing capabilities"""
|
||||||
crawler_config = CrawlerRunConfig(
|
crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5)
|
||||||
mean_delay=0.5,
|
|
||||||
max_range=1.0,
|
|
||||||
semaphore_count=5
|
|
||||||
)
|
|
||||||
|
|
||||||
urls = [
|
urls = ["https://example.com/1", "https://example.com/2", "https://example.com/3"]
|
||||||
'https://example.com/1',
|
|
||||||
'https://example.com/2',
|
|
||||||
'https://example.com/3'
|
|
||||||
]
|
|
||||||
|
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
results = await crawler.arun_many(
|
results = await crawler.arun_many(urls, config=crawler_config)
|
||||||
urls,
|
|
||||||
config=crawler_config
|
|
||||||
)
|
|
||||||
assert len(results) == len(urls), "Not all URLs were processed"
|
assert len(results) == len(urls), "Not all URLs were processed"
|
||||||
assert all(r.success for r in results), "Some parallel requests failed"
|
assert all(r.success for r in results), "Some parallel requests failed"
|
||||||
|
|
||||||
|
|
||||||
# Category 7: Backwards Compatibility Tests
|
# Category 7: Backwards Compatibility Tests
|
||||||
async def test_legacy_parameter_support():
|
async def test_legacy_parameter_support():
|
||||||
"""Test that legacy parameters still work"""
|
"""Test that legacy parameters still work"""
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
headless=True,
|
headless=True, browser_type="chromium", viewport_width=1024, viewport_height=768
|
||||||
browser_type="chromium",
|
|
||||||
viewport_width=1024,
|
|
||||||
viewport_height=768
|
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
'https://example.com',
|
"https://example.com",
|
||||||
screenshot=True,
|
screenshot=True,
|
||||||
word_count_threshold=200,
|
word_count_threshold=200,
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
css_selector=".main-content"
|
css_selector=".main-content",
|
||||||
)
|
)
|
||||||
assert result.success, "Legacy parameter support failed"
|
assert result.success, "Legacy parameter support failed"
|
||||||
|
|
||||||
|
|
||||||
# Category 8: Mixed Configuration Tests
|
# Category 8: Mixed Configuration Tests
|
||||||
async def test_mixed_config_usage():
|
async def test_mixed_config_usage():
|
||||||
"""Test mixing new config objects with legacy parameters"""
|
"""Test mixing new config objects with legacy parameters"""
|
||||||
@@ -194,17 +172,19 @@ async def test_mixed_config_usage():
|
|||||||
|
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
config=browser_config,
|
config=browser_config,
|
||||||
verbose=True # legacy parameter
|
verbose=True, # legacy parameter
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
'https://example.com',
|
"https://example.com",
|
||||||
config=crawler_config,
|
config=crawler_config,
|
||||||
cache_mode=CacheMode.BYPASS, # legacy parameter
|
cache_mode=CacheMode.BYPASS, # legacy parameter
|
||||||
css_selector="body" # legacy parameter
|
css_selector="body", # legacy parameter
|
||||||
)
|
)
|
||||||
assert result.success, "Mixed configuration usage failed"
|
assert result.success, "Mixed configuration usage failed"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
async def run_tests():
|
async def run_tests():
|
||||||
test_functions = [
|
test_functions = [
|
||||||
test_browser_config_object,
|
test_browser_config_object,
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import asyncio
|
|||||||
import shutil
|
import shutil
|
||||||
from typing import List
|
from typing import List
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -12,6 +11,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
class TestDownloads:
|
class TestDownloads:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_")
|
self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_")
|
||||||
@@ -31,9 +31,7 @@ class TestDownloads:
|
|||||||
"""Test basic file download functionality"""
|
"""Test basic file download functionality"""
|
||||||
try:
|
try:
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
accept_downloads=True,
|
accept_downloads=True, downloads_path=self.download_dir, verbose=True
|
||||||
downloads_path=self.download_dir,
|
|
||||||
verbose=True
|
|
||||||
) as crawler:
|
) as crawler:
|
||||||
# Python.org downloads page typically has stable download links
|
# Python.org downloads page typically has stable download links
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
@@ -42,14 +40,19 @@ class TestDownloads:
|
|||||||
// Click first download link
|
// Click first download link
|
||||||
const downloadLink = document.querySelector('a[href$=".exe"]');
|
const downloadLink = document.querySelector('a[href$=".exe"]');
|
||||||
if (downloadLink) downloadLink.click();
|
if (downloadLink) downloadLink.click();
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
|
success = (
|
||||||
|
result.downloaded_files is not None
|
||||||
|
and len(result.downloaded_files) > 0
|
||||||
|
)
|
||||||
self.log_result(
|
self.log_result(
|
||||||
"Basic Download",
|
"Basic Download",
|
||||||
success,
|
success,
|
||||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
|
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||||
|
if success
|
||||||
|
else "No files downloaded",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log_result("Basic Download", False, str(e))
|
self.log_result("Basic Download", False, str(e))
|
||||||
@@ -65,21 +68,26 @@ class TestDownloads:
|
|||||||
downloads_path=self.download_dir,
|
downloads_path=self.download_dir,
|
||||||
use_persistent_context=True,
|
use_persistent_context=True,
|
||||||
user_data_dir=user_data_dir,
|
user_data_dir=user_data_dir,
|
||||||
verbose=True
|
verbose=True,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
js_code="""
|
js_code="""
|
||||||
const downloadLink = document.querySelector('a[href$=".exe"]');
|
const downloadLink = document.querySelector('a[href$=".exe"]');
|
||||||
if (downloadLink) downloadLink.click();
|
if (downloadLink) downloadLink.click();
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
|
success = (
|
||||||
|
result.downloaded_files is not None
|
||||||
|
and len(result.downloaded_files) > 0
|
||||||
|
)
|
||||||
self.log_result(
|
self.log_result(
|
||||||
"Persistent Context Download",
|
"Persistent Context Download",
|
||||||
success,
|
success,
|
||||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
|
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||||
|
if success
|
||||||
|
else "No files downloaded",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log_result("Persistent Context Download", False, str(e))
|
self.log_result("Persistent Context Download", False, str(e))
|
||||||
@@ -88,9 +96,7 @@ class TestDownloads:
|
|||||||
"""Test multiple simultaneous downloads"""
|
"""Test multiple simultaneous downloads"""
|
||||||
try:
|
try:
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
accept_downloads=True,
|
accept_downloads=True, downloads_path=self.download_dir, verbose=True
|
||||||
downloads_path=self.download_dir,
|
|
||||||
verbose=True
|
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
@@ -98,14 +104,19 @@ class TestDownloads:
|
|||||||
// Click multiple download links
|
// Click multiple download links
|
||||||
const downloadLinks = document.querySelectorAll('a[href$=".exe"]');
|
const downloadLinks = document.querySelectorAll('a[href$=".exe"]');
|
||||||
downloadLinks.forEach(link => link.click());
|
downloadLinks.forEach(link => link.click());
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 1
|
success = (
|
||||||
|
result.downloaded_files is not None
|
||||||
|
and len(result.downloaded_files) > 1
|
||||||
|
)
|
||||||
self.log_result(
|
self.log_result(
|
||||||
"Multiple Downloads",
|
"Multiple Downloads",
|
||||||
success,
|
success,
|
||||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "Not enough files downloaded"
|
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||||
|
if success
|
||||||
|
else "Not enough files downloaded",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log_result("Multiple Downloads", False, str(e))
|
self.log_result("Multiple Downloads", False, str(e))
|
||||||
@@ -120,21 +131,26 @@ class TestDownloads:
|
|||||||
accept_downloads=True,
|
accept_downloads=True,
|
||||||
downloads_path=self.download_dir,
|
downloads_path=self.download_dir,
|
||||||
browser_type=browser_type,
|
browser_type=browser_type,
|
||||||
verbose=True
|
verbose=True,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
js_code="""
|
js_code="""
|
||||||
const downloadLink = document.querySelector('a[href$=".exe"]');
|
const downloadLink = document.querySelector('a[href$=".exe"]');
|
||||||
if (downloadLink) downloadLink.click();
|
if (downloadLink) downloadLink.click();
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
success = result.downloaded_files is not None and len(result.downloaded_files) > 0
|
success = (
|
||||||
|
result.downloaded_files is not None
|
||||||
|
and len(result.downloaded_files) > 0
|
||||||
|
)
|
||||||
self.log_result(
|
self.log_result(
|
||||||
f"{browser_type.title()} Download",
|
f"{browser_type.title()} Download",
|
||||||
success,
|
success,
|
||||||
f"Downloaded {len(result.downloaded_files or [])} files" if success else "No files downloaded"
|
f"Downloaded {len(result.downloaded_files or [])} files"
|
||||||
|
if success
|
||||||
|
else "No files downloaded",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log_result(f"{browser_type.title()} Download", False, str(e))
|
self.log_result(f"{browser_type.title()} Download", False, str(e))
|
||||||
@@ -144,18 +160,15 @@ class TestDownloads:
|
|||||||
|
|
||||||
# Test 1: Downloads without specifying download path
|
# Test 1: Downloads without specifying download path
|
||||||
try:
|
try:
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler:
|
||||||
accept_downloads=True,
|
|
||||||
verbose=True
|
|
||||||
) as crawler:
|
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
js_code="document.querySelector('a[href$=\".exe\"]').click()"
|
js_code="document.querySelector('a[href$=\".exe\"]').click()",
|
||||||
)
|
)
|
||||||
self.log_result(
|
self.log_result(
|
||||||
"Default Download Path",
|
"Default Download Path",
|
||||||
True,
|
True,
|
||||||
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}"
|
f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log_result("Default Download Path", False, str(e))
|
self.log_result("Default Download Path", False, str(e))
|
||||||
@@ -165,31 +178,34 @@ class TestDownloads:
|
|||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(
|
||||||
accept_downloads=True,
|
accept_downloads=True,
|
||||||
downloads_path="/invalid/path/that/doesnt/exist",
|
downloads_path="/invalid/path/that/doesnt/exist",
|
||||||
verbose=True
|
verbose=True,
|
||||||
) as crawler:
|
) as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
js_code="document.querySelector('a[href$=\".exe\"]').click()"
|
js_code="document.querySelector('a[href$=\".exe\"]').click()",
|
||||||
|
)
|
||||||
|
self.log_result(
|
||||||
|
"Invalid Download Path", False, "Should have raised an error"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.log_result(
|
||||||
|
"Invalid Download Path", True, "Correctly handled invalid path"
|
||||||
)
|
)
|
||||||
self.log_result("Invalid Download Path", False, "Should have raised an error")
|
|
||||||
except Exception as e:
|
|
||||||
self.log_result("Invalid Download Path", True, "Correctly handled invalid path")
|
|
||||||
|
|
||||||
# Test 3: Download with accept_downloads=False
|
# Test 3: Download with accept_downloads=False
|
||||||
try:
|
try:
|
||||||
async with AsyncWebCrawler(
|
async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler:
|
||||||
accept_downloads=False,
|
|
||||||
verbose=True
|
|
||||||
) as crawler:
|
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url="https://www.python.org/downloads/",
|
url="https://www.python.org/downloads/",
|
||||||
js_code="document.querySelector('a[href$=\".exe\"]').click()"
|
js_code="document.querySelector('a[href$=\".exe\"]').click()",
|
||||||
)
|
)
|
||||||
success = result.downloaded_files is None
|
success = result.downloaded_files is None
|
||||||
self.log_result(
|
self.log_result(
|
||||||
"Disabled Downloads",
|
"Disabled Downloads",
|
||||||
success,
|
success,
|
||||||
"Correctly ignored downloads" if success else "Unexpectedly downloaded files"
|
"Correctly ignored downloads"
|
||||||
|
if success
|
||||||
|
else "Unexpectedly downloaded files",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log_result("Disabled Downloads", False, str(e))
|
self.log_result("Disabled Downloads", False, str(e))
|
||||||
@@ -203,7 +219,7 @@ class TestDownloads:
|
|||||||
self.test_persistent_context_download,
|
self.test_persistent_context_download,
|
||||||
self.test_multiple_downloads,
|
self.test_multiple_downloads,
|
||||||
self.test_different_browsers,
|
self.test_different_browsers,
|
||||||
self.test_edge_cases
|
self.test_edge_cases,
|
||||||
]
|
]
|
||||||
|
|
||||||
for test in test_methods:
|
for test in test_methods:
|
||||||
@@ -215,15 +231,17 @@ class TestDownloads:
|
|||||||
for result in self.results:
|
for result in self.results:
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
successes = len([r for r in self.results if '✅' in r])
|
successes = len([r for r in self.results if "✅" in r])
|
||||||
total = len(self.results)
|
total = len(self.results)
|
||||||
print(f"\nTotal: {successes}/{total} tests passed")
|
print(f"\nTotal: {successes}/{total} tests passed")
|
||||||
|
|
||||||
self.cleanup()
|
self.cleanup()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
tester = TestDownloads()
|
tester = TestDownloads()
|
||||||
await tester.run_all_tests()
|
await tester.run_all_tests()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -1,15 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
parent_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
)
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_successful_crawl():
|
async def test_successful_crawl():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -21,6 +23,7 @@ async def test_successful_crawl():
|
|||||||
assert result.markdown
|
assert result.markdown
|
||||||
assert result.cleaned_html
|
assert result.cleaned_html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_url():
|
async def test_invalid_url():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -29,19 +32,21 @@ async def test_invalid_url():
|
|||||||
assert not result.success
|
assert not result.success
|
||||||
assert result.error_message
|
assert result.error_message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_urls():
|
async def test_multiple_urls():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
urls = [
|
urls = [
|
||||||
"https://www.nbcnews.com/business",
|
"https://www.nbcnews.com/business",
|
||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
"https://www.python.org"
|
"https://www.python.org",
|
||||||
]
|
]
|
||||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||||
assert len(results) == len(urls)
|
assert len(results) == len(urls)
|
||||||
assert all(result.success for result in results)
|
assert all(result.success for result in results)
|
||||||
assert all(result.html for result in results)
|
assert all(result.html for result in results)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_javascript_execution():
|
async def test_javascript_execution():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -51,6 +56,7 @@ async def test_javascript_execution():
|
|||||||
assert result.success
|
assert result.success
|
||||||
assert "<h1>Modified by JS</h1>" in result.html
|
assert "<h1>Modified by JS</h1>" in result.html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_concurrent_crawling_performance():
|
async def test_concurrent_crawling_performance():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -59,7 +65,7 @@ async def test_concurrent_crawling_performance():
|
|||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
"https://www.python.org",
|
"https://www.python.org",
|
||||||
"https://www.github.com",
|
"https://www.github.com",
|
||||||
"https://www.stackoverflow.com"
|
"https://www.stackoverflow.com",
|
||||||
]
|
]
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -74,7 +80,10 @@ async def test_concurrent_crawling_performance():
|
|||||||
|
|
||||||
# Assert that concurrent crawling is faster than sequential
|
# Assert that concurrent crawling is faster than sequential
|
||||||
# This multiplier may need adjustment based on the number of URLs and their complexity
|
# This multiplier may need adjustment based on the number of URLs and their complexity
|
||||||
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
assert (
|
||||||
|
total_time < len(urls) * 5
|
||||||
|
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_caching():
|
async def test_caching():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -31,6 +32,7 @@ async def test_caching():
|
|||||||
assert result2.success
|
assert result2.success
|
||||||
assert time_taken2 < time_taken1 # Cached result should be faster
|
assert time_taken2 < time_taken1 # Cached result should be faster
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bypass_cache():
|
async def test_bypass_cache():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -47,6 +49,7 @@ async def test_bypass_cache():
|
|||||||
# Content should be different (or at least, not guaranteed to be the same)
|
# Content should be different (or at least, not guaranteed to be the same)
|
||||||
assert result1.html != result2.html or result1.markdown != result2.markdown
|
assert result1.html != result2.html or result1.markdown != result2.markdown
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_clear_cache():
|
async def test_clear_cache():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -62,6 +65,7 @@ async def test_clear_cache():
|
|||||||
cache_size = await crawler.aget_cache_size()
|
cache_size = await crawler.aget_cache_size()
|
||||||
assert cache_size == 0
|
assert cache_size == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flush_cache():
|
async def test_flush_cache():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -77,6 +81,7 @@ async def test_flush_cache():
|
|||||||
cache_size = await crawler.aget_cache_size()
|
cache_size = await crawler.aget_cache_size()
|
||||||
assert cache_size == 0
|
assert cache_size == 0
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
@@ -9,8 +8,9 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking
|
from crawl4ai.chunking_strategy import RegexChunking
|
||||||
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy
|
from crawl4ai.extraction_strategy import LLMExtractionStrategy
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regex_chunking():
|
async def test_regex_chunking():
|
||||||
@@ -18,15 +18,14 @@ async def test_regex_chunking():
|
|||||||
url = "https://www.nbcnews.com/business"
|
url = "https://www.nbcnews.com/business"
|
||||||
chunking_strategy = RegexChunking(patterns=["\n\n"])
|
chunking_strategy = RegexChunking(patterns=["\n\n"])
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url=url,
|
url=url, chunking_strategy=chunking_strategy, bypass_cache=True
|
||||||
chunking_strategy=chunking_strategy,
|
|
||||||
bypass_cache=True
|
|
||||||
)
|
)
|
||||||
assert result.success
|
assert result.success
|
||||||
assert result.extracted_content
|
assert result.extracted_content
|
||||||
chunks = json.loads(result.extracted_content)
|
chunks = json.loads(result.extracted_content)
|
||||||
assert len(chunks) > 1 # Ensure multiple chunks were created
|
assert len(chunks) > 1 # Ensure multiple chunks were created
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
# async def test_cosine_strategy():
|
# async def test_cosine_strategy():
|
||||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -43,25 +42,25 @@ async def test_regex_chunking():
|
|||||||
# assert len(extracted_data) > 0
|
# assert len(extracted_data) > 0
|
||||||
# assert all('tags' in item for item in extracted_data)
|
# assert all('tags' in item for item in extracted_data)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_extraction_strategy():
|
async def test_llm_extraction_strategy():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
url = "https://www.nbcnews.com/business"
|
url = "https://www.nbcnews.com/business"
|
||||||
extraction_strategy = LLMExtractionStrategy(
|
extraction_strategy = LLMExtractionStrategy(
|
||||||
provider="openai/gpt-4o-mini",
|
provider="openai/gpt-4o-mini",
|
||||||
api_token=os.getenv('OPENAI_API_KEY'),
|
api_token=os.getenv("OPENAI_API_KEY"),
|
||||||
instruction="Extract only content related to technology"
|
instruction="Extract only content related to technology",
|
||||||
)
|
)
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url=url,
|
url=url, extraction_strategy=extraction_strategy, bypass_cache=True
|
||||||
extraction_strategy=extraction_strategy,
|
|
||||||
bypass_cache=True
|
|
||||||
)
|
)
|
||||||
assert result.success
|
assert result.success
|
||||||
assert result.extracted_content
|
assert result.extracted_content
|
||||||
extracted_data = json.loads(result.extracted_content)
|
extracted_data = json.loads(result.extracted_content)
|
||||||
assert len(extracted_data) > 0
|
assert len(extracted_data) > 0
|
||||||
assert all('content' in item for item in extracted_data)
|
assert all("content" in item for item in extracted_data)
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
# async def test_combined_chunking_and_extraction():
|
# async def test_combined_chunking_and_extraction():
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_markdown():
|
async def test_extract_markdown():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -20,6 +19,7 @@ async def test_extract_markdown():
|
|||||||
assert isinstance(result.markdown, str)
|
assert isinstance(result.markdown, str)
|
||||||
assert len(result.markdown) > 0
|
assert len(result.markdown) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_cleaned_html():
|
async def test_extract_cleaned_html():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -30,6 +30,7 @@ async def test_extract_cleaned_html():
|
|||||||
assert isinstance(result.cleaned_html, str)
|
assert isinstance(result.cleaned_html, str)
|
||||||
assert len(result.cleaned_html) > 0
|
assert len(result.cleaned_html) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_media():
|
async def test_extract_media():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -46,6 +47,7 @@ async def test_extract_media():
|
|||||||
assert "alt" in image
|
assert "alt" in image
|
||||||
assert "type" in image
|
assert "type" in image
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_links():
|
async def test_extract_links():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -63,6 +65,7 @@ async def test_extract_links():
|
|||||||
assert "href" in link
|
assert "href" in link
|
||||||
assert "text" in link
|
assert "text" in link
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_metadata():
|
async def test_extract_metadata():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -75,16 +78,20 @@ async def test_extract_metadata():
|
|||||||
assert "title" in metadata
|
assert "title" in metadata
|
||||||
assert isinstance(metadata["title"], str)
|
assert isinstance(metadata["title"], str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_css_selector_extraction():
|
async def test_css_selector_extraction():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
url = "https://www.nbcnews.com/business"
|
url = "https://www.nbcnews.com/business"
|
||||||
css_selector = "h1, h2, h3"
|
css_selector = "h1, h2, h3"
|
||||||
result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector)
|
result = await crawler.arun(
|
||||||
|
url=url, bypass_cache=True, css_selector=css_selector
|
||||||
|
)
|
||||||
assert result.success
|
assert result.success
|
||||||
assert result.markdown
|
assert result.markdown
|
||||||
assert all(heading in result.markdown for heading in ["#", "##", "###"])
|
assert all(heading in result.markdown for heading in ["#", "##", "###"])
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
import pytest
|
import pytest
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -9,6 +8,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
from crawl4ai.content_filter_strategy import BM25ContentFilter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def basic_html():
|
def basic_html():
|
||||||
return """
|
return """
|
||||||
@@ -28,6 +28,7 @@ def basic_html():
|
|||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def wiki_html():
|
def wiki_html():
|
||||||
return """
|
return """
|
||||||
@@ -46,6 +47,7 @@ def wiki_html():
|
|||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def no_meta_html():
|
def no_meta_html():
|
||||||
return """
|
return """
|
||||||
@@ -57,6 +59,7 @@ def no_meta_html():
|
|||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TestBM25ContentFilter:
|
class TestBM25ContentFilter:
|
||||||
def test_basic_extraction(self, basic_html):
|
def test_basic_extraction(self, basic_html):
|
||||||
"""Test basic content extraction functionality"""
|
"""Test basic content extraction functionality"""
|
||||||
@@ -65,8 +68,8 @@ class TestBM25ContentFilter:
|
|||||||
|
|
||||||
assert contents, "Should extract content"
|
assert contents, "Should extract content"
|
||||||
assert len(contents) >= 1, "Should extract at least one content block"
|
assert len(contents) >= 1, "Should extract at least one content block"
|
||||||
assert "long paragraph" in ' '.join(contents).lower()
|
assert "long paragraph" in " ".join(contents).lower()
|
||||||
assert "navigation" not in ' '.join(contents).lower()
|
assert "navigation" not in " ".join(contents).lower()
|
||||||
|
|
||||||
def test_user_query_override(self, basic_html):
|
def test_user_query_override(self, basic_html):
|
||||||
"""Test that user query overrides metadata extraction"""
|
"""Test that user query overrides metadata extraction"""
|
||||||
@@ -74,8 +77,8 @@ class TestBM25ContentFilter:
|
|||||||
filter = BM25ContentFilter(user_query=user_query)
|
filter = BM25ContentFilter(user_query=user_query)
|
||||||
|
|
||||||
# Access internal state to verify query usage
|
# Access internal state to verify query usage
|
||||||
soup = BeautifulSoup(basic_html, 'lxml')
|
soup = BeautifulSoup(basic_html, "lxml")
|
||||||
extracted_query = filter.extract_page_query(soup.find('head'))
|
extracted_query = filter.extract_page_query(soup.find("head"))
|
||||||
|
|
||||||
assert extracted_query == user_query
|
assert extracted_query == user_query
|
||||||
assert "Test description" not in extracted_query
|
assert "Test description" not in extracted_query
|
||||||
@@ -85,7 +88,7 @@ class TestBM25ContentFilter:
|
|||||||
filter = BM25ContentFilter()
|
filter = BM25ContentFilter()
|
||||||
contents = filter.filter_content(wiki_html)
|
contents = filter.filter_content(wiki_html)
|
||||||
|
|
||||||
combined_content = ' '.join(contents).lower()
|
combined_content = " ".join(contents).lower()
|
||||||
assert "section 1" in combined_content, "Should include section header"
|
assert "section 1" in combined_content, "Should include section header"
|
||||||
assert "article title" in combined_content, "Should include main title"
|
assert "article title" in combined_content, "Should include main title"
|
||||||
|
|
||||||
@@ -95,7 +98,9 @@ class TestBM25ContentFilter:
|
|||||||
contents = filter.filter_content(no_meta_html)
|
contents = filter.filter_content(no_meta_html)
|
||||||
|
|
||||||
assert contents, "Should extract content even without metadata"
|
assert contents, "Should extract content even without metadata"
|
||||||
assert "First paragraph" in ' '.join(contents), "Should use first paragraph content"
|
assert "First paragraph" in " ".join(
|
||||||
|
contents
|
||||||
|
), "Should use first paragraph content"
|
||||||
|
|
||||||
def test_empty_input(self):
|
def test_empty_input(self):
|
||||||
"""Test handling of empty input"""
|
"""Test handling of empty input"""
|
||||||
@@ -119,18 +124,19 @@ class TestBM25ContentFilter:
|
|||||||
strict_contents = strict_filter.filter_content(basic_html)
|
strict_contents = strict_filter.filter_content(basic_html)
|
||||||
lenient_contents = lenient_filter.filter_content(basic_html)
|
lenient_contents = lenient_filter.filter_content(basic_html)
|
||||||
|
|
||||||
assert len(strict_contents) <= len(lenient_contents), \
|
assert len(strict_contents) <= len(
|
||||||
"Strict threshold should extract fewer elements"
|
lenient_contents
|
||||||
|
), "Strict threshold should extract fewer elements"
|
||||||
|
|
||||||
def test_html_cleaning(self, basic_html):
|
def test_html_cleaning(self, basic_html):
|
||||||
"""Test HTML cleaning functionality"""
|
"""Test HTML cleaning functionality"""
|
||||||
filter = BM25ContentFilter()
|
filter = BM25ContentFilter()
|
||||||
contents = filter.filter_content(basic_html)
|
contents = filter.filter_content(basic_html)
|
||||||
|
|
||||||
cleaned_content = ' '.join(contents)
|
cleaned_content = " ".join(contents)
|
||||||
assert 'class=' not in cleaned_content, "Should remove class attributes"
|
assert "class=" not in cleaned_content, "Should remove class attributes"
|
||||||
assert 'style=' not in cleaned_content, "Should remove style attributes"
|
assert "style=" not in cleaned_content, "Should remove style attributes"
|
||||||
assert '<script' not in cleaned_content, "Should remove script tags"
|
assert "<script" not in cleaned_content, "Should remove script tags"
|
||||||
|
|
||||||
def test_large_content(self):
|
def test_large_content(self):
|
||||||
"""Test handling of large content blocks"""
|
"""Test handling of large content blocks"""
|
||||||
@@ -143,9 +149,9 @@ class TestBM25ContentFilter:
|
|||||||
contents = filter.filter_content(large_html)
|
contents = filter.filter_content(large_html)
|
||||||
assert contents, "Should handle large content blocks"
|
assert contents, "Should handle large content blocks"
|
||||||
|
|
||||||
@pytest.mark.parametrize("unwanted_tag", [
|
@pytest.mark.parametrize(
|
||||||
'script', 'style', 'nav', 'footer', 'header'
|
"unwanted_tag", ["script", "style", "nav", "footer", "header"]
|
||||||
])
|
)
|
||||||
def test_excluded_tags(self, unwanted_tag):
|
def test_excluded_tags(self, unwanted_tag):
|
||||||
"""Test that specific tags are properly excluded"""
|
"""Test that specific tags are properly excluded"""
|
||||||
html = f"""
|
html = f"""
|
||||||
@@ -157,7 +163,7 @@ class TestBM25ContentFilter:
|
|||||||
filter = BM25ContentFilter()
|
filter = BM25ContentFilter()
|
||||||
contents = filter.filter_content(html)
|
contents = filter.filter_content(html)
|
||||||
|
|
||||||
combined_content = ' '.join(contents).lower()
|
combined_content = " ".join(contents).lower()
|
||||||
assert "should not appear" not in combined_content
|
assert "should not appear" not in combined_content
|
||||||
|
|
||||||
def test_performance(self, basic_html):
|
def test_performance(self, basic_html):
|
||||||
@@ -165,11 +171,13 @@ class TestBM25ContentFilter:
|
|||||||
filter = BM25ContentFilter()
|
filter = BM25ContentFilter()
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
filter.filter_content(basic_html)
|
filter.filter_content(basic_html)
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
|
|
||||||
assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds"
|
assert duration < 1.0, f"Processing took too long: {duration:.2f} seconds"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
import pytest
|
import pytest
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
|
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
from crawl4ai.content_filter_strategy import PruningContentFilter
|
from crawl4ai.content_filter_strategy import PruningContentFilter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def basic_html():
|
def basic_html():
|
||||||
return """
|
return """
|
||||||
@@ -22,6 +22,7 @@ def basic_html():
|
|||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def link_heavy_html():
|
def link_heavy_html():
|
||||||
return """
|
return """
|
||||||
@@ -40,6 +41,7 @@ def link_heavy_html():
|
|||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mixed_content_html():
|
def mixed_content_html():
|
||||||
return """
|
return """
|
||||||
@@ -60,13 +62,14 @@ def mixed_content_html():
|
|||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TestPruningContentFilter:
|
class TestPruningContentFilter:
|
||||||
def test_basic_pruning(self, basic_html):
|
def test_basic_pruning(self, basic_html):
|
||||||
"""Test basic content pruning functionality"""
|
"""Test basic content pruning functionality"""
|
||||||
filter = PruningContentFilter(min_word_threshold=5)
|
filter = PruningContentFilter(min_word_threshold=5)
|
||||||
contents = filter.filter_content(basic_html)
|
contents = filter.filter_content(basic_html)
|
||||||
|
|
||||||
combined_content = ' '.join(contents).lower()
|
combined_content = " ".join(contents).lower()
|
||||||
assert "high-quality paragraph" in combined_content
|
assert "high-quality paragraph" in combined_content
|
||||||
assert "sidebar content" not in combined_content
|
assert "sidebar content" not in combined_content
|
||||||
assert "share buttons" not in combined_content
|
assert "share buttons" not in combined_content
|
||||||
@@ -76,39 +79,41 @@ class TestPruningContentFilter:
|
|||||||
filter = PruningContentFilter(min_word_threshold=10)
|
filter = PruningContentFilter(min_word_threshold=10)
|
||||||
contents = filter.filter_content(mixed_content_html)
|
contents = filter.filter_content(mixed_content_html)
|
||||||
|
|
||||||
combined_content = ' '.join(contents).lower()
|
combined_content = " ".join(contents).lower()
|
||||||
assert "short summary" not in combined_content
|
assert "short summary" not in combined_content
|
||||||
assert "long high-quality paragraph" in combined_content
|
assert "long high-quality paragraph" in combined_content
|
||||||
assert "short comment" not in combined_content
|
assert "short comment" not in combined_content
|
||||||
|
|
||||||
def test_threshold_types(self, basic_html):
|
def test_threshold_types(self, basic_html):
|
||||||
"""Test fixed vs dynamic thresholds"""
|
"""Test fixed vs dynamic thresholds"""
|
||||||
fixed_filter = PruningContentFilter(threshold_type='fixed', threshold=0.48)
|
fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48)
|
||||||
dynamic_filter = PruningContentFilter(threshold_type='dynamic', threshold=0.45)
|
dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45)
|
||||||
|
|
||||||
fixed_contents = fixed_filter.filter_content(basic_html)
|
fixed_contents = fixed_filter.filter_content(basic_html)
|
||||||
dynamic_contents = dynamic_filter.filter_content(basic_html)
|
dynamic_contents = dynamic_filter.filter_content(basic_html)
|
||||||
|
|
||||||
assert len(fixed_contents) != len(dynamic_contents), \
|
assert len(fixed_contents) != len(
|
||||||
"Fixed and dynamic thresholds should yield different results"
|
dynamic_contents
|
||||||
|
), "Fixed and dynamic thresholds should yield different results"
|
||||||
|
|
||||||
def test_link_density_impact(self, link_heavy_html):
|
def test_link_density_impact(self, link_heavy_html):
|
||||||
"""Test handling of link-heavy content"""
|
"""Test handling of link-heavy content"""
|
||||||
filter = PruningContentFilter(threshold_type='dynamic')
|
filter = PruningContentFilter(threshold_type="dynamic")
|
||||||
contents = filter.filter_content(link_heavy_html)
|
contents = filter.filter_content(link_heavy_html)
|
||||||
|
|
||||||
combined_content = ' '.join(contents).lower()
|
combined_content = " ".join(contents).lower()
|
||||||
assert "good content paragraph" in combined_content
|
assert "good content paragraph" in combined_content
|
||||||
assert len([c for c in contents if 'href' in c]) < 2, \
|
assert (
|
||||||
"Should prune link-heavy sections"
|
len([c for c in contents if "href" in c]) < 2
|
||||||
|
), "Should prune link-heavy sections"
|
||||||
|
|
||||||
def test_tag_importance(self, mixed_content_html):
|
def test_tag_importance(self, mixed_content_html):
|
||||||
"""Test tag importance in scoring"""
|
"""Test tag importance in scoring"""
|
||||||
filter = PruningContentFilter(threshold_type='dynamic')
|
filter = PruningContentFilter(threshold_type="dynamic")
|
||||||
contents = filter.filter_content(mixed_content_html)
|
contents = filter.filter_content(mixed_content_html)
|
||||||
|
|
||||||
has_article = any('article' in c.lower() for c in contents)
|
has_article = any("article" in c.lower() for c in contents)
|
||||||
has_h1 = any('h1' in c.lower() for c in contents)
|
has_h1 = any("h1" in c.lower() for c in contents)
|
||||||
assert has_article or has_h1, "Should retain important tags"
|
assert has_article or has_h1, "Should retain important tags"
|
||||||
|
|
||||||
def test_empty_input(self):
|
def test_empty_input(self):
|
||||||
@@ -129,6 +134,7 @@ class TestPruningContentFilter:
|
|||||||
filter = PruningContentFilter()
|
filter = PruningContentFilter()
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
filter.filter_content(basic_html)
|
filter.filter_content(basic_html)
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
@@ -136,17 +142,21 @@ class TestPruningContentFilter:
|
|||||||
# Extra strict on performance since you mentioned milliseconds matter
|
# Extra strict on performance since you mentioned milliseconds matter
|
||||||
assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds"
|
assert duration < 0.1, f"Processing took too long: {duration:.3f} seconds"
|
||||||
|
|
||||||
@pytest.mark.parametrize("threshold,expected_count", [
|
@pytest.mark.parametrize(
|
||||||
|
"threshold,expected_count",
|
||||||
|
[
|
||||||
(0.3, 4), # Very lenient
|
(0.3, 4), # Very lenient
|
||||||
(0.48, 2), # Default
|
(0.48, 2), # Default
|
||||||
(0.7, 1), # Very strict
|
(0.7, 1), # Very strict
|
||||||
])
|
],
|
||||||
|
)
|
||||||
def test_threshold_levels(self, mixed_content_html, threshold, expected_count):
|
def test_threshold_levels(self, mixed_content_html, threshold, expected_count):
|
||||||
"""Test different threshold levels"""
|
"""Test different threshold levels"""
|
||||||
filter = PruningContentFilter(threshold_type='fixed', threshold=threshold)
|
filter = PruningContentFilter(threshold_type="fixed", threshold=threshold)
|
||||||
contents = filter.filter_content(mixed_content_html)
|
contents = filter.filter_content(mixed_content_html)
|
||||||
assert len(contents) <= expected_count, \
|
assert (
|
||||||
f"Expected {expected_count} or fewer elements with threshold {threshold}"
|
len(contents) <= expected_count
|
||||||
|
), f"Expected {expected_count} or fewer elements with threshold {threshold}"
|
||||||
|
|
||||||
def test_consistent_output(self, basic_html):
|
def test_consistent_output(self, basic_html):
|
||||||
"""Test output consistency across multiple runs"""
|
"""Test output consistency across multiple runs"""
|
||||||
@@ -155,5 +165,6 @@ class TestPruningContentFilter:
|
|||||||
second_run = filter.filter_content(basic_html)
|
second_run = filter.filter_content(basic_html)
|
||||||
assert first_run == second_run, "Output should be consistent"
|
assert first_run == second_run, "Output should be consistent"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
@@ -1,22 +1,24 @@
|
|||||||
import asyncio
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
from typing import Dict, Any
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import csv
|
import csv
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Dict
|
from typing import List
|
||||||
|
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
parent_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
)
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
from crawl4ai.content_scraping_strategy import WebScrapingStrategy
|
from crawl4ai.content_scraping_strategy import WebScrapingStrategy
|
||||||
from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent
|
from crawl4ai.content_scraping_strategy import (
|
||||||
|
WebScrapingStrategy as WebScrapingStrategyCurrent,
|
||||||
|
)
|
||||||
# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
|
# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestResult:
|
class TestResult:
|
||||||
name: str
|
name: str
|
||||||
@@ -27,33 +29,32 @@ class TestResult:
|
|||||||
markdown_length: int
|
markdown_length: int
|
||||||
execution_time: float
|
execution_time: float
|
||||||
|
|
||||||
|
|
||||||
class StrategyTester:
|
class StrategyTester:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.new_scraper = WebScrapingStrategy()
|
self.new_scraper = WebScrapingStrategy()
|
||||||
self.current_scraper = WebScrapingStrategyCurrent()
|
self.current_scraper = WebScrapingStrategyCurrent()
|
||||||
with open(__location__ + '/sample_wikipedia.html', 'r', encoding='utf-8') as f:
|
with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f:
|
||||||
self.WIKI_HTML = f.read()
|
self.WIKI_HTML = f.read()
|
||||||
self.results = {'new': [], 'current': []}
|
self.results = {"new": [], "current": []}
|
||||||
|
|
||||||
def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
|
def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
|
||||||
results = []
|
results = []
|
||||||
for scraper in [self.new_scraper, self.current_scraper]:
|
for scraper in [self.new_scraper, self.current_scraper]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
result = scraper._get_content_of_website_optimized(
|
result = scraper._get_content_of_website_optimized(
|
||||||
url="https://en.wikipedia.org/wiki/Test",
|
url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs
|
||||||
html=self.WIKI_HTML,
|
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
test_result = TestResult(
|
test_result = TestResult(
|
||||||
name=name,
|
name=name,
|
||||||
success=result['success'],
|
success=result["success"],
|
||||||
images=len(result['media']['images']),
|
images=len(result["media"]["images"]),
|
||||||
internal_links=len(result['links']['internal']),
|
internal_links=len(result["links"]["internal"]),
|
||||||
external_links=len(result['links']['external']),
|
external_links=len(result["links"]["external"]),
|
||||||
markdown_length=len(result['markdown']),
|
markdown_length=len(result["markdown"]),
|
||||||
execution_time=execution_time
|
execution_time=execution_time,
|
||||||
)
|
)
|
||||||
results.append(test_result)
|
results.append(test_result)
|
||||||
|
|
||||||
@@ -62,34 +63,37 @@ class StrategyTester:
|
|||||||
def run_all_tests(self):
|
def run_all_tests(self):
|
||||||
test_cases = [
|
test_cases = [
|
||||||
("Basic Extraction", {}),
|
("Basic Extraction", {}),
|
||||||
("Exclude Tags", {'excluded_tags': ['table', 'div.infobox', 'div.navbox']}),
|
("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
|
||||||
("Word Threshold", {'word_count_threshold': 50}),
|
("Word Threshold", {"word_count_threshold": 50}),
|
||||||
("CSS Selector", {'css_selector': 'div.mw-parser-output > p'}),
|
("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
|
||||||
("Link Exclusions", {
|
(
|
||||||
'exclude_external_links': True,
|
"Link Exclusions",
|
||||||
'exclude_social_media_links': True,
|
{
|
||||||
'exclude_domains': ['facebook.com', 'twitter.com']
|
"exclude_external_links": True,
|
||||||
}),
|
"exclude_social_media_links": True,
|
||||||
("Media Handling", {
|
"exclude_domains": ["facebook.com", "twitter.com"],
|
||||||
'exclude_external_images': True,
|
},
|
||||||
'image_description_min_word_threshold': 20
|
),
|
||||||
}),
|
(
|
||||||
("Text Only", {
|
"Media Handling",
|
||||||
'only_text': True,
|
{
|
||||||
'remove_forms': True
|
"exclude_external_images": True,
|
||||||
}),
|
"image_description_min_word_threshold": 20,
|
||||||
("HTML Cleaning", {
|
},
|
||||||
'clean_html': True,
|
),
|
||||||
'keep_data_attributes': True
|
("Text Only", {"only_text": True, "remove_forms": True}),
|
||||||
}),
|
("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
|
||||||
("HTML2Text Options", {
|
(
|
||||||
'html2text': {
|
"HTML2Text Options",
|
||||||
'skip_internal_links': True,
|
{
|
||||||
'single_line_break': True,
|
"html2text": {
|
||||||
'mark_code': True,
|
"skip_internal_links": True,
|
||||||
'preserve_tags': ['pre', 'code']
|
"single_line_break": True,
|
||||||
|
"mark_code": True,
|
||||||
|
"preserve_tags": ["pre", "code"],
|
||||||
}
|
}
|
||||||
})
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
@@ -104,58 +108,111 @@ class StrategyTester:
|
|||||||
self.print_comparison_table(all_results)
|
self.print_comparison_table(all_results)
|
||||||
|
|
||||||
def save_results_to_csv(self, all_results: List[tuple]):
|
def save_results_to_csv(self, all_results: List[tuple]):
|
||||||
csv_file = os.path.join(__location__, 'strategy_comparison_results.csv')
|
csv_file = os.path.join(__location__, "strategy_comparison_results.csv")
|
||||||
with open(csv_file, 'w', newline='') as f:
|
with open(csv_file, "w", newline="") as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
writer.writerow(['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links',
|
writer.writerow(
|
||||||
'External Links', 'Markdown Length', 'Execution Time'])
|
[
|
||||||
|
"Test Name",
|
||||||
|
"Strategy",
|
||||||
|
"Success",
|
||||||
|
"Images",
|
||||||
|
"Internal Links",
|
||||||
|
"External Links",
|
||||||
|
"Markdown Length",
|
||||||
|
"Execution Time",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
for name, new_result, current_result in all_results:
|
for name, new_result, current_result in all_results:
|
||||||
writer.writerow([name, 'New', new_result.success, new_result.images,
|
writer.writerow(
|
||||||
new_result.internal_links, new_result.external_links,
|
[
|
||||||
new_result.markdown_length, f"{new_result.execution_time:.3f}"])
|
name,
|
||||||
writer.writerow([name, 'Current', current_result.success, current_result.images,
|
"New",
|
||||||
current_result.internal_links, current_result.external_links,
|
new_result.success,
|
||||||
current_result.markdown_length, f"{current_result.execution_time:.3f}"])
|
new_result.images,
|
||||||
|
new_result.internal_links,
|
||||||
|
new_result.external_links,
|
||||||
|
new_result.markdown_length,
|
||||||
|
f"{new_result.execution_time:.3f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
name,
|
||||||
|
"Current",
|
||||||
|
current_result.success,
|
||||||
|
current_result.images,
|
||||||
|
current_result.internal_links,
|
||||||
|
current_result.external_links,
|
||||||
|
current_result.markdown_length,
|
||||||
|
f"{current_result.execution_time:.3f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def print_comparison_table(self, all_results: List[tuple]):
|
def print_comparison_table(self, all_results: List[tuple]):
|
||||||
table_data = []
|
table_data = []
|
||||||
headers = ['Test Name', 'Strategy', 'Success', 'Images', 'Internal Links',
|
headers = [
|
||||||
'External Links', 'Markdown Length', 'Time (s)']
|
"Test Name",
|
||||||
|
"Strategy",
|
||||||
|
"Success",
|
||||||
|
"Images",
|
||||||
|
"Internal Links",
|
||||||
|
"External Links",
|
||||||
|
"Markdown Length",
|
||||||
|
"Time (s)",
|
||||||
|
]
|
||||||
|
|
||||||
for name, new_result, current_result in all_results:
|
for name, new_result, current_result in all_results:
|
||||||
# Check for differences
|
# Check for differences
|
||||||
differences = []
|
differences = []
|
||||||
if new_result.images != current_result.images: differences.append('images')
|
if new_result.images != current_result.images:
|
||||||
if new_result.internal_links != current_result.internal_links: differences.append('internal_links')
|
differences.append("images")
|
||||||
if new_result.external_links != current_result.external_links: differences.append('external_links')
|
if new_result.internal_links != current_result.internal_links:
|
||||||
if new_result.markdown_length != current_result.markdown_length: differences.append('markdown')
|
differences.append("internal_links")
|
||||||
|
if new_result.external_links != current_result.external_links:
|
||||||
|
differences.append("external_links")
|
||||||
|
if new_result.markdown_length != current_result.markdown_length:
|
||||||
|
differences.append("markdown")
|
||||||
|
|
||||||
# Add row for new strategy
|
# Add row for new strategy
|
||||||
new_row = [
|
new_row = [
|
||||||
name, 'New', new_result.success, new_result.images,
|
name,
|
||||||
new_result.internal_links, new_result.external_links,
|
"New",
|
||||||
new_result.markdown_length, f"{new_result.execution_time:.3f}"
|
new_result.success,
|
||||||
|
new_result.images,
|
||||||
|
new_result.internal_links,
|
||||||
|
new_result.external_links,
|
||||||
|
new_result.markdown_length,
|
||||||
|
f"{new_result.execution_time:.3f}",
|
||||||
]
|
]
|
||||||
table_data.append(new_row)
|
table_data.append(new_row)
|
||||||
|
|
||||||
# Add row for current strategy
|
# Add row for current strategy
|
||||||
current_row = [
|
current_row = [
|
||||||
'', 'Current', current_result.success, current_result.images,
|
"",
|
||||||
current_result.internal_links, current_result.external_links,
|
"Current",
|
||||||
current_result.markdown_length, f"{current_result.execution_time:.3f}"
|
current_result.success,
|
||||||
|
current_result.images,
|
||||||
|
current_result.internal_links,
|
||||||
|
current_result.external_links,
|
||||||
|
current_result.markdown_length,
|
||||||
|
f"{current_result.execution_time:.3f}",
|
||||||
]
|
]
|
||||||
table_data.append(current_row)
|
table_data.append(current_row)
|
||||||
|
|
||||||
# Add difference summary if any
|
# Add difference summary if any
|
||||||
if differences:
|
if differences:
|
||||||
table_data.append(['', '⚠️ Differences', ', '.join(differences), '', '', '', '', ''])
|
table_data.append(
|
||||||
|
["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
|
||||||
|
)
|
||||||
|
|
||||||
# Add empty row for better readability
|
# Add empty row for better readability
|
||||||
table_data.append([''] * len(headers))
|
table_data.append([""] * len(headers))
|
||||||
|
|
||||||
print("\nStrategy Comparison Results:")
|
print("\nStrategy Comparison Results:")
|
||||||
print(tabulate(table_data, headers=headers, tablefmt='grid'))
|
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tester = StrategyTester()
|
tester = StrategyTester()
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_custom_user_agent():
|
async def test_custom_user_agent():
|
||||||
@@ -20,6 +19,7 @@ async def test_custom_user_agent():
|
|||||||
assert result.success
|
assert result.success
|
||||||
assert custom_user_agent in result.html
|
assert custom_user_agent in result.html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_custom_headers():
|
async def test_custom_headers():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -31,6 +31,7 @@ async def test_custom_headers():
|
|||||||
assert "X-Test-Header" in result.html
|
assert "X-Test-Header" in result.html
|
||||||
assert "TestValue" in result.html
|
assert "TestValue" in result.html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_javascript_execution():
|
async def test_javascript_execution():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -40,19 +41,22 @@ async def test_javascript_execution():
|
|||||||
assert result.success
|
assert result.success
|
||||||
assert "<h1>Modified by JS</h1>" in result.html
|
assert "<h1>Modified by JS</h1>" in result.html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_hook_execution():
|
async def test_hook_execution():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
|
|
||||||
async def test_hook(page):
|
async def test_hook(page):
|
||||||
await page.evaluate("document.body.style.backgroundColor = 'red';")
|
await page.evaluate("document.body.style.backgroundColor = 'red';")
|
||||||
return page
|
return page
|
||||||
|
|
||||||
crawler.crawler_strategy.set_hook('after_goto', test_hook)
|
crawler.crawler_strategy.set_hook("after_goto", test_hook)
|
||||||
url = "https://www.example.com"
|
url = "https://www.example.com"
|
||||||
result = await crawler.arun(url=url, bypass_cache=True)
|
result = await crawler.arun(url=url, bypass_cache=True)
|
||||||
assert result.success
|
assert result.success
|
||||||
assert "background-color: red" in result.html
|
assert "background-color: red" in result.html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot():
|
async def test_screenshot():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -63,6 +67,7 @@ async def test_screenshot():
|
|||||||
assert isinstance(result.screenshot, str)
|
assert isinstance(result.screenshot, str)
|
||||||
assert len(result.screenshot) > 0
|
assert len(result.screenshot) > 0
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -10,6 +8,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cache_url():
|
async def test_cache_url():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -23,6 +22,7 @@ async def test_cache_url():
|
|||||||
assert result2.success
|
assert result2.success
|
||||||
assert result2.html == result1.html
|
assert result2.html == result1.html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bypass_cache():
|
async def test_bypass_cache():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -34,7 +34,10 @@ async def test_bypass_cache():
|
|||||||
# Second run bypassing cache
|
# Second run bypassing cache
|
||||||
result2 = await crawler.arun(url=url, bypass_cache=True)
|
result2 = await crawler.arun(url=url, bypass_cache=True)
|
||||||
assert result2.success
|
assert result2.success
|
||||||
assert result2.html != result1.html # Content might be different due to dynamic nature of websites
|
assert (
|
||||||
|
result2.html != result1.html
|
||||||
|
) # Content might be different due to dynamic nature of websites
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cache_size():
|
async def test_cache_size():
|
||||||
@@ -47,6 +50,7 @@ async def test_cache_size():
|
|||||||
new_size = await crawler.aget_cache_size()
|
new_size = await crawler.aget_cache_size()
|
||||||
assert new_size == initial_size + 1
|
assert new_size == initial_size + 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_clear_cache():
|
async def test_clear_cache():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -60,6 +64,7 @@ async def test_clear_cache():
|
|||||||
new_size = await crawler.aget_cache_size()
|
new_size = await crawler.aget_cache_size()
|
||||||
assert new_size == 0
|
assert new_size == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flush_cache():
|
async def test_flush_cache():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -75,7 +80,10 @@ async def test_flush_cache():
|
|||||||
|
|
||||||
# Try to retrieve the previously cached URL
|
# Try to retrieve the previously cached URL
|
||||||
result = await crawler.arun(url=url, bypass_cache=False)
|
result = await crawler.arun(url=url, bypass_cache=False)
|
||||||
assert result.success # The crawler should still succeed, but it will fetch the content anew
|
assert (
|
||||||
|
result.success
|
||||||
|
) # The crawler should still succeed, but it will fetch the content anew
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,114 +1,133 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import asyncio, time
|
import time
|
||||||
from crawl4ai import (
|
from crawl4ai import (
|
||||||
AsyncWebCrawler, BrowserConfig, CrawlerRunConfig,
|
AsyncWebCrawler,
|
||||||
MemoryAdaptiveDispatcher, SemaphoreDispatcher,
|
BrowserConfig,
|
||||||
RateLimiter, CrawlerMonitor, DisplayMode, CacheMode
|
CrawlerRunConfig,
|
||||||
|
MemoryAdaptiveDispatcher,
|
||||||
|
SemaphoreDispatcher,
|
||||||
|
RateLimiter,
|
||||||
|
CrawlerMonitor,
|
||||||
|
DisplayMode,
|
||||||
|
CacheMode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def browser_config():
|
def browser_config():
|
||||||
return BrowserConfig(
|
return BrowserConfig(headless=True, verbose=False)
|
||||||
headless=True,
|
|
||||||
verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def run_config():
|
def run_config():
|
||||||
return CrawlerRunConfig(
|
return CrawlerRunConfig(cache_mode=CacheMode.BYPASS, verbose=False)
|
||||||
cache_mode=CacheMode.BYPASS,
|
|
||||||
verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_urls():
|
def test_urls():
|
||||||
return [
|
return [
|
||||||
"http://example.com",
|
"http://example.com",
|
||||||
"http://example.com/page1",
|
"http://example.com/page1",
|
||||||
"http://example.com/page2"
|
"http://example.com/page2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestDispatchStrategies:
|
class TestDispatchStrategies:
|
||||||
|
|
||||||
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
|
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = MemoryAdaptiveDispatcher(
|
dispatcher = MemoryAdaptiveDispatcher(
|
||||||
memory_threshold_percent=70.0,
|
memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1
|
||||||
max_session_permit=2,
|
)
|
||||||
check_interval=0.1
|
results = await crawler.arun_many(
|
||||||
|
test_urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
assert len(results) == len(test_urls)
|
assert len(results) == len(test_urls)
|
||||||
assert all(r.success for r in results)
|
assert all(r.success for r in results)
|
||||||
|
|
||||||
async def test_memory_adaptive_with_rate_limit(self, browser_config, run_config, test_urls):
|
async def test_memory_adaptive_with_rate_limit(
|
||||||
|
self, browser_config, run_config, test_urls
|
||||||
|
):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = MemoryAdaptiveDispatcher(
|
dispatcher = MemoryAdaptiveDispatcher(
|
||||||
memory_threshold_percent=70.0,
|
memory_threshold_percent=70.0,
|
||||||
max_session_permit=2,
|
max_session_permit=2,
|
||||||
check_interval=0.1,
|
check_interval=0.1,
|
||||||
rate_limiter=RateLimiter(
|
rate_limiter=RateLimiter(
|
||||||
base_delay=(0.1, 0.2),
|
base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
|
||||||
max_delay=1.0,
|
),
|
||||||
max_retries=2
|
|
||||||
)
|
)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
test_urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
assert len(results) == len(test_urls)
|
assert len(results) == len(test_urls)
|
||||||
assert all(r.success for r in results)
|
assert all(r.success for r in results)
|
||||||
|
|
||||||
async def test_semaphore_basic(self, browser_config, run_config, test_urls):
|
async def test_semaphore_basic(self, browser_config, run_config, test_urls):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = SemaphoreDispatcher(
|
dispatcher = SemaphoreDispatcher(semaphore_count=2)
|
||||||
semaphore_count=2
|
results = await crawler.arun_many(
|
||||||
|
test_urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
assert len(results) == len(test_urls)
|
assert len(results) == len(test_urls)
|
||||||
assert all(r.success for r in results)
|
assert all(r.success for r in results)
|
||||||
|
|
||||||
async def test_semaphore_with_rate_limit(self, browser_config, run_config, test_urls):
|
async def test_semaphore_with_rate_limit(
|
||||||
|
self, browser_config, run_config, test_urls
|
||||||
|
):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = SemaphoreDispatcher(
|
dispatcher = SemaphoreDispatcher(
|
||||||
semaphore_count=2,
|
semaphore_count=2,
|
||||||
rate_limiter=RateLimiter(
|
rate_limiter=RateLimiter(
|
||||||
base_delay=(0.1, 0.2),
|
base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2
|
||||||
max_delay=1.0,
|
),
|
||||||
max_retries=2
|
|
||||||
)
|
)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
test_urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
assert len(results) == len(test_urls)
|
assert len(results) == len(test_urls)
|
||||||
assert all(r.success for r in results)
|
assert all(r.success for r in results)
|
||||||
|
|
||||||
async def test_memory_adaptive_memory_error(self, browser_config, run_config, test_urls):
|
async def test_memory_adaptive_memory_error(
|
||||||
|
self, browser_config, run_config, test_urls
|
||||||
|
):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = MemoryAdaptiveDispatcher(
|
dispatcher = MemoryAdaptiveDispatcher(
|
||||||
memory_threshold_percent=1.0, # Set unrealistically low threshold
|
memory_threshold_percent=1.0, # Set unrealistically low threshold
|
||||||
max_session_permit=2,
|
max_session_permit=2,
|
||||||
check_interval=0.1,
|
check_interval=0.1,
|
||||||
memory_wait_timeout=1.0 # Short timeout for testing
|
memory_wait_timeout=1.0, # Short timeout for testing
|
||||||
)
|
)
|
||||||
with pytest.raises(MemoryError):
|
with pytest.raises(MemoryError):
|
||||||
await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
await crawler.arun_many(
|
||||||
|
test_urls, config=run_config, dispatcher=dispatcher
|
||||||
|
)
|
||||||
|
|
||||||
async def test_empty_urls(self, browser_config, run_config):
|
async def test_empty_urls(self, browser_config, run_config):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
||||||
results = await crawler.arun_many([], config=run_config, dispatcher=dispatcher)
|
results = await crawler.arun_many(
|
||||||
|
[], config=run_config, dispatcher=dispatcher
|
||||||
|
)
|
||||||
assert len(results) == 0
|
assert len(results) == 0
|
||||||
|
|
||||||
async def test_single_url(self, browser_config, run_config):
|
async def test_single_url(self, browser_config, run_config):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
||||||
results = await crawler.arun_many(["http://example.com"], config=run_config, dispatcher=dispatcher)
|
results = await crawler.arun_many(
|
||||||
|
["http://example.com"], config=run_config, dispatcher=dispatcher
|
||||||
|
)
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].success
|
assert results[0].success
|
||||||
|
|
||||||
async def test_invalid_urls(self, browser_config, run_config):
|
async def test_invalid_urls(self, browser_config, run_config):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2)
|
||||||
results = await crawler.arun_many(["http://invalid.url.that.doesnt.exist"], config=run_config, dispatcher=dispatcher)
|
results = await crawler.arun_many(
|
||||||
|
["http://invalid.url.that.doesnt.exist"],
|
||||||
|
config=run_config,
|
||||||
|
dispatcher=dispatcher,
|
||||||
|
)
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert not results[0].success
|
assert not results[0].success
|
||||||
|
|
||||||
@@ -121,27 +140,31 @@ class TestDispatchStrategies:
|
|||||||
base_delay=(0.1, 0.2),
|
base_delay=(0.1, 0.2),
|
||||||
max_delay=1.0,
|
max_delay=1.0,
|
||||||
max_retries=2,
|
max_retries=2,
|
||||||
rate_limit_codes=[200] # Force rate limiting for testing
|
rate_limit_codes=[200], # Force rate limiting for testing
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = await crawler.arun_many(urls, config=run_config, dispatcher=dispatcher)
|
results = await crawler.arun_many(
|
||||||
|
urls, config=run_config, dispatcher=dispatcher
|
||||||
|
)
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
assert len(results) == len(urls)
|
assert len(results) == len(urls)
|
||||||
assert duration > 1.0 # Ensure rate limiting caused delays
|
assert duration > 1.0 # Ensure rate limiting caused delays
|
||||||
|
|
||||||
async def test_monitor_integration(self, browser_config, run_config, test_urls):
|
async def test_monitor_integration(self, browser_config, run_config, test_urls):
|
||||||
async with AsyncWebCrawler(config=browser_config) as crawler:
|
async with AsyncWebCrawler(config=browser_config) as crawler:
|
||||||
monitor = CrawlerMonitor(max_visible_rows=5, display_mode=DisplayMode.DETAILED)
|
monitor = CrawlerMonitor(
|
||||||
dispatcher = MemoryAdaptiveDispatcher(
|
max_visible_rows=5, display_mode=DisplayMode.DETAILED
|
||||||
max_session_permit=2,
|
)
|
||||||
monitor=monitor
|
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor)
|
||||||
|
results = await crawler.arun_many(
|
||||||
|
test_urls, config=run_config, dispatcher=dispatcher
|
||||||
)
|
)
|
||||||
results = await crawler.arun_many(test_urls, config=run_config, dispatcher=dispatcher)
|
|
||||||
assert len(results) == len(test_urls)
|
assert len(results) == len(test_urls)
|
||||||
# Check monitor stats
|
# Check monitor stats
|
||||||
assert len(monitor.stats) == len(test_urls)
|
assert len(monitor.stats) == len(test_urls)
|
||||||
assert all(stat.end_time is not None for stat in monitor.stats.values())
|
assert all(stat.end_time is not None for stat in monitor.stats.values())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
||||||
@@ -2,9 +2,9 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import json
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
@@ -59,19 +59,21 @@ from crawl4ai.async_webcrawler import AsyncWebCrawler
|
|||||||
# assert result.success
|
# assert result.success
|
||||||
# assert "github" in result.html.lower()
|
# assert "github" in result.html.lower()
|
||||||
|
|
||||||
|
|
||||||
# Add this test to your existing test file
|
# Add this test to your existing test file
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_typescript_commits_multi_page():
|
async def test_typescript_commits_multi_page():
|
||||||
first_commit = ""
|
first_commit = ""
|
||||||
|
|
||||||
async def on_execution_started(page):
|
async def on_execution_started(page):
|
||||||
nonlocal first_commit
|
nonlocal first_commit
|
||||||
try:
|
try:
|
||||||
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
|
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
|
||||||
while True:
|
while True:
|
||||||
await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4')
|
await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4")
|
||||||
commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4')
|
commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4")
|
||||||
commit = await commit.evaluate('(element) => element.textContent')
|
commit = await commit.evaluate("(element) => element.textContent")
|
||||||
commit = re.sub(r'\s+', '', commit)
|
commit = re.sub(r"\s+", "", commit)
|
||||||
if commit and commit != first_commit:
|
if commit and commit != first_commit:
|
||||||
first_commit = commit
|
first_commit = commit
|
||||||
break
|
break
|
||||||
@@ -79,9 +81,8 @@ async def test_typescript_commits_multi_page():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: New content didn't appear after JavaScript execution: {e}")
|
print(f"Warning: New content didn't appear after JavaScript execution: {e}")
|
||||||
|
|
||||||
|
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started)
|
crawler.crawler_strategy.set_hook("on_execution_started", on_execution_started)
|
||||||
|
|
||||||
url = "https://github.com/microsoft/TypeScript/commits/main"
|
url = "https://github.com/microsoft/TypeScript/commits/main"
|
||||||
session_id = "typescript_commits_session"
|
session_id = "typescript_commits_session"
|
||||||
@@ -97,19 +98,21 @@ async def test_typescript_commits_multi_page():
|
|||||||
url=url, # Only use URL for the first page
|
url=url, # Only use URL for the first page
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
css_selector="li.Box-sc-g0xbh4-0",
|
css_selector="li.Box-sc-g0xbh4-0",
|
||||||
js=js_next_page if page > 0 else None, # Don't click 'next' on the first page
|
js=js_next_page
|
||||||
|
if page > 0
|
||||||
|
else None, # Don't click 'next' on the first page
|
||||||
bypass_cache=True,
|
bypass_cache=True,
|
||||||
js_only=page > 0 # Use js_only for subsequent pages
|
js_only=page > 0, # Use js_only for subsequent pages
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.success, f"Failed to crawl page {page + 1}"
|
assert result.success, f"Failed to crawl page {page + 1}"
|
||||||
|
|
||||||
# Parse the HTML and extract commits
|
# Parse the HTML and extract commits
|
||||||
soup = BeautifulSoup(result.cleaned_html, 'html.parser')
|
soup = BeautifulSoup(result.cleaned_html, "html.parser")
|
||||||
commits = soup.select("li")
|
commits = soup.select("li")
|
||||||
# Take first commit find h4 extract text
|
# Take first commit find h4 extract text
|
||||||
first_commit = commits[0].find("h4").text
|
first_commit = commits[0].find("h4").text
|
||||||
first_commit = re.sub(r'\s+', '', first_commit)
|
first_commit = re.sub(r"\s+", "", first_commit)
|
||||||
all_commits.extend(commits)
|
all_commits.extend(commits)
|
||||||
|
|
||||||
print(f"Page {page + 1}: Found {len(commits)} commits")
|
print(f"Page {page + 1}: Found {len(commits)} commits")
|
||||||
@@ -118,10 +121,13 @@ async def test_typescript_commits_multi_page():
|
|||||||
await crawler.crawler_strategy.kill_session(session_id)
|
await crawler.crawler_strategy.kill_session(session_id)
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}"
|
assert (
|
||||||
|
len(all_commits) >= 90
|
||||||
|
), f"Expected at least 90 commits, but got {len(all_commits)}"
|
||||||
|
|
||||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -1,11 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from crawl4ai.content_scraping_strategy import WebScrapingStrategy, LXMLWebScrapingStrategy
|
from crawl4ai.content_scraping_strategy import (
|
||||||
from typing import Dict, Any, List, Tuple
|
WebScrapingStrategy,
|
||||||
|
LXMLWebScrapingStrategy,
|
||||||
|
)
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
import difflib
|
import difflib
|
||||||
from lxml import html as lhtml, etree
|
from lxml import html as lhtml, etree
|
||||||
|
|
||||||
|
|
||||||
def normalize_dom(element):
|
def normalize_dom(element):
|
||||||
"""
|
"""
|
||||||
Recursively normalizes an lxml HTML element:
|
Recursively normalizes an lxml HTML element:
|
||||||
@@ -15,7 +19,7 @@ def normalize_dom(element):
|
|||||||
Returns the same element (mutated).
|
Returns the same element (mutated).
|
||||||
"""
|
"""
|
||||||
# Remove comment nodes
|
# Remove comment nodes
|
||||||
comments = element.xpath('//comment()')
|
comments = element.xpath("//comment()")
|
||||||
for c in comments:
|
for c in comments:
|
||||||
p = c.getparent()
|
p = c.getparent()
|
||||||
if p is not None:
|
if p is not None:
|
||||||
@@ -53,8 +57,8 @@ def strip_html_body(root):
|
|||||||
tag_name = (root.tag or "").lower()
|
tag_name = (root.tag or "").lower()
|
||||||
|
|
||||||
# Case 1: The root is <html>
|
# Case 1: The root is <html>
|
||||||
if tag_name == 'html':
|
if tag_name == "html":
|
||||||
bodies = root.xpath('./body')
|
bodies = root.xpath("./body")
|
||||||
if bodies:
|
if bodies:
|
||||||
body = bodies[0]
|
body = bodies[0]
|
||||||
new_div = lhtml.Element("div")
|
new_div = lhtml.Element("div")
|
||||||
@@ -66,7 +70,7 @@ def strip_html_body(root):
|
|||||||
return root
|
return root
|
||||||
|
|
||||||
# Case 2: The root is <body>
|
# Case 2: The root is <body>
|
||||||
elif tag_name == 'body':
|
elif tag_name == "body":
|
||||||
new_div = lhtml.Element("div")
|
new_div = lhtml.Element("div")
|
||||||
for child in root:
|
for child in root:
|
||||||
new_div.append(child)
|
new_div.append(child)
|
||||||
@@ -92,7 +96,9 @@ def compare_nodes(node1, node2, differences, path="/"):
|
|||||||
attrs1 = list(node1.attrib.items())
|
attrs1 = list(node1.attrib.items())
|
||||||
attrs2 = list(node2.attrib.items())
|
attrs2 = list(node2.attrib.items())
|
||||||
if attrs1 != attrs2:
|
if attrs1 != attrs2:
|
||||||
differences.append(f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}")
|
differences.append(
|
||||||
|
f"Attribute mismatch at {path}/{node1.tag}: {attrs1} vs. {attrs2}"
|
||||||
|
)
|
||||||
|
|
||||||
# 3) Compare text (trim or unify whitespace as needed)
|
# 3) Compare text (trim or unify whitespace as needed)
|
||||||
text1 = (node1.text or "").strip()
|
text1 = (node1.text or "").strip()
|
||||||
@@ -102,7 +108,9 @@ def compare_nodes(node1, node2, differences, path="/"):
|
|||||||
text2 = " ".join(text2.split())
|
text2 = " ".join(text2.split())
|
||||||
if text1 != text2:
|
if text1 != text2:
|
||||||
# If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup
|
# If you prefer ignoring newlines or multiple whitespace, do a more robust cleanup
|
||||||
differences.append(f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'")
|
differences.append(
|
||||||
|
f"Text mismatch at {path}/{node1.tag}: '{text1}' vs. '{text2}'"
|
||||||
|
)
|
||||||
|
|
||||||
# 4) Compare number of children
|
# 4) Compare number of children
|
||||||
children1 = list(node1)
|
children1 = list(node1)
|
||||||
@@ -123,7 +131,9 @@ def compare_nodes(node1, node2, differences, path="/"):
|
|||||||
tail1 = (node1.tail or "").strip()
|
tail1 = (node1.tail or "").strip()
|
||||||
tail2 = (node2.tail or "").strip()
|
tail2 = (node2.tail or "").strip()
|
||||||
if tail1 != tail2:
|
if tail1 != tail2:
|
||||||
differences.append(f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'")
|
differences.append(
|
||||||
|
f"Tail mismatch after {path}/{node1.tag}: '{tail1}' vs. '{tail2}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compare_html_structurally(html1, html2):
|
def compare_html_structurally(html1, html2):
|
||||||
@@ -156,11 +166,11 @@ def compare_html_structurally(html1, html2):
|
|||||||
return differences
|
return differences
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_large_html(n_elements=1000):
|
def generate_large_html(n_elements=1000):
|
||||||
html = ['<!DOCTYPE html><html><head></head><body>']
|
html = ["<!DOCTYPE html><html><head></head><body>"]
|
||||||
for i in range(n_elements):
|
for i in range(n_elements):
|
||||||
html.append(f'''
|
html.append(
|
||||||
|
f"""
|
||||||
<div class="article">
|
<div class="article">
|
||||||
<h2>Heading {i}</h2>
|
<h2>Heading {i}</h2>
|
||||||
<p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p>
|
<p>This is paragraph {i} with some content and a <a href="http://example.com/{i}">link</a></p>
|
||||||
@@ -170,9 +180,11 @@ def generate_large_html(n_elements=1000):
|
|||||||
<li>List item {i}.2</li>
|
<li>List item {i}.2</li>
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
''')
|
"""
|
||||||
html.append('</body></html>')
|
)
|
||||||
return ''.join(html)
|
html.append("</body></html>")
|
||||||
|
return "".join(html)
|
||||||
|
|
||||||
|
|
||||||
def generate_complicated_html():
|
def generate_complicated_html():
|
||||||
"""
|
"""
|
||||||
@@ -352,13 +364,12 @@ def get_test_scenarios():
|
|||||||
return TEST_SCENARIOS
|
return TEST_SCENARIOS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScraperEquivalenceTester:
|
class ScraperEquivalenceTester:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.test_cases = {
|
self.test_cases = {
|
||||||
'basic': self.generate_basic_html(),
|
"basic": self.generate_basic_html(),
|
||||||
'complex': self.generate_complex_html(),
|
"complex": self.generate_complex_html(),
|
||||||
'malformed': self.generate_malformed_html(),
|
"malformed": self.generate_malformed_html(),
|
||||||
# 'real_world': self.load_real_samples()
|
# 'real_world': self.load_real_samples()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,20 +410,19 @@ class ScraperEquivalenceTester:
|
|||||||
def load_real_samples(self):
|
def load_real_samples(self):
|
||||||
# Load some real-world HTML samples you've collected
|
# Load some real-world HTML samples you've collected
|
||||||
samples = {
|
samples = {
|
||||||
'article': open('tests/samples/article.html').read(),
|
"article": open("tests/samples/article.html").read(),
|
||||||
'product': open('tests/samples/product.html').read(),
|
"product": open("tests/samples/product.html").read(),
|
||||||
'blog': open('tests/samples/blog.html').read()
|
"blog": open("tests/samples/blog.html").read(),
|
||||||
}
|
}
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]:
|
def deep_compare_links(self, old_links: Dict, new_links: Dict) -> List[str]:
|
||||||
"""Detailed comparison of link structures"""
|
"""Detailed comparison of link structures"""
|
||||||
differences = []
|
differences = []
|
||||||
|
|
||||||
for category in ['internal', 'external']:
|
for category in ["internal", "external"]:
|
||||||
old_urls = {link['href'] for link in old_links[category]}
|
old_urls = {link["href"] for link in old_links[category]}
|
||||||
new_urls = {link['href'] for link in new_links[category]}
|
new_urls = {link["href"] for link in new_links[category]}
|
||||||
|
|
||||||
missing = old_urls - new_urls
|
missing = old_urls - new_urls
|
||||||
extra = new_urls - old_urls
|
extra = new_urls - old_urls
|
||||||
@@ -425,10 +435,10 @@ class ScraperEquivalenceTester:
|
|||||||
# Compare link attributes for common URLs
|
# Compare link attributes for common URLs
|
||||||
common = old_urls & new_urls
|
common = old_urls & new_urls
|
||||||
for url in common:
|
for url in common:
|
||||||
old_link = next(l for l in old_links[category] if l['href'] == url)
|
old_link = next(l for l in old_links[category] if l["href"] == url)
|
||||||
new_link = next(l for l in new_links[category] if l['href'] == url)
|
new_link = next(l for l in new_links[category] if l["href"] == url)
|
||||||
|
|
||||||
for attr in ['text', 'title']:
|
for attr in ["text", "title"]:
|
||||||
if old_link[attr] != new_link[attr]:
|
if old_link[attr] != new_link[attr]:
|
||||||
differences.append(
|
differences.append(
|
||||||
f"Link attribute mismatch for {url} - {attr}:"
|
f"Link attribute mismatch for {url} - {attr}:"
|
||||||
@@ -441,9 +451,9 @@ class ScraperEquivalenceTester:
|
|||||||
"""Detailed comparison of media elements"""
|
"""Detailed comparison of media elements"""
|
||||||
differences = []
|
differences = []
|
||||||
|
|
||||||
for media_type in ['images', 'videos', 'audios']:
|
for media_type in ["images", "videos", "audios"]:
|
||||||
old_srcs = {item['src'] for item in old_media[media_type]}
|
old_srcs = {item["src"] for item in old_media[media_type]}
|
||||||
new_srcs = {item['src'] for item in new_media[media_type]}
|
new_srcs = {item["src"] for item in new_media[media_type]}
|
||||||
|
|
||||||
missing = old_srcs - new_srcs
|
missing = old_srcs - new_srcs
|
||||||
extra = new_srcs - old_srcs
|
extra = new_srcs - old_srcs
|
||||||
@@ -456,10 +466,10 @@ class ScraperEquivalenceTester:
|
|||||||
# Compare media attributes for common sources
|
# Compare media attributes for common sources
|
||||||
common = old_srcs & new_srcs
|
common = old_srcs & new_srcs
|
||||||
for src in common:
|
for src in common:
|
||||||
old_item = next(m for m in old_media[media_type] if m['src'] == src)
|
old_item = next(m for m in old_media[media_type] if m["src"] == src)
|
||||||
new_item = next(m for m in new_media[media_type] if m['src'] == src)
|
new_item = next(m for m in new_media[media_type] if m["src"] == src)
|
||||||
|
|
||||||
for attr in ['alt', 'description']:
|
for attr in ["alt", "description"]:
|
||||||
if old_item.get(attr) != new_item.get(attr):
|
if old_item.get(attr) != new_item.get(attr):
|
||||||
differences.append(
|
differences.append(
|
||||||
f"{media_type} attribute mismatch for {src} - {attr}:"
|
f"{media_type} attribute mismatch for {src} - {attr}:"
|
||||||
@@ -474,10 +484,10 @@ class ScraperEquivalenceTester:
|
|||||||
differences = []
|
differences = []
|
||||||
|
|
||||||
def normalize_html(html: str) -> Tuple[str, str]:
|
def normalize_html(html: str) -> Tuple[str, str]:
|
||||||
soup = BeautifulSoup(html, 'lxml')
|
soup = BeautifulSoup(html, "lxml")
|
||||||
# Get both structure and text
|
# Get both structure and text
|
||||||
structure = ' '.join(tag.name for tag in soup.find_all())
|
structure = " ".join(tag.name for tag in soup.find_all())
|
||||||
text = ' '.join(soup.get_text().split())
|
text = " ".join(soup.get_text().split())
|
||||||
return structure, text
|
return structure, text
|
||||||
|
|
||||||
old_structure, old_text = normalize_html(old_html)
|
old_structure, old_text = normalize_html(old_html)
|
||||||
@@ -487,46 +497,47 @@ class ScraperEquivalenceTester:
|
|||||||
if abs(len(old_structure) - len(new_structure)) > 100:
|
if abs(len(old_structure) - len(new_structure)) > 100:
|
||||||
# if old_structure != new_structure:
|
# if old_structure != new_structure:
|
||||||
diff = difflib.unified_diff(
|
diff = difflib.unified_diff(
|
||||||
old_structure.split(),
|
old_structure.split(), new_structure.split(), lineterm=""
|
||||||
new_structure.split(),
|
|
||||||
lineterm=''
|
|
||||||
)
|
)
|
||||||
differences.append("HTML structure differences:\n" + '\n'.join(diff))
|
differences.append("HTML structure differences:\n" + "\n".join(diff))
|
||||||
|
|
||||||
# Compare text content
|
# Compare text content
|
||||||
if abs(len(old_text) - len(new_text)) > 100:
|
if abs(len(old_text) - len(new_text)) > 100:
|
||||||
# if old_text != new_text:
|
# if old_text != new_text:
|
||||||
# Show detailed text differences
|
# Show detailed text differences
|
||||||
text_diff = difflib.unified_diff(
|
text_diff = difflib.unified_diff(
|
||||||
old_text.split(),
|
old_text.split(), new_text.split(), lineterm=""
|
||||||
new_text.split(),
|
|
||||||
lineterm=''
|
|
||||||
)
|
)
|
||||||
differences.append("Text content differences:\n" + '\n'.join(text_diff))
|
differences.append("Text content differences:\n" + "\n".join(text_diff))
|
||||||
|
|
||||||
return differences
|
return differences
|
||||||
|
|
||||||
def compare_results(self, old_result: Dict, new_result: Dict) -> Dict[str, List[str]]:
|
def compare_results(
|
||||||
|
self, old_result: Dict, new_result: Dict
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
"""Comprehensive comparison of scraper outputs"""
|
"""Comprehensive comparison of scraper outputs"""
|
||||||
differences = {}
|
differences = {}
|
||||||
|
|
||||||
# Compare links
|
# Compare links
|
||||||
link_differences = self.deep_compare_links(old_result['links'], new_result['links'])
|
link_differences = self.deep_compare_links(
|
||||||
|
old_result["links"], new_result["links"]
|
||||||
|
)
|
||||||
if link_differences:
|
if link_differences:
|
||||||
differences['links'] = link_differences
|
differences["links"] = link_differences
|
||||||
|
|
||||||
# Compare media
|
# Compare media
|
||||||
media_differences = self.deep_compare_media(old_result['media'], new_result['media'])
|
media_differences = self.deep_compare_media(
|
||||||
|
old_result["media"], new_result["media"]
|
||||||
|
)
|
||||||
if media_differences:
|
if media_differences:
|
||||||
differences['media'] = media_differences
|
differences["media"] = media_differences
|
||||||
|
|
||||||
# Compare HTML
|
# Compare HTML
|
||||||
html_differences = self.compare_html_content(
|
html_differences = self.compare_html_content(
|
||||||
old_result['cleaned_html'],
|
old_result["cleaned_html"], new_result["cleaned_html"]
|
||||||
new_result['cleaned_html']
|
|
||||||
)
|
)
|
||||||
if html_differences:
|
if html_differences:
|
||||||
differences['html'] = html_differences
|
differences["html"] = html_differences
|
||||||
|
|
||||||
return differences
|
return differences
|
||||||
|
|
||||||
@@ -535,10 +546,7 @@ class ScraperEquivalenceTester:
|
|||||||
# We'll still keep some "test_cases" logic from above (basic, complex, malformed).
|
# We'll still keep some "test_cases" logic from above (basic, complex, malformed).
|
||||||
# But we add a new section for the complicated HTML scenarios.
|
# But we add a new section for the complicated HTML scenarios.
|
||||||
|
|
||||||
results = {
|
results = {"tests": [], "summary": {"passed": 0, "failed": 0}}
|
||||||
'tests': [],
|
|
||||||
'summary': {'passed': 0, 'failed': 0}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 1) First, run the existing 3 built-in test cases (basic, complex, malformed).
|
# 1) First, run the existing 3 built-in test cases (basic, complex, malformed).
|
||||||
# for case_name, html in self.test_cases.items():
|
# for case_name, html in self.test_cases.items():
|
||||||
@@ -616,33 +624,38 @@ class ScraperEquivalenceTester:
|
|||||||
lxml_time = time.time() - start
|
lxml_time = time.time() - start
|
||||||
|
|
||||||
diffs = {}
|
diffs = {}
|
||||||
link_diff = self.deep_compare_links(orig_result['links'], lxml_result['links'])
|
link_diff = self.deep_compare_links(
|
||||||
|
orig_result["links"], lxml_result["links"]
|
||||||
|
)
|
||||||
if link_diff:
|
if link_diff:
|
||||||
diffs['links'] = link_diff
|
diffs["links"] = link_diff
|
||||||
|
|
||||||
media_diff = self.deep_compare_media(orig_result['media'], lxml_result['media'])
|
media_diff = self.deep_compare_media(
|
||||||
|
orig_result["media"], lxml_result["media"]
|
||||||
|
)
|
||||||
if media_diff:
|
if media_diff:
|
||||||
diffs['media'] = media_diff
|
diffs["media"] = media_diff
|
||||||
|
|
||||||
html_diff = self.compare_html_content(orig_result['cleaned_html'], lxml_result['cleaned_html'])
|
html_diff = self.compare_html_content(
|
||||||
|
orig_result["cleaned_html"], lxml_result["cleaned_html"]
|
||||||
|
)
|
||||||
if html_diff:
|
if html_diff:
|
||||||
diffs['html'] = html_diff
|
diffs["html"] = html_diff
|
||||||
|
|
||||||
test_result = {
|
test_result = {
|
||||||
'case': f"complicated_{scenario_name}",
|
"case": f"complicated_{scenario_name}",
|
||||||
'lxml_mode': {
|
"lxml_mode": {"differences": diffs, "execution_time": lxml_time},
|
||||||
'differences': diffs,
|
"original_time": orig_time,
|
||||||
'execution_time': lxml_time
|
|
||||||
},
|
|
||||||
'original_time': orig_time
|
|
||||||
}
|
}
|
||||||
results['tests'].append(test_result)
|
results["tests"].append(test_result)
|
||||||
|
|
||||||
if not diffs:
|
if not diffs:
|
||||||
results['summary']['passed'] += 1
|
results["summary"]["passed"] += 1
|
||||||
print(f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)")
|
print(
|
||||||
|
f"✅ [OK] No differences found. Time(Orig: {orig_time:.3f}s, LXML: {lxml_time:.3f}s)"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
results['summary']['failed'] += 1
|
results["summary"]["failed"] += 1
|
||||||
print("❌ Differences found:")
|
print("❌ Differences found:")
|
||||||
for category, dlist in diffs.items():
|
for category, dlist in diffs.items():
|
||||||
print(f" {category}:")
|
print(f" {category}:")
|
||||||
@@ -658,19 +671,21 @@ class ScraperEquivalenceTester:
|
|||||||
print(f"Passed: {results['summary']['passed']}")
|
print(f"Passed: {results['summary']['passed']}")
|
||||||
print(f"Failed: {results['summary']['failed']}")
|
print(f"Failed: {results['summary']['failed']}")
|
||||||
|
|
||||||
for test in results['tests']:
|
for test in results["tests"]:
|
||||||
print(f"\nTest Case: {test['case']}")
|
print(f"\nTest Case: {test['case']}")
|
||||||
|
|
||||||
if not test['lxml_mode']['differences']:
|
if not test["lxml_mode"]["differences"]:
|
||||||
print("✅ All implementations produced identical results")
|
print("✅ All implementations produced identical results")
|
||||||
print(f"Times - Original: {test['original_time']:.3f}s, "
|
print(
|
||||||
f"LXML: {test['lxml_mode']['execution_time']:.3f}s")
|
f"Times - Original: {test['original_time']:.3f}s, "
|
||||||
|
f"LXML: {test['lxml_mode']['execution_time']:.3f}s"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("❌ Differences found:")
|
print("❌ Differences found:")
|
||||||
|
|
||||||
if test['lxml_mode']['differences']:
|
if test["lxml_mode"]["differences"]:
|
||||||
print("\nLXML Mode Differences:")
|
print("\nLXML Mode Differences:")
|
||||||
for category, diffs in test['lxml_mode']['differences'].items():
|
for category, diffs in test["lxml_mode"]["differences"].items():
|
||||||
print(f"\n{category}:")
|
print(f"\n{category}:")
|
||||||
for diff in diffs:
|
for diff in diffs:
|
||||||
print(f" - {diff}")
|
print(f" - {diff}")
|
||||||
@@ -682,7 +697,7 @@ def main():
|
|||||||
tester.print_report(results)
|
tester.print_report(results)
|
||||||
|
|
||||||
# Save detailed results for debugging
|
# Save detailed results for debugging
|
||||||
with open('scraper_equivalence_results.json', 'w') as f:
|
with open("scraper_equivalence_results.json", "w") as f:
|
||||||
json.dump(results, f, indent=2)
|
json.dump(results, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
# - **State:** open
|
# - **State:** open
|
||||||
|
|
||||||
import os, sys, time
|
import os, sys, time
|
||||||
|
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
sys.path.append(parent_dir)
|
sys.path.append(parent_dir)
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
@@ -16,12 +16,12 @@ from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator
|
|||||||
# Get current directory
|
# Get current directory
|
||||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
|
||||||
def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
|
def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
|
||||||
"""Helper function to print test results."""
|
"""Helper function to print test results."""
|
||||||
print(f"\n{'='*20} {name} {'='*20}")
|
print(f"\n{'='*20} {name} {'='*20}")
|
||||||
print(f"Execution time: {execution_time:.4f} seconds")
|
print(f"Execution time: {execution_time:.4f} seconds")
|
||||||
|
|
||||||
|
|
||||||
# Save markdown to files
|
# Save markdown to files
|
||||||
for key, content in result.items():
|
for key, content in result.items():
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -36,6 +36,7 @@ def print_test_result(name: str, result: Dict[str, Any], execution_time: float):
|
|||||||
# print(preview)
|
# print(preview)
|
||||||
# print(f"Total length: {len(content)} characters")
|
# print(f"Total length: {len(content)} characters")
|
||||||
|
|
||||||
|
|
||||||
def test_basic_markdown_conversion():
|
def test_basic_markdown_conversion():
|
||||||
"""Test basic markdown conversion with links."""
|
"""Test basic markdown conversion with links."""
|
||||||
with open(__location__ + "/data/wikipedia.html", "r") as f:
|
with open(__location__ + "/data/wikipedia.html", "r") as f:
|
||||||
@@ -45,23 +46,29 @@ def test_basic_markdown_conversion():
|
|||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
result = generator.generate_markdown(
|
result = generator.generate_markdown(
|
||||||
cleaned_html=cleaned_html,
|
cleaned_html=cleaned_html, base_url="https://en.wikipedia.org"
|
||||||
base_url="https://en.wikipedia.org"
|
|
||||||
)
|
)
|
||||||
execution_time = time.perf_counter() - start_time
|
execution_time = time.perf_counter() - start_time
|
||||||
|
|
||||||
print_test_result("Basic Markdown Conversion", {
|
print_test_result(
|
||||||
'raw': result.raw_markdown,
|
"Basic Markdown Conversion",
|
||||||
'with_citations': result.markdown_with_citations,
|
{
|
||||||
'references': result.references_markdown
|
"raw": result.raw_markdown,
|
||||||
}, execution_time)
|
"with_citations": result.markdown_with_citations,
|
||||||
|
"references": result.references_markdown,
|
||||||
|
},
|
||||||
|
execution_time,
|
||||||
|
)
|
||||||
|
|
||||||
# Basic assertions
|
# Basic assertions
|
||||||
assert result.raw_markdown, "Raw markdown should not be empty"
|
assert result.raw_markdown, "Raw markdown should not be empty"
|
||||||
assert result.markdown_with_citations, "Markdown with citations should not be empty"
|
assert result.markdown_with_citations, "Markdown with citations should not be empty"
|
||||||
assert result.references_markdown, "References should not be empty"
|
assert result.references_markdown, "References should not be empty"
|
||||||
assert "⟨" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets"
|
assert "⟨" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets"
|
||||||
assert "## References" in result.references_markdown, "Should contain references section"
|
assert (
|
||||||
|
"## References" in result.references_markdown
|
||||||
|
), "Should contain references section"
|
||||||
|
|
||||||
|
|
||||||
def test_relative_links():
|
def test_relative_links():
|
||||||
"""Test handling of relative links with base URL."""
|
"""Test handling of relative links with base URL."""
|
||||||
@@ -72,14 +79,14 @@ def test_relative_links():
|
|||||||
|
|
||||||
generator = DefaultMarkdownGenerator()
|
generator = DefaultMarkdownGenerator()
|
||||||
result = generator.generate_markdown(
|
result = generator.generate_markdown(
|
||||||
cleaned_html=markdown,
|
cleaned_html=markdown, base_url="https://en.wikipedia.org"
|
||||||
base_url="https://en.wikipedia.org"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown
|
assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown
|
||||||
assert "https://example.com" in result.references_markdown
|
assert "https://example.com" in result.references_markdown
|
||||||
assert "https://en.wikipedia.org/images/test.png" in result.references_markdown
|
assert "https://en.wikipedia.org/images/test.png" in result.references_markdown
|
||||||
|
|
||||||
|
|
||||||
def test_duplicate_links():
|
def test_duplicate_links():
|
||||||
"""Test handling of duplicate links."""
|
"""Test handling of duplicate links."""
|
||||||
markdown = """
|
markdown = """
|
||||||
@@ -88,14 +95,14 @@ def test_duplicate_links():
|
|||||||
|
|
||||||
generator = DefaultMarkdownGenerator()
|
generator = DefaultMarkdownGenerator()
|
||||||
result = generator.generate_markdown(
|
result = generator.generate_markdown(
|
||||||
cleaned_html=markdown,
|
cleaned_html=markdown, base_url="https://example.com"
|
||||||
base_url="https://example.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Count citations in markdown
|
# Count citations in markdown
|
||||||
citations = result.markdown_with_citations.count("⟨1⟩")
|
citations = result.markdown_with_citations.count("⟨1⟩")
|
||||||
assert citations == 2, "Same link should use same citation number"
|
assert citations == 2, "Same link should use same citation number"
|
||||||
|
|
||||||
|
|
||||||
def test_link_descriptions():
|
def test_link_descriptions():
|
||||||
"""Test handling of link titles and descriptions."""
|
"""Test handling of link titles and descriptions."""
|
||||||
markdown = """
|
markdown = """
|
||||||
@@ -104,12 +111,16 @@ def test_link_descriptions():
|
|||||||
|
|
||||||
generator = DefaultMarkdownGenerator()
|
generator = DefaultMarkdownGenerator()
|
||||||
result = generator.generate_markdown(
|
result = generator.generate_markdown(
|
||||||
cleaned_html=markdown,
|
cleaned_html=markdown, base_url="https://example.com"
|
||||||
base_url="https://example.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "Test Title" in result.references_markdown, "Link title should be in references"
|
assert (
|
||||||
assert "link with description" in result.references_markdown, "Link text should be in references"
|
"Test Title" in result.references_markdown
|
||||||
|
), "Link title should be in references"
|
||||||
|
assert (
|
||||||
|
"link with description" in result.references_markdown
|
||||||
|
), "Link text should be in references"
|
||||||
|
|
||||||
|
|
||||||
def test_performance_large_document():
|
def test_performance_large_document():
|
||||||
"""Test performance with large document."""
|
"""Test performance with large document."""
|
||||||
@@ -125,18 +136,20 @@ def test_performance_large_document():
|
|||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
result = generator.generate_markdown(
|
result = generator.generate_markdown(
|
||||||
cleaned_html=markdown,
|
cleaned_html=markdown, base_url="https://en.wikipedia.org"
|
||||||
base_url="https://en.wikipedia.org"
|
|
||||||
)
|
)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
times.append(end_time - start_time)
|
times.append(end_time - start_time)
|
||||||
|
|
||||||
avg_time = sum(times) / len(times)
|
avg_time = sum(times) / len(times)
|
||||||
print(f"\n{'='*20} Performance Test {'='*20}")
|
print(f"\n{'='*20} Performance Test {'='*20}")
|
||||||
print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds")
|
print(
|
||||||
|
f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds"
|
||||||
|
)
|
||||||
print(f"Min time: {min(times):.4f} seconds")
|
print(f"Min time: {min(times):.4f} seconds")
|
||||||
print(f"Max time: {max(times):.4f} seconds")
|
print(f"Max time: {max(times):.4f} seconds")
|
||||||
|
|
||||||
|
|
||||||
def test_image_links():
|
def test_image_links():
|
||||||
"""Test handling of image links."""
|
"""Test handling of image links."""
|
||||||
markdown = """
|
markdown = """
|
||||||
@@ -146,12 +159,16 @@ def test_image_links():
|
|||||||
|
|
||||||
generator = DefaultMarkdownGenerator()
|
generator = DefaultMarkdownGenerator()
|
||||||
result = generator.generate_markdown(
|
result = generator.generate_markdown(
|
||||||
cleaned_html=markdown,
|
cleaned_html=markdown, base_url="https://example.com"
|
||||||
base_url="https://example.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved"
|
assert (
|
||||||
assert "Image Title" in result.references_markdown, "Image title should be in references"
|
"![" in result.markdown_with_citations
|
||||||
|
), "Image markdown syntax should be preserved"
|
||||||
|
assert (
|
||||||
|
"Image Title" in result.references_markdown
|
||||||
|
), "Image title should be in references"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Running markdown generation strategy tests...")
|
print("Running markdown generation strategy tests...")
|
||||||
@@ -162,4 +179,3 @@ if __name__ == "__main__":
|
|||||||
test_link_descriptions()
|
test_link_descriptions()
|
||||||
test_performance_large_document()
|
test_performance_large_document()
|
||||||
test_image_links()
|
test_image_links()
|
||||||
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -10,24 +8,37 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_word_count_threshold():
|
async def test_word_count_threshold():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
url = "https://www.nbcnews.com/business"
|
url = "https://www.nbcnews.com/business"
|
||||||
result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True)
|
result_no_threshold = await crawler.arun(
|
||||||
result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True)
|
url=url, word_count_threshold=0, bypass_cache=True
|
||||||
|
)
|
||||||
|
result_with_threshold = await crawler.arun(
|
||||||
|
url=url, word_count_threshold=50, bypass_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
|
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_css_selector():
|
async def test_css_selector():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
url = "https://www.nbcnews.com/business"
|
url = "https://www.nbcnews.com/business"
|
||||||
css_selector = "h1, h2, h3"
|
css_selector = "h1, h2, h3"
|
||||||
result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True)
|
result = await crawler.arun(
|
||||||
|
url=url, css_selector=css_selector, bypass_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
assert result.success
|
assert result.success
|
||||||
assert "<h1" in result.cleaned_html or "<h2" in result.cleaned_html or "<h3" in result.cleaned_html
|
assert (
|
||||||
|
"<h1" in result.cleaned_html
|
||||||
|
or "<h2" in result.cleaned_html
|
||||||
|
or "<h3" in result.cleaned_html
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_javascript_execution():
|
async def test_javascript_execution():
|
||||||
@@ -37,12 +48,15 @@ async def test_javascript_execution():
|
|||||||
# Crawl without JS
|
# Crawl without JS
|
||||||
result_without_more = await crawler.arun(url=url, bypass_cache=True)
|
result_without_more = await crawler.arun(url=url, bypass_cache=True)
|
||||||
|
|
||||||
js_code = ["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"]
|
js_code = [
|
||||||
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
|
]
|
||||||
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
|
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
|
||||||
|
|
||||||
assert result_with_more.success
|
assert result_with_more.success
|
||||||
assert len(result_with_more.markdown) > len(result_without_more.markdown)
|
assert len(result_with_more.markdown) > len(result_without_more.markdown)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot():
|
async def test_screenshot():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -53,16 +67,20 @@ async def test_screenshot():
|
|||||||
assert result.screenshot
|
assert result.screenshot
|
||||||
assert isinstance(result.screenshot, str) # Should be a base64 encoded string
|
assert isinstance(result.screenshot, str) # Should be a base64 encoded string
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_custom_user_agent():
|
async def test_custom_user_agent():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
url = "https://www.nbcnews.com/business"
|
url = "https://www.nbcnews.com/business"
|
||||||
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
|
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
|
||||||
result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True)
|
result = await crawler.arun(
|
||||||
|
url=url, user_agent=custom_user_agent, bypass_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
assert result.success
|
assert result.success
|
||||||
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
|
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_extract_media_and_links():
|
async def test_extract_media_and_links():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -72,10 +90,11 @@ async def test_extract_media_and_links():
|
|||||||
assert result.success
|
assert result.success
|
||||||
assert result.media
|
assert result.media
|
||||||
assert isinstance(result.media, dict)
|
assert isinstance(result.media, dict)
|
||||||
assert 'images' in result.media
|
assert "images" in result.media
|
||||||
assert result.links
|
assert result.links
|
||||||
assert isinstance(result.links, dict)
|
assert isinstance(result.links, dict)
|
||||||
assert 'internal' in result.links and 'external' in result.links
|
assert "internal" in result.links and "external" in result.links
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_metadata_extraction():
|
async def test_metadata_extraction():
|
||||||
@@ -87,7 +106,10 @@ async def test_metadata_extraction():
|
|||||||
assert result.metadata
|
assert result.metadata
|
||||||
assert isinstance(result.metadata, dict)
|
assert isinstance(result.metadata, dict)
|
||||||
# Check for common metadata fields
|
# Check for common metadata fields
|
||||||
assert any(key in result.metadata for key in ['title', 'description', 'keywords'])
|
assert any(
|
||||||
|
key in result.metadata for key in ["title", "description", "keywords"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# Add the parent directory to the Python path
|
# Add the parent directory to the Python path
|
||||||
@@ -10,6 +9,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_crawl_speed():
|
async def test_crawl_speed():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -24,6 +24,7 @@ async def test_crawl_speed():
|
|||||||
|
|
||||||
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
|
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_concurrent_crawling_performance():
|
async def test_concurrent_crawling_performance():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -32,7 +33,7 @@ async def test_concurrent_crawling_performance():
|
|||||||
"https://www.example.com",
|
"https://www.example.com",
|
||||||
"https://www.python.org",
|
"https://www.python.org",
|
||||||
"https://www.github.com",
|
"https://www.github.com",
|
||||||
"https://www.stackoverflow.com"
|
"https://www.stackoverflow.com",
|
||||||
]
|
]
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -45,7 +46,10 @@ async def test_concurrent_crawling_performance():
|
|||||||
assert all(result.success for result in results)
|
assert all(result.success for result in results)
|
||||||
assert len(results) == len(urls)
|
assert len(results) == len(urls)
|
||||||
|
|
||||||
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
assert (
|
||||||
|
total_time < len(urls) * 5
|
||||||
|
), f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_crawl_speed_with_caching():
|
async def test_crawl_speed_with_caching():
|
||||||
@@ -66,7 +70,10 @@ async def test_crawl_speed_with_caching():
|
|||||||
print(f"First crawl time: {first_crawl_time:.2f} seconds")
|
print(f"First crawl time: {first_crawl_time:.2f} seconds")
|
||||||
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
|
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
|
||||||
|
|
||||||
assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster"
|
assert (
|
||||||
|
second_crawl_time < first_crawl_time / 2
|
||||||
|
), "Cached crawl not significantly faster"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
@@ -12,6 +11,7 @@ sys.path.append(parent_dir)
|
|||||||
|
|
||||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_screenshot():
|
async def test_basic_screenshot():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -26,6 +26,7 @@ async def test_basic_screenshot():
|
|||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
assert image.format == "PNG"
|
assert image.format == "PNG"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_with_wait_for():
|
async def test_screenshot_with_wait_for():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -34,10 +35,7 @@ async def test_screenshot_with_wait_for():
|
|||||||
wait_for = "css:#content" # Wait for the main content to load
|
wait_for = "css:#content" # Wait for the main content to load
|
||||||
|
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url=url,
|
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
|
||||||
bypass_cache=True,
|
|
||||||
screenshot=True,
|
|
||||||
wait_for=wait_for
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.success
|
assert result.success
|
||||||
@@ -51,6 +49,7 @@ async def test_screenshot_with_wait_for():
|
|||||||
# You might want to add more specific checks here, like image dimensions
|
# You might want to add more specific checks here, like image dimensions
|
||||||
# or even use image recognition to verify certain elements are present
|
# or even use image recognition to verify certain elements are present
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_with_js_wait_for():
|
async def test_screenshot_with_js_wait_for():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -58,10 +57,7 @@ async def test_screenshot_with_js_wait_for():
|
|||||||
wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null"
|
wait_for = "js:() => document.querySelector('#nav-logo-sprites') !== null"
|
||||||
|
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(
|
||||||
url=url,
|
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
|
||||||
bypass_cache=True,
|
|
||||||
screenshot=True,
|
|
||||||
wait_for=wait_for
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.success
|
assert result.success
|
||||||
@@ -71,6 +67,7 @@ async def test_screenshot_with_js_wait_for():
|
|||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
assert image.format == "PNG"
|
assert image.format == "PNG"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_without_wait_for():
|
async def test_screenshot_without_wait_for():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -85,6 +82,7 @@ async def test_screenshot_without_wait_for():
|
|||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
assert image.format == "PNG"
|
assert image.format == "PNG"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_screenshot_comparison():
|
async def test_screenshot_comparison():
|
||||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||||
@@ -93,17 +91,12 @@ async def test_screenshot_comparison():
|
|||||||
|
|
||||||
# Take screenshot without wait_for
|
# Take screenshot without wait_for
|
||||||
result_without_wait = await crawler.arun(
|
result_without_wait = await crawler.arun(
|
||||||
url=url,
|
url=url, bypass_cache=True, screenshot=True
|
||||||
bypass_cache=True,
|
|
||||||
screenshot=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Take screenshot with wait_for
|
# Take screenshot with wait_for
|
||||||
result_with_wait = await crawler.arun(
|
result_with_wait = await crawler.arun(
|
||||||
url=url,
|
url=url, bypass_cache=True, screenshot=True, wait_for=wait_for
|
||||||
bypass_cache=True,
|
|
||||||
screenshot=True,
|
|
||||||
wait_for=wait_for
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result_without_wait.success and result_with_wait.success
|
assert result_without_wait.success and result_with_wait.success
|
||||||
@@ -111,14 +104,19 @@ async def test_screenshot_comparison():
|
|||||||
assert result_with_wait.screenshot is not None
|
assert result_with_wait.screenshot is not None
|
||||||
|
|
||||||
# Compare the two screenshots
|
# Compare the two screenshots
|
||||||
image_without_wait = Image.open(io.BytesIO(base64.b64decode(result_without_wait.screenshot)))
|
image_without_wait = Image.open(
|
||||||
image_with_wait = Image.open(io.BytesIO(base64.b64decode(result_with_wait.screenshot)))
|
io.BytesIO(base64.b64decode(result_without_wait.screenshot))
|
||||||
|
)
|
||||||
|
image_with_wait = Image.open(
|
||||||
|
io.BytesIO(base64.b64decode(result_with_wait.screenshot))
|
||||||
|
)
|
||||||
|
|
||||||
# This is a simple size comparison. In a real-world scenario, you might want to use
|
# This is a simple size comparison. In a real-world scenario, you might want to use
|
||||||
# more sophisticated image comparison techniques.
|
# more sophisticated image comparison techniques.
|
||||||
assert image_with_wait.size[0] >= image_without_wait.size[0]
|
assert image_with_wait.size[0] >= image_without_wait.size[0]
|
||||||
assert image_with_wait.size[1] >= image_without_wait.size[1]
|
assert image_with_wait.size[1] >= image_without_wait.size[1]
|
||||||
|
|
||||||
|
|
||||||
# Entry point for debugging
|
# Entry point for debugging
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
@@ -6,15 +6,24 @@ import base64
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class Crawl4AiTester:
|
class Crawl4AiTester:
|
||||||
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
|
def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.api_token = api_token or os.getenv('CRAWL4AI_API_TOKEN') # Check environment variable as fallback
|
self.api_token = api_token or os.getenv(
|
||||||
self.headers = {'Authorization': f'Bearer {self.api_token}'} if self.api_token else {}
|
"CRAWL4AI_API_TOKEN"
|
||||||
|
) # Check environment variable as fallback
|
||||||
|
self.headers = (
|
||||||
|
{"Authorization": f"Bearer {self.api_token}"} if self.api_token else {}
|
||||||
|
)
|
||||||
|
|
||||||
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
|
def submit_and_wait(
|
||||||
|
self, request_data: Dict[str, Any], timeout: int = 300
|
||||||
|
) -> Dict[str, Any]:
|
||||||
# Submit crawl job
|
# Submit crawl job
|
||||||
response = requests.post(f"{self.base_url}/crawl", json=request_data, headers=self.headers)
|
response = requests.post(
|
||||||
|
f"{self.base_url}/crawl", json=request_data, headers=self.headers
|
||||||
|
)
|
||||||
if response.status_code == 403:
|
if response.status_code == 403:
|
||||||
raise Exception("API token is invalid or missing")
|
raise Exception("API token is invalid or missing")
|
||||||
task_id = response.json()["task_id"]
|
task_id = response.json()["task_id"]
|
||||||
@@ -24,9 +33,13 @@ class Crawl4AiTester:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
raise TimeoutError(
|
||||||
|
f"Task {task_id} did not complete within {timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
result = requests.get(f"{self.base_url}/task/{task_id}", headers=self.headers)
|
result = requests.get(
|
||||||
|
f"{self.base_url}/task/{task_id}", headers=self.headers
|
||||||
|
)
|
||||||
status = result.json()
|
status = result.json()
|
||||||
|
|
||||||
if status["status"] == "failed":
|
if status["status"] == "failed":
|
||||||
@@ -39,17 +52,23 @@ class Crawl4AiTester:
|
|||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
response = requests.post(f"{self.base_url}/crawl_sync", json=request_data, headers=self.headers, timeout=60)
|
response = requests.post(
|
||||||
|
f"{self.base_url}/crawl_sync",
|
||||||
|
json=request_data,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
if response.status_code == 408:
|
if response.status_code == 408:
|
||||||
raise TimeoutError("Task did not complete within server timeout")
|
raise TimeoutError("Task did not complete within server timeout")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def test_docker_deployment(version="basic"):
|
def test_docker_deployment(version="basic"):
|
||||||
tester = Crawl4AiTester(
|
tester = Crawl4AiTester(
|
||||||
# base_url="http://localhost:11235" ,
|
# base_url="http://localhost:11235" ,
|
||||||
base_url="https://crawl4ai-sby74.ondigitalocean.app",
|
base_url="https://crawl4ai-sby74.ondigitalocean.app",
|
||||||
api_token="test"
|
api_token="test",
|
||||||
)
|
)
|
||||||
print(f"Testing Crawl4AI Docker {version} version")
|
print(f"Testing Crawl4AI Docker {version} version")
|
||||||
|
|
||||||
@@ -60,7 +79,7 @@ def test_docker_deployment(version="basic"):
|
|||||||
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
||||||
print("Health check:", health.json())
|
print("Health check:", health.json())
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException:
|
||||||
if i == max_retries - 1:
|
if i == max_retries - 1:
|
||||||
print(f"Failed to connect after {max_retries} attempts")
|
print(f"Failed to connect after {max_retries} attempts")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -88,7 +107,7 @@ def test_basic_crawl(tester: Crawl4AiTester):
|
|||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 10,
|
"priority": 10,
|
||||||
"session_id": "test"
|
"session_id": "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -96,19 +115,21 @@ def test_basic_crawl(tester: Crawl4AiTester):
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(result["result"]["markdown"]) > 0
|
assert len(result["result"]["markdown"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_basic_crawl_sync(tester: Crawl4AiTester):
|
def test_basic_crawl_sync(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Basic Crawl (Sync) ===")
|
print("\n=== Testing Basic Crawl (Sync) ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 10,
|
"priority": 10,
|
||||||
"session_id": "test"
|
"session_id": "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_sync(request)
|
result = tester.submit_sync(request)
|
||||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||||
assert result['status'] == 'completed'
|
assert result["status"] == "completed"
|
||||||
assert result['result']['success']
|
assert result["result"]["success"]
|
||||||
assert len(result['result']['markdown']) > 0
|
assert len(result["result"]["markdown"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_js_execution(tester: Crawl4AiTester):
|
def test_js_execution(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing JS Execution ===")
|
print("\n=== Testing JS Execution ===")
|
||||||
@@ -119,32 +140,29 @@ def test_js_execution(tester: Crawl4AiTester):
|
|||||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
],
|
],
|
||||||
"wait_for": "article.tease-card:nth-child(10)",
|
"wait_for": "article.tease-card:nth-child(10)",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
def test_css_selector(tester: Crawl4AiTester):
|
def test_css_selector(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing CSS Selector ===")
|
print("\n=== Testing CSS Selector ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 7,
|
"priority": 7,
|
||||||
"css_selector": ".wide-tease-item__description",
|
"css_selector": ".wide-tease-item__description",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
"extra": {"word_count_threshold": 10},
|
||||||
},
|
|
||||||
"extra": {"word_count_threshold": 10}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_extraction(tester: Crawl4AiTester):
|
def test_structured_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Structured Extraction ===")
|
print("\n=== Testing Structured Extraction ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -165,19 +183,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
|||||||
"name": "price",
|
"name": "price",
|
||||||
"selector": "td:nth-child(2)",
|
"selector": "td:nth-child(2)",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.coinbase.com/explore",
|
"urls": "https://www.coinbase.com/explore",
|
||||||
"priority": 9,
|
"priority": 9,
|
||||||
"extraction_config": {
|
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||||
"type": "json_css",
|
|
||||||
"params": {
|
|
||||||
"schema": schema
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -187,6 +200,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(extracted) > 0
|
assert len(extracted) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_llm_extraction(tester: Crawl4AiTester):
|
def test_llm_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing LLM Extraction ===")
|
print("\n=== Testing LLM Extraction ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -194,18 +208,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"model_name": {
|
"model_name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Name of the OpenAI model."
|
"description": "Name of the OpenAI model.",
|
||||||
},
|
},
|
||||||
"input_fee": {
|
"input_fee": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Fee for input token for the OpenAI model."
|
"description": "Fee for input token for the OpenAI model.",
|
||||||
},
|
},
|
||||||
"output_fee": {
|
"output_fee": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Fee for output token for the OpenAI model."
|
"description": "Fee for output token for the OpenAI model.",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["model_name", "input_fee", "output_fee"]
|
},
|
||||||
|
"required": ["model_name", "input_fee", "output_fee"],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -218,10 +232,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
"api_token": os.getenv("OPENAI_API_KEY"),
|
"api_token": os.getenv("OPENAI_API_KEY"),
|
||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
|
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"crawler_params": {"word_count_threshold": 1}
|
},
|
||||||
|
"crawler_params": {"word_count_threshold": 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -233,6 +247,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_llm_with_ollama(tester: Crawl4AiTester):
|
def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing LLM with Ollama ===")
|
print("\n=== Testing LLM with Ollama ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -240,18 +255,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"article_title": {
|
"article_title": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The main title of the news article"
|
"description": "The main title of the news article",
|
||||||
},
|
},
|
||||||
"summary": {
|
"summary": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A brief summary of the article content"
|
"description": "A brief summary of the article content",
|
||||||
},
|
},
|
||||||
"main_topics": {
|
"main_topics": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Main topics or themes discussed in the article"
|
"description": "Main topics or themes discussed in the article",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -263,11 +278,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
"provider": "ollama/llama2",
|
"provider": "ollama/llama2",
|
||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": "Extract the main article information including title, summary, and main topics."
|
"instruction": "Extract the main article information including title, summary, and main topics.",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"extra": {"word_count_threshold": 1},
|
"extra": {"word_count_threshold": 1},
|
||||||
"crawler_params": {"verbose": True}
|
"crawler_params": {"verbose": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -278,6 +293,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ollama extraction test failed: {str(e)}")
|
print(f"Ollama extraction test failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_cosine_extraction(tester: Crawl4AiTester):
|
def test_cosine_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Cosine Extraction ===")
|
print("\n=== Testing Cosine Extraction ===")
|
||||||
request = {
|
request = {
|
||||||
@@ -289,9 +305,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
|||||||
"semantic_filter": "business finance economy",
|
"semantic_filter": "business finance economy",
|
||||||
"word_count_threshold": 10,
|
"word_count_threshold": 10,
|
||||||
"max_dist": 0.2,
|
"max_dist": 0.2,
|
||||||
"top_k": 3
|
"top_k": 3,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -303,15 +319,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Cosine extraction test failed: {str(e)}")
|
print(f"Cosine extraction test failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_screenshot(tester: Crawl4AiTester):
|
def test_screenshot(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Screenshot ===")
|
print("\n=== Testing Screenshot ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"screenshot": True,
|
"screenshot": True,
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -326,6 +341,7 @@ def test_screenshot(tester: Crawl4AiTester):
|
|||||||
|
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
||||||
# version = "full"
|
# version = "full"
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
|
||||||
from crawl4ai.docs_manager import DocsManager
|
from crawl4ai.docs_manager import DocsManager
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
from crawl4ai.cli import cli
|
from crawl4ai.cli import cli
|
||||||
|
|
||||||
|
|
||||||
def test_cli():
|
def test_cli():
|
||||||
"""Test all CLI commands"""
|
"""Test all CLI commands"""
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
@@ -35,9 +35,10 @@ def test_cli():
|
|||||||
# print(f"First 200 chars: {result.output[:200]}...")
|
# print(f"First 200 chars: {result.output[:200]}...")
|
||||||
|
|
||||||
print("\n5. Testing combine all sections...")
|
print("\n5. Testing combine all sections...")
|
||||||
result = runner.invoke(cli, ['docs', 'combine', '--mode', 'condensed'])
|
result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"])
|
||||||
print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
|
print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
|
||||||
print(f"First 200 chars: {result.output[:200]}...")
|
print(f"First 200 chars: {result.output[:200]}...")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_cli()
|
test_cli()
|
||||||
@@ -6,11 +6,14 @@ import base64
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class Crawl4AiTester:
|
class Crawl4AiTester:
|
||||||
def __init__(self, base_url: str = "http://localhost:11235"):
|
def __init__(self, base_url: str = "http://localhost:11235"):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
|
||||||
def submit_and_wait(self, request_data: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
|
def submit_and_wait(
|
||||||
|
self, request_data: Dict[str, Any], timeout: int = 300
|
||||||
|
) -> Dict[str, Any]:
|
||||||
# Submit crawl job
|
# Submit crawl job
|
||||||
response = requests.post(f"{self.base_url}/crawl", json=request_data)
|
response = requests.post(f"{self.base_url}/crawl", json=request_data)
|
||||||
task_id = response.json()["task_id"]
|
task_id = response.json()["task_id"]
|
||||||
@@ -20,7 +23,9 @@ class Crawl4AiTester:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
raise TimeoutError(
|
||||||
|
f"Task {task_id} did not complete within {timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
result = requests.get(f"{self.base_url}/task/{task_id}")
|
result = requests.get(f"{self.base_url}/task/{task_id}")
|
||||||
status = result.json()
|
status = result.json()
|
||||||
@@ -34,6 +39,7 @@ class Crawl4AiTester:
|
|||||||
|
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
def test_docker_deployment(version="basic"):
|
def test_docker_deployment(version="basic"):
|
||||||
tester = Crawl4AiTester()
|
tester = Crawl4AiTester()
|
||||||
print(f"Testing Crawl4AI Docker {version} version")
|
print(f"Testing Crawl4AI Docker {version} version")
|
||||||
@@ -45,7 +51,7 @@ def test_docker_deployment(version="basic"):
|
|||||||
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
health = requests.get(f"{tester.base_url}/health", timeout=10)
|
||||||
print("Health check:", health.json())
|
print("Health check:", health.json())
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException:
|
||||||
if i == max_retries - 1:
|
if i == max_retries - 1:
|
||||||
print(f"Failed to connect after {max_retries} attempts")
|
print(f"Failed to connect after {max_retries} attempts")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -68,16 +74,14 @@ def test_docker_deployment(version="basic"):
|
|||||||
|
|
||||||
def test_basic_crawl(tester: Crawl4AiTester):
|
def test_basic_crawl(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Basic Crawl ===")
|
print("\n=== Testing Basic Crawl ===")
|
||||||
request = {
|
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
|
||||||
"urls": "https://www.nbcnews.com/business",
|
|
||||||
"priority": 10
|
|
||||||
}
|
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(result["result"]["markdown"]) > 0
|
assert len(result["result"]["markdown"]) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_js_execution(tester: Crawl4AiTester):
|
def test_js_execution(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing JS Execution ===")
|
print("\n=== Testing JS Execution ===")
|
||||||
request = {
|
request = {
|
||||||
@@ -87,32 +91,29 @@ def test_js_execution(tester: Crawl4AiTester):
|
|||||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
],
|
],
|
||||||
"wait_for": "article.tease-card:nth-child(10)",
|
"wait_for": "article.tease-card:nth-child(10)",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
print(f"JS execution result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
def test_css_selector(tester: Crawl4AiTester):
|
def test_css_selector(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing CSS Selector ===")
|
print("\n=== Testing CSS Selector ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 7,
|
"priority": 7,
|
||||||
"css_selector": ".wide-tease-item__description",
|
"css_selector": ".wide-tease-item__description",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
"extra": {"word_count_threshold": 10},
|
||||||
},
|
|
||||||
"extra": {"word_count_threshold": 10}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
print(f"CSS selector result length: {len(result['result']['markdown'])}")
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_extraction(tester: Crawl4AiTester):
|
def test_structured_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Structured Extraction ===")
|
print("\n=== Testing Structured Extraction ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -133,19 +134,14 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
|||||||
"name": "price",
|
"name": "price",
|
||||||
"selector": "td:nth-child(2)",
|
"selector": "td:nth-child(2)",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.coinbase.com/explore",
|
"urls": "https://www.coinbase.com/explore",
|
||||||
"priority": 9,
|
"priority": 9,
|
||||||
"extraction_config": {
|
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||||
"type": "json_css",
|
|
||||||
"params": {
|
|
||||||
"schema": schema
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -155,6 +151,7 @@ def test_structured_extraction(tester: Crawl4AiTester):
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(extracted) > 0
|
assert len(extracted) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_llm_extraction(tester: Crawl4AiTester):
|
def test_llm_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing LLM Extraction ===")
|
print("\n=== Testing LLM Extraction ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -162,18 +159,18 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"model_name": {
|
"model_name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Name of the OpenAI model."
|
"description": "Name of the OpenAI model.",
|
||||||
},
|
},
|
||||||
"input_fee": {
|
"input_fee": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Fee for input token for the OpenAI model."
|
"description": "Fee for input token for the OpenAI model.",
|
||||||
},
|
},
|
||||||
"output_fee": {
|
"output_fee": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Fee for output token for the OpenAI model."
|
"description": "Fee for output token for the OpenAI model.",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["model_name", "input_fee", "output_fee"]
|
},
|
||||||
|
"required": ["model_name", "input_fee", "output_fee"],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -186,10 +183,10 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
"api_token": os.getenv("OPENAI_API_KEY"),
|
"api_token": os.getenv("OPENAI_API_KEY"),
|
||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens."""
|
"instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"crawler_params": {"word_count_threshold": 1}
|
},
|
||||||
|
"crawler_params": {"word_count_threshold": 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -201,6 +198,7 @@ def test_llm_extraction(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_llm_with_ollama(tester: Crawl4AiTester):
|
def test_llm_with_ollama(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing LLM with Ollama ===")
|
print("\n=== Testing LLM with Ollama ===")
|
||||||
schema = {
|
schema = {
|
||||||
@@ -208,18 +206,18 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"article_title": {
|
"article_title": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The main title of the news article"
|
"description": "The main title of the news article",
|
||||||
},
|
},
|
||||||
"summary": {
|
"summary": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A brief summary of the article content"
|
"description": "A brief summary of the article content",
|
||||||
},
|
},
|
||||||
"main_topics": {
|
"main_topics": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Main topics or themes discussed in the article"
|
"description": "Main topics or themes discussed in the article",
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -231,11 +229,11 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
"provider": "ollama/llama2",
|
"provider": "ollama/llama2",
|
||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": "Extract the main article information including title, summary, and main topics."
|
"instruction": "Extract the main article information including title, summary, and main topics.",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"extra": {"word_count_threshold": 1},
|
"extra": {"word_count_threshold": 1},
|
||||||
"crawler_params": {"verbose": True}
|
"crawler_params": {"verbose": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -246,6 +244,7 @@ def test_llm_with_ollama(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ollama extraction test failed: {str(e)}")
|
print(f"Ollama extraction test failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_cosine_extraction(tester: Crawl4AiTester):
|
def test_cosine_extraction(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Cosine Extraction ===")
|
print("\n=== Testing Cosine Extraction ===")
|
||||||
request = {
|
request = {
|
||||||
@@ -257,9 +256,9 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
|||||||
"semantic_filter": "business finance economy",
|
"semantic_filter": "business finance economy",
|
||||||
"word_count_threshold": 10,
|
"word_count_threshold": 10,
|
||||||
"max_dist": 0.2,
|
"max_dist": 0.2,
|
||||||
"top_k": 3
|
"top_k": 3,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -271,15 +270,14 @@ def test_cosine_extraction(tester: Crawl4AiTester):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Cosine extraction test failed: {str(e)}")
|
print(f"Cosine extraction test failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_screenshot(tester: Crawl4AiTester):
|
def test_screenshot(tester: Crawl4AiTester):
|
||||||
print("\n=== Testing Screenshot ===")
|
print("\n=== Testing Screenshot ===")
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"screenshot": True,
|
"screenshot": True,
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result = tester.submit_and_wait(request)
|
result = tester.submit_and_wait(request)
|
||||||
@@ -294,6 +292,7 @@ def test_screenshot(tester: Crawl4AiTester):
|
|||||||
|
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
version = sys.argv[1] if len(sys.argv) > 1 else "basic"
|
||||||
# version = "full"
|
# version = "full"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from crawl4ai.async_logger import AsyncLogger
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
current_file = Path(__file__).resolve()
|
current_file = Path(__file__).resolve()
|
||||||
# base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
|
# base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
|
||||||
@@ -26,8 +27,7 @@ async def main():
|
|||||||
# Generate index files
|
# Generate index files
|
||||||
print("\nGenerating index files...")
|
print("\nGenerating index files...")
|
||||||
await manager.generate_index_files(
|
await manager.generate_index_files(
|
||||||
force_generate_facts=False,
|
force_generate_facts=False, clear_bm25_cache=False
|
||||||
clear_bm25_cache=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test some relevant queries about Crawl4AI
|
# Test some relevant queries about Crawl4AI
|
||||||
@@ -41,9 +41,12 @@ async def main():
|
|||||||
results = manager.search(query, top_k=2)
|
results = manager.search(query, top_k=2)
|
||||||
print(f"Results length: {len(results)} characters")
|
print(f"Results length: {len(results)} characters")
|
||||||
if results:
|
if results:
|
||||||
print("First 200 chars of results:", results[:200].replace('\n', ' '), "...")
|
print(
|
||||||
|
"First 200 chars of results:", results[:200].replace("\n", " "), "..."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("No results found")
|
print("No results found")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -3,8 +3,8 @@ import aiohttp
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Dict, Any
|
from typing import Dict, Any
|
||||||
from pydantic import BaseModel, HttpUrl
|
|
||||||
|
|
||||||
class NBCNewsAPITest:
|
class NBCNewsAPITest:
|
||||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||||
@@ -20,7 +20,9 @@ class NBCNewsAPITest:
|
|||||||
await self.session.close()
|
await self.session.close()
|
||||||
|
|
||||||
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
|
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
|
||||||
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response:
|
async with self.session.post(
|
||||||
|
f"{self.base_url}/crawl", json=request_data
|
||||||
|
) as response:
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
return result["task_id"]
|
return result["task_id"]
|
||||||
|
|
||||||
@@ -28,11 +30,15 @@ class NBCNewsAPITest:
|
|||||||
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
|
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]:
|
async def wait_for_task(
|
||||||
|
self, task_id: str, timeout: int = 300, poll_interval: int = 2
|
||||||
|
) -> Dict[str, Any]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
|
raise TimeoutError(
|
||||||
|
f"Task {task_id} did not complete within {timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
status = await self.get_task_status(task_id)
|
status = await self.get_task_status(task_id)
|
||||||
if status["status"] in ["completed", "failed"]:
|
if status["status"] in ["completed", "failed"]:
|
||||||
@@ -44,13 +50,11 @@ class NBCNewsAPITest:
|
|||||||
async with self.session.get(f"{self.base_url}/health") as response:
|
async with self.session.get(f"{self.base_url}/health") as response:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def test_basic_crawl():
|
async def test_basic_crawl():
|
||||||
print("\n=== Testing Basic Crawl ===")
|
print("\n=== Testing Basic Crawl ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
request = {
|
request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
|
||||||
"urls": "https://www.nbcnews.com/business",
|
|
||||||
"priority": 10
|
|
||||||
}
|
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
result = await api.wait_for_task(task_id)
|
result = await api.wait_for_task(task_id)
|
||||||
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
|
||||||
@@ -58,6 +62,7 @@ async def test_basic_crawl():
|
|||||||
assert "result" in result
|
assert "result" in result
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
async def test_js_execution():
|
async def test_js_execution():
|
||||||
print("\n=== Testing JS Execution ===")
|
print("\n=== Testing JS Execution ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
@@ -68,9 +73,7 @@ async def test_js_execution():
|
|||||||
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
|
||||||
],
|
],
|
||||||
"wait_for": "article.tease-card:nth-child(10)",
|
"wait_for": "article.tease-card:nth-child(10)",
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
result = await api.wait_for_task(task_id)
|
result = await api.wait_for_task(task_id)
|
||||||
@@ -78,13 +81,14 @@ async def test_js_execution():
|
|||||||
assert result["status"] == "completed"
|
assert result["status"] == "completed"
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
async def test_css_selector():
|
async def test_css_selector():
|
||||||
print("\n=== Testing CSS Selector ===")
|
print("\n=== Testing CSS Selector ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 7,
|
"priority": 7,
|
||||||
"css_selector": ".wide-tease-item__description"
|
"css_selector": ".wide-tease-item__description",
|
||||||
}
|
}
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
result = await api.wait_for_task(task_id)
|
result = await api.wait_for_task(task_id)
|
||||||
@@ -92,6 +96,7 @@ async def test_css_selector():
|
|||||||
assert result["status"] == "completed"
|
assert result["status"] == "completed"
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
async def test_structured_extraction():
|
async def test_structured_extraction():
|
||||||
print("\n=== Testing Structured Extraction ===")
|
print("\n=== Testing Structured Extraction ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
@@ -99,34 +104,25 @@ async def test_structured_extraction():
|
|||||||
"name": "NBC News Articles",
|
"name": "NBC News Articles",
|
||||||
"baseSelector": "article.tease-card",
|
"baseSelector": "article.tease-card",
|
||||||
"fields": [
|
"fields": [
|
||||||
{
|
{"name": "title", "selector": "h2", "type": "text"},
|
||||||
"name": "title",
|
|
||||||
"selector": "h2",
|
|
||||||
"type": "text"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "description",
|
"name": "description",
|
||||||
"selector": ".tease-card__description",
|
"selector": ".tease-card__description",
|
||||||
"type": "text"
|
"type": "text",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "link",
|
"name": "link",
|
||||||
"selector": "a",
|
"selector": "a",
|
||||||
"type": "attribute",
|
"type": "attribute",
|
||||||
"attribute": "href"
|
"attribute": "href",
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 9,
|
"priority": 9,
|
||||||
"extraction_config": {
|
"extraction_config": {"type": "json_css", "params": {"schema": schema}},
|
||||||
"type": "json_css",
|
|
||||||
"params": {
|
|
||||||
"schema": schema
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
result = await api.wait_for_task(task_id)
|
result = await api.wait_for_task(task_id)
|
||||||
@@ -136,6 +132,7 @@ async def test_structured_extraction():
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert len(extracted) > 0
|
assert len(extracted) > 0
|
||||||
|
|
||||||
|
|
||||||
async def test_batch_crawl():
|
async def test_batch_crawl():
|
||||||
print("\n=== Testing Batch Crawl ===")
|
print("\n=== Testing Batch Crawl ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
@@ -143,12 +140,10 @@ async def test_batch_crawl():
|
|||||||
"urls": [
|
"urls": [
|
||||||
"https://www.nbcnews.com/business",
|
"https://www.nbcnews.com/business",
|
||||||
"https://www.nbcnews.com/business/consumer",
|
"https://www.nbcnews.com/business/consumer",
|
||||||
"https://www.nbcnews.com/business/economy"
|
"https://www.nbcnews.com/business/economy",
|
||||||
],
|
],
|
||||||
"priority": 6,
|
"priority": 6,
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
result = await api.wait_for_task(task_id)
|
result = await api.wait_for_task(task_id)
|
||||||
@@ -157,6 +152,7 @@ async def test_batch_crawl():
|
|||||||
assert "results" in result
|
assert "results" in result
|
||||||
assert len(result["results"]) == 3
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
|
||||||
async def test_llm_extraction():
|
async def test_llm_extraction():
|
||||||
print("\n=== Testing LLM Extraction with Ollama ===")
|
print("\n=== Testing LLM Extraction with Ollama ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
@@ -165,19 +161,19 @@ async def test_llm_extraction():
|
|||||||
"properties": {
|
"properties": {
|
||||||
"article_title": {
|
"article_title": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The main title of the news article"
|
"description": "The main title of the news article",
|
||||||
},
|
},
|
||||||
"summary": {
|
"summary": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A brief summary of the article content"
|
"description": "A brief summary of the article content",
|
||||||
},
|
},
|
||||||
"main_topics": {
|
"main_topics": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Main topics or themes discussed in the article"
|
"description": "Main topics or themes discussed in the article",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["article_title", "summary", "main_topics"]
|
},
|
||||||
|
"required": ["article_title", "summary", "main_topics"],
|
||||||
}
|
}
|
||||||
|
|
||||||
request = {
|
request = {
|
||||||
@@ -191,13 +187,10 @@ async def test_llm_extraction():
|
|||||||
"schema": schema,
|
"schema": schema,
|
||||||
"extraction_type": "schema",
|
"extraction_type": "schema",
|
||||||
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
|
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
|
||||||
Focus on the primary business news article on the page."""
|
Focus on the primary business news article on the page.""",
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"crawler_params": {
|
},
|
||||||
"headless": True,
|
"crawler_params": {"headless": True, "word_count_threshold": 1},
|
||||||
"word_count_threshold": 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
@@ -205,12 +198,13 @@ async def test_llm_extraction():
|
|||||||
|
|
||||||
if result["status"] == "completed":
|
if result["status"] == "completed":
|
||||||
extracted = json.loads(result["result"]["extracted_content"])
|
extracted = json.loads(result["result"]["extracted_content"])
|
||||||
print(f"Extracted article analysis:")
|
print("Extracted article analysis:")
|
||||||
print(json.dumps(extracted, indent=2))
|
print(json.dumps(extracted, indent=2))
|
||||||
|
|
||||||
assert result["status"] == "completed"
|
assert result["status"] == "completed"
|
||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
|
|
||||||
|
|
||||||
async def test_screenshot():
|
async def test_screenshot():
|
||||||
print("\n=== Testing Screenshot ===")
|
print("\n=== Testing Screenshot ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
@@ -218,9 +212,7 @@ async def test_screenshot():
|
|||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 5,
|
"priority": 5,
|
||||||
"screenshot": True,
|
"screenshot": True,
|
||||||
"crawler_params": {
|
"crawler_params": {"headless": True},
|
||||||
"headless": True
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
task_id = await api.submit_crawl(request)
|
task_id = await api.submit_crawl(request)
|
||||||
result = await api.wait_for_task(task_id)
|
result = await api.wait_for_task(task_id)
|
||||||
@@ -229,6 +221,7 @@ async def test_screenshot():
|
|||||||
assert result["result"]["success"]
|
assert result["result"]["success"]
|
||||||
assert result["result"]["screenshot"] is not None
|
assert result["result"]["screenshot"] is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_priority_handling():
|
async def test_priority_handling():
|
||||||
print("\n=== Testing Priority Handling ===")
|
print("\n=== Testing Priority Handling ===")
|
||||||
async with NBCNewsAPITest() as api:
|
async with NBCNewsAPITest() as api:
|
||||||
@@ -236,7 +229,7 @@ async def test_priority_handling():
|
|||||||
low_priority = {
|
low_priority = {
|
||||||
"urls": "https://www.nbcnews.com/business",
|
"urls": "https://www.nbcnews.com/business",
|
||||||
"priority": 1,
|
"priority": 1,
|
||||||
"crawler_params": {"headless": True}
|
"crawler_params": {"headless": True},
|
||||||
}
|
}
|
||||||
low_task_id = await api.submit_crawl(low_priority)
|
low_task_id = await api.submit_crawl(low_priority)
|
||||||
|
|
||||||
@@ -244,7 +237,7 @@ async def test_priority_handling():
|
|||||||
high_priority = {
|
high_priority = {
|
||||||
"urls": "https://www.nbcnews.com/business/consumer",
|
"urls": "https://www.nbcnews.com/business/consumer",
|
||||||
"priority": 10,
|
"priority": 10,
|
||||||
"crawler_params": {"headless": True}
|
"crawler_params": {"headless": True},
|
||||||
}
|
}
|
||||||
high_task_id = await api.submit_crawl(high_priority)
|
high_task_id = await api.submit_crawl(high_priority)
|
||||||
|
|
||||||
@@ -256,6 +249,7 @@ async def test_priority_handling():
|
|||||||
assert high_result["status"] == "completed"
|
assert high_result["status"] == "completed"
|
||||||
assert low_result["status"] == "completed"
|
assert low_result["status"] == "completed"
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
try:
|
try:
|
||||||
# Start with health check
|
# Start with health check
|
||||||
@@ -277,5 +271,6 @@ async def main():
|
|||||||
print(f"Test failed: {str(e)}")
|
print(f"Test failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
@@ -1,21 +1,26 @@
|
|||||||
import nest_asyncio
|
import nest_asyncio
|
||||||
|
|
||||||
nest_asyncio.apply()
|
nest_asyncio.apply()
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from crawl4ai import AsyncWebCrawler, CrawlerRunConfig, LXMLWebScrapingStrategy, CacheMode
|
from crawl4ai import (
|
||||||
|
AsyncWebCrawler,
|
||||||
|
CrawlerRunConfig,
|
||||||
|
LXMLWebScrapingStrategy,
|
||||||
|
CacheMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
config = CrawlerRunConfig(
|
config = CrawlerRunConfig(
|
||||||
cache_mode=CacheMode.BYPASS,
|
cache_mode=CacheMode.BYPASS,
|
||||||
scraping_strategy=LXMLWebScrapingStrategy() # Faster alternative to default BeautifulSoup
|
scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup
|
||||||
)
|
)
|
||||||
async with AsyncWebCrawler() as crawler:
|
async with AsyncWebCrawler() as crawler:
|
||||||
result = await crawler.arun(
|
result = await crawler.arun(url="https://example.com", config=config)
|
||||||
url="https://example.com",
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
print(f"Success: {result.success}")
|
print(f"Success: {result.success}")
|
||||||
print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}")
|
print(f"Markdown length: {len(result.markdown_v2.raw_markdown)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
import unittest, os
|
import unittest, os
|
||||||
from crawl4ai.web_crawler import WebCrawler
|
from crawl4ai.web_crawler import WebCrawler
|
||||||
from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking
|
from crawl4ai.chunking_strategy import (
|
||||||
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy
|
RegexChunking,
|
||||||
|
FixedLengthWordChunking,
|
||||||
|
SlidingWindowChunking,
|
||||||
|
)
|
||||||
|
from crawl4ai.extraction_strategy import (
|
||||||
|
CosineStrategy,
|
||||||
|
LLMExtractionStrategy,
|
||||||
|
TopicExtractionStrategy,
|
||||||
|
NoExtractionStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestWebCrawler(unittest.TestCase):
|
class TestWebCrawler(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.crawler = WebCrawler()
|
self.crawler = WebCrawler()
|
||||||
|
|
||||||
@@ -14,52 +23,72 @@ class TestWebCrawler(unittest.TestCase):
|
|||||||
|
|
||||||
def test_run_default_strategies(self):
|
def test_run_default_strategies(self):
|
||||||
result = self.crawler.run(
|
result = self.crawler.run(
|
||||||
url='https://www.nbcnews.com/business',
|
url="https://www.nbcnews.com/business",
|
||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
chunking_strategy=RegexChunking(),
|
chunking_strategy=RegexChunking(),
|
||||||
extraction_strategy=CosineStrategy(), bypass_cache=True
|
extraction_strategy=CosineStrategy(),
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
result.success, "Failed to crawl and extract using default strategies"
|
||||||
)
|
)
|
||||||
self.assertTrue(result.success, "Failed to crawl and extract using default strategies")
|
|
||||||
|
|
||||||
def test_run_different_strategies(self):
|
def test_run_different_strategies(self):
|
||||||
url = 'https://www.nbcnews.com/business'
|
url = "https://www.nbcnews.com/business"
|
||||||
|
|
||||||
# Test with FixedLengthWordChunking and LLMExtractionStrategy
|
# Test with FixedLengthWordChunking and LLMExtractionStrategy
|
||||||
result = self.crawler.run(
|
result = self.crawler.run(
|
||||||
url=url,
|
url=url,
|
||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
|
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
|
||||||
extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True
|
extraction_strategy=LLMExtractionStrategy(
|
||||||
|
provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY")
|
||||||
|
),
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
result.success,
|
||||||
|
"Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy",
|
||||||
)
|
)
|
||||||
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy")
|
|
||||||
|
|
||||||
# Test with SlidingWindowChunking and TopicExtractionStrategy
|
# Test with SlidingWindowChunking and TopicExtractionStrategy
|
||||||
result = self.crawler.run(
|
result = self.crawler.run(
|
||||||
url=url,
|
url=url,
|
||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
|
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
|
||||||
extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True
|
extraction_strategy=TopicExtractionStrategy(num_keywords=5),
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
result.success,
|
||||||
|
"Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy",
|
||||||
)
|
)
|
||||||
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy")
|
|
||||||
|
|
||||||
def test_invalid_url(self):
|
def test_invalid_url(self):
|
||||||
with self.assertRaises(Exception) as context:
|
with self.assertRaises(Exception) as context:
|
||||||
self.crawler.run(url='invalid_url', bypass_cache=True)
|
self.crawler.run(url="invalid_url", bypass_cache=True)
|
||||||
self.assertIn("Invalid URL", str(context.exception))
|
self.assertIn("Invalid URL", str(context.exception))
|
||||||
|
|
||||||
def test_unsupported_extraction_strategy(self):
|
def test_unsupported_extraction_strategy(self):
|
||||||
with self.assertRaises(Exception) as context:
|
with self.assertRaises(Exception) as context:
|
||||||
self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True)
|
self.crawler.run(
|
||||||
|
url="https://www.nbcnews.com/business",
|
||||||
|
extraction_strategy="UnsupportedStrategy",
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
self.assertIn("Unsupported extraction strategy", str(context.exception))
|
self.assertIn("Unsupported extraction strategy", str(context.exception))
|
||||||
|
|
||||||
def test_invalid_css_selector(self):
|
def test_invalid_css_selector(self):
|
||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError) as context:
|
||||||
self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True)
|
self.crawler.run(
|
||||||
|
url="https://www.nbcnews.com/business",
|
||||||
|
css_selector="invalid_selector",
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
self.assertIn("Invalid CSS selector", str(context.exception))
|
self.assertIn("Invalid CSS selector", str(context.exception))
|
||||||
|
|
||||||
|
|
||||||
def test_crawl_with_cache_and_bypass_cache(self):
|
def test_crawl_with_cache_and_bypass_cache(self):
|
||||||
url = 'https://www.nbcnews.com/business'
|
url = "https://www.nbcnews.com/business"
|
||||||
|
|
||||||
# First crawl with cache enabled
|
# First crawl with cache enabled
|
||||||
result = self.crawler.run(url=url, bypass_cache=False)
|
result = self.crawler.run(url=url, bypass_cache=False)
|
||||||
@@ -70,10 +99,7 @@ class TestWebCrawler(unittest.TestCase):
|
|||||||
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
|
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
|
||||||
|
|
||||||
def test_fetch_multiple_pages(self):
|
def test_fetch_multiple_pages(self):
|
||||||
urls = [
|
urls = ["https://www.nbcnews.com/business", "https://www.bbc.com/news"]
|
||||||
'https://www.nbcnews.com/business',
|
|
||||||
'https://www.bbc.com/news'
|
|
||||||
]
|
|
||||||
results = []
|
results = []
|
||||||
for url in urls:
|
for url in urls:
|
||||||
result = self.crawler.run(
|
result = self.crawler.run(
|
||||||
@@ -81,31 +107,42 @@ class TestWebCrawler(unittest.TestCase):
|
|||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
chunking_strategy=RegexChunking(),
|
chunking_strategy=RegexChunking(),
|
||||||
extraction_strategy=CosineStrategy(),
|
extraction_strategy=CosineStrategy(),
|
||||||
bypass_cache=True
|
bypass_cache=True,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages")
|
self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages")
|
||||||
for result in results:
|
for result in results:
|
||||||
self.assertTrue(result.success, "Failed to crawl and extract a page in the list")
|
self.assertTrue(
|
||||||
|
result.success, "Failed to crawl and extract a page in the list"
|
||||||
|
)
|
||||||
|
|
||||||
def test_run_fixed_length_word_chunking_and_no_extraction(self):
|
def test_run_fixed_length_word_chunking_and_no_extraction(self):
|
||||||
result = self.crawler.run(
|
result = self.crawler.run(
|
||||||
url='https://www.nbcnews.com/business',
|
url="https://www.nbcnews.com/business",
|
||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
|
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
|
||||||
extraction_strategy=NoExtractionStrategy(), bypass_cache=True
|
extraction_strategy=NoExtractionStrategy(),
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
result.success,
|
||||||
|
"Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy",
|
||||||
)
|
)
|
||||||
self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy")
|
|
||||||
|
|
||||||
def test_run_sliding_window_and_no_extraction(self):
|
def test_run_sliding_window_and_no_extraction(self):
|
||||||
result = self.crawler.run(
|
result = self.crawler.run(
|
||||||
url='https://www.nbcnews.com/business',
|
url="https://www.nbcnews.com/business",
|
||||||
word_count_threshold=5,
|
word_count_threshold=5,
|
||||||
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
|
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
|
||||||
extraction_strategy=NoExtractionStrategy(), bypass_cache=True
|
extraction_strategy=NoExtractionStrategy(),
|
||||||
|
bypass_cache=True,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
result.success,
|
||||||
|
"Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy",
|
||||||
)
|
)
|
||||||
self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user