diff --git a/crawl4ai/__init__.py b/crawl4ai/__init__.py index ad9475b4..0ccf13d8 100644 --- a/crawl4ai/__init__.py +++ b/crawl4ai/__init__.py @@ -1,6 +1,7 @@ # __init__.py from .async_webcrawler import AsyncWebCrawler, CacheMode + from .models import CrawlResult from .__version__ import __version__ # __version__ = "0.3.73" diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index e7dc9c54..3f332eb0 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -229,6 +229,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): self.headless = kwargs.get("headless", True) self.browser_type = kwargs.get("browser_type", "chromium") self.headers = kwargs.get("headers", {}) + self.cookies = kwargs.get("cookies", []) self.sessions = {} self.session_ttl = 1800 self.js_code = js_code @@ -295,6 +296,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Set up the default context if self.default_context: await self.default_context.set_extra_http_headers(self.headers) + if self.cookies: + await self.default_context.add_cookies(self.cookies) if self.accept_downloads: await self.default_context.set_default_timeout(60000) await self.default_context.set_default_navigation_timeout(60000) @@ -669,6 +672,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # downloads_path=self.downloads_path if self.accept_downloads else None ) await context.add_cookies([{"name": "cookiesEnabled", "value": "true", "url": url}]) + if self.cookies: + await context.add_cookies(self.cookies) await context.set_extra_http_headers(self.headers) page = await context.new_page() self.sessions[session_id] = (context, page, time.time()) @@ -684,6 +689,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): proxy={"server": self.proxy} if self.proxy else None, accept_downloads=self.accept_downloads, ) + if self.cookies: + await context.add_cookies(self.cookies) await context.set_extra_http_headers(self.headers) if kwargs.get("override_navigator", False) or kwargs.get("simulate_user", False) or kwargs.get("magic", False): @@ -828,7 +835,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): for js in js_code: await page.evaluate(js) - await page.wait_for_load_state('networkidle') + # await page.wait_for_timeout(100) + # Check for on execution event await self.execute_hook('on_execution_started', page) @@ -846,6 +854,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): await self.smart_wait(page, wait_for, timeout=kwargs.get("page_timeout", 60000)) except Exception as e: raise RuntimeError(f"Wait condition failed: {str(e)}") + + # if not wait_for and js_code: + # await page.wait_for_load_state('networkidle', timeout=5000) # Update image dimensions update_image_dimensions_js = """ diff --git a/crawl4ai/async_database.3.73.py b/crawl4ai/async_database.3.73.py deleted file mode 100644 index f86c7f1d..00000000 --- a/crawl4ai/async_database.3.73.py +++ /dev/null @@ -1,285 +0,0 @@ -import os -from pathlib import Path -import aiosqlite -import asyncio -from typing import Optional, Tuple, Dict -from contextlib import asynccontextmanager -import logging -import json # Added for serialization/deserialization -from .utils import ensure_content_dirs, generate_content_hash -import xxhash -import aiofiles -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -DB_PATH = os.path.join(Path.home(), ".crawl4ai") -os.makedirs(DB_PATH, exist_ok=True) -DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") - -class AsyncDatabaseManager: - def __init__(self, pool_size: int = 10, max_retries: int = 3): - self.db_path = DB_PATH - self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH)) - self.pool_size = pool_size - self.max_retries = max_retries - self.connection_pool: Dict[int, aiosqlite.Connection] = {} - self.pool_lock = asyncio.Lock() - self.connection_semaphore = asyncio.Semaphore(pool_size) - - async def initialize(self): - """Initialize the database and connection pool""" - await self.ainit_db() - - async def cleanup(self): - """Cleanup connections when shutting down""" - async with self.pool_lock: - for conn in self.connection_pool.values(): - await conn.close() - self.connection_pool.clear() - - @asynccontextmanager - async def get_connection(self): - """Connection pool manager""" - async with self.connection_semaphore: - task_id = id(asyncio.current_task()) - try: - async with self.pool_lock: - if task_id not in self.connection_pool: - conn = await aiosqlite.connect( - self.db_path, - timeout=30.0 - ) - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA busy_timeout = 5000') - self.connection_pool[task_id] = conn - - yield self.connection_pool[task_id] - - except Exception as e: - logger.error(f"Connection error: {e}") - raise - finally: - async with self.pool_lock: - if task_id in self.connection_pool: - await self.connection_pool[task_id].close() - del self.connection_pool[task_id] - - async def execute_with_retry(self, operation, *args): - """Execute database operations with retry logic""" - for attempt in range(self.max_retries): - try: - async with self.get_connection() as db: - result = await operation(db, *args) - await db.commit() - return result - except Exception as e: - if attempt == self.max_retries - 1: - logger.error(f"Operation failed after {self.max_retries} attempts: {e}") - raise - await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff - - async def ainit_db(self): - """Initialize database schema""" - async def _init(db): - await db.execute(''' - CREATE TABLE IF NOT EXISTS crawled_data ( - url TEXT PRIMARY KEY, - html TEXT, - cleaned_html TEXT, - markdown TEXT, - extracted_content TEXT, - success BOOLEAN, - media TEXT DEFAULT "{}", - links TEXT DEFAULT "{}", - metadata TEXT DEFAULT "{}", - screenshot TEXT DEFAULT "", - response_headers TEXT DEFAULT "{}", - downloaded_files TEXT DEFAULT "{}" -- New column added - ) - ''') - - await self.execute_with_retry(_init) - await self.update_db_schema() - - async def update_db_schema(self): - """Update database schema if needed""" - async def _check_columns(db): - cursor = await db.execute("PRAGMA table_info(crawled_data)") - columns = await cursor.fetchall() - return [column[1] for column in columns] - - column_names = await self.execute_with_retry(_check_columns) - - # List of new columns to add - new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] - - for column in new_columns: - if column not in column_names: - await self.aalter_db_add_column(column) - - async def aalter_db_add_column(self, new_column: str): - """Add new column to the database""" - async def _alter(db): - if new_column == 'response_headers': - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') - else: - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') - logger.info(f"Added column '{new_column}' to the database.") - - await self.execute_with_retry(_alter) - - - async def aget_cached_url(self, url: str) -> Optional[Tuple[str, str, str, str, str, bool, str, str, str, str]]: - """Retrieve cached URL data""" - async def _get(db): - async with db.execute( - ''' - SELECT url, html, cleaned_html, markdown, - extracted_content, success, media, links, - metadata, screenshot, response_headers, - downloaded_files - FROM crawled_data WHERE url = ? - ''', - (url,) - ) as cursor: - row = await cursor.fetchone() - if row: - # Load content from files using stored hashes - html = await self._load_content(row[1], 'html') if row[1] else "" - cleaned = await self._load_content(row[2], 'cleaned') if row[2] else "" - markdown = await self._load_content(row[3], 'markdown') if row[3] else "" - extracted = await self._load_content(row[4], 'extracted') if row[4] else "" - screenshot = await self._load_content(row[9], 'screenshots') if row[9] else "" - - return ( - row[0], # url - html or "", # Return empty string if file not found - cleaned or "", - markdown or "", - extracted or "", - row[5], # success - json.loads(row[6] or '{}'), # media - json.loads(row[7] or '{}'), # links - json.loads(row[8] or '{}'), # metadata - screenshot or "", - json.loads(row[10] or '{}'), # response_headers - json.loads(row[11] or '[]') # downloaded_files - ) - return None - - try: - return await self.execute_with_retry(_get) - except Exception as e: - logger.error(f"Error retrieving cached URL: {e}") - return None - - async def acache_url(self, url: str, html: str, cleaned_html: str, - markdown: str, extracted_content: str, success: bool, - media: str = "{}", links: str = "{}", - metadata: str = "{}", screenshot: str = "", - response_headers: str = "{}", downloaded_files: str = "[]"): - """Cache URL data with content stored in filesystem""" - - # Store content files and get hashes - html_hash = await self._store_content(html, 'html') - cleaned_hash = await self._store_content(cleaned_html, 'cleaned') - markdown_hash = await self._store_content(markdown, 'markdown') - extracted_hash = await self._store_content(extracted_content, 'extracted') - screenshot_hash = await self._store_content(screenshot, 'screenshots') - - async def _cache(db): - await db.execute(''' - INSERT INTO crawled_data ( - url, html, cleaned_html, markdown, - extracted_content, success, media, links, metadata, - screenshot, response_headers, downloaded_files - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(url) DO UPDATE SET - html = excluded.html, - cleaned_html = excluded.cleaned_html, - markdown = excluded.markdown, - extracted_content = excluded.extracted_content, - success = excluded.success, - media = excluded.media, - links = excluded.links, - metadata = excluded.metadata, - screenshot = excluded.screenshot, - response_headers = excluded.response_headers, - downloaded_files = excluded.downloaded_files - ''', (url, html_hash, cleaned_hash, markdown_hash, extracted_hash, - success, media, links, metadata, screenshot_hash, - response_headers, downloaded_files)) - - try: - await self.execute_with_retry(_cache) - except Exception as e: - logger.error(f"Error caching URL: {e}") - - - - async def aget_total_count(self) -> int: - """Get total number of cached URLs""" - async def _count(db): - async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: - result = await cursor.fetchone() - return result[0] if result else 0 - - try: - return await self.execute_with_retry(_count) - except Exception as e: - logger.error(f"Error getting total count: {e}") - return 0 - - async def aclear_db(self): - """Clear all data from the database""" - async def _clear(db): - await db.execute('DELETE FROM crawled_data') - - try: - await self.execute_with_retry(_clear) - except Exception as e: - logger.error(f"Error clearing database: {e}") - - async def aflush_db(self): - """Drop the entire table""" - async def _flush(db): - await db.execute('DROP TABLE IF EXISTS crawled_data') - - try: - await self.execute_with_retry(_flush) - except Exception as e: - logger.error(f"Error flushing database: {e}") - - - async def _store_content(self, content: str, content_type: str) -> str: - """Store content in filesystem and return hash""" - if not content: - return "" - - content_hash = generate_content_hash(content) - file_path = os.path.join(self.content_paths[content_type], content_hash) - - # Only write if file doesn't exist - if not os.path.exists(file_path): - async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: - await f.write(content) - - return content_hash - - async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: - """Load content from filesystem by hash""" - if not content_hash: - return None - - file_path = os.path.join(self.content_paths[content_type], content_hash) - try: - async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: - return await f.read() - except: - logger.error(f"Failed to load content: {file_path}") - return None - -# Create a singleton instance -async_db_manager = AsyncDatabaseManager() diff --git a/crawl4ai/async_webcrawler.3.73.py b/crawl4ai/async_webcrawler.3.73.py deleted file mode 100644 index 03e7a393..00000000 --- a/crawl4ai/async_webcrawler.3.73.py +++ /dev/null @@ -1,344 +0,0 @@ -import os -import time -from pathlib import Path -from typing import Optional -import json -import asyncio -from .models import CrawlResult -from .async_database import async_db_manager -from .chunking_strategy import * -from .extraction_strategy import * -from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse -from .content_scrapping_strategy import WebScrapingStrategy -from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD -from .utils import ( - sanitize_input_encode, - InvalidCSSSelectorError, - format_html -) -from .__version__ import __version__ as crawl4ai_version - -class AsyncWebCrawler: - def __init__( - self, - crawler_strategy: Optional[AsyncCrawlerStrategy] = None, - always_by_pass_cache: bool = False, - base_directory: str = str(Path.home()), - **kwargs, - ): - self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy( - **kwargs - ) - self.always_by_pass_cache = always_by_pass_cache - # self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai") - self.crawl4ai_folder = os.path.join(base_directory, ".crawl4ai") - os.makedirs(self.crawl4ai_folder, exist_ok=True) - os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) - self.ready = False - self.verbose = kwargs.get("verbose", False) - - async def __aenter__(self): - await self.crawler_strategy.__aenter__() - await self.awarmup() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.crawler_strategy.__aexit__(exc_type, exc_val, exc_tb) - - async def awarmup(self): - # Print a message for crawl4ai and its version - if self.verbose: - print(f"[LOG] 🚀 Crawl4AI {crawl4ai_version}") - print("[LOG] 🌤️ Warming up the AsyncWebCrawler") - # await async_db_manager.ainit_db() - # # await async_db_manager.initialize() - # await self.arun( - # url="https://google.com/", - # word_count_threshold=5, - # bypass_cache=False, - # verbose=False, - # ) - self.ready = True - if self.verbose: - print("[LOG] 🌞 AsyncWebCrawler is ready to crawl") - - async def arun( - self, - url: str, - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - disable_cache: bool = False, - no_cache_read: bool = False, - no_cache_write: bool = False, - **kwargs, - ) -> CrawlResult: - """ - Runs the crawler for a single source: URL (web, local file, or raw HTML). - - Args: - url (str): The URL to crawl. Supported prefixes: - - 'http://' or 'https://': Web URL to crawl. - - 'file://': Local file path to process. - - 'raw:': Raw HTML content to process. - ... [other existing parameters] - - Returns: - CrawlResult: The result of the crawling and processing. - """ - try: - if disable_cache: - bypass_cache = True - no_cache_read = True - no_cache_write = True - - extraction_strategy = extraction_strategy or NoExtractionStrategy() - extraction_strategy.verbose = verbose - if not isinstance(extraction_strategy, ExtractionStrategy): - raise ValueError("Unsupported extraction strategy") - if not isinstance(chunking_strategy, ChunkingStrategy): - raise ValueError("Unsupported chunking strategy") - - word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD) - - async_response: AsyncCrawlResponse = None - cached = None - screenshot_data = None - extracted_content = None - - is_web_url = url.startswith(('http://', 'https://')) - is_local_file = url.startswith("file://") - is_raw_html = url.startswith("raw:") - _url = url if not is_raw_html else "Raw HTML" - - start_time = time.perf_counter() - cached_result = None - if is_web_url and (not bypass_cache or not no_cache_read) and not self.always_by_pass_cache: - cached_result = await async_db_manager.aget_cached_url(url) - - if cached_result: - html = sanitize_input_encode(cached_result.html) - extracted_content = sanitize_input_encode(cached_result.extracted_content or "") - if screenshot: - screenshot_data = cached_result.screenshot - if not screenshot_data: - cached_result = None - if verbose: - print( - f"[LOG] 1️⃣ ✅ Page fetched (cache) for {_url}, success: {bool(html)}, time taken: {time.perf_counter() - start_time:.2f} seconds" - ) - - - if not cached or not html: - t1 = time.perf_counter() - - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - async_response: AsyncCrawlResponse = await self.crawler_strategy.crawl(url, screenshot=screenshot, **kwargs) - html = sanitize_input_encode(async_response.html) - screenshot_data = async_response.screenshot - t2 = time.perf_counter() - if verbose: - print( - f"[LOG] 1️⃣ ✅ Page fetched (no-cache) for {_url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds" - ) - - t1 = time.perf_counter() - crawl_result = await self.aprocess_html( - url=url, - html=html, - extracted_content=extracted_content, - word_count_threshold=word_count_threshold, - extraction_strategy=extraction_strategy, - chunking_strategy=chunking_strategy, - css_selector=css_selector, - screenshot=screenshot_data, - verbose=verbose, - is_cached=bool(cached), - async_response=async_response, - bypass_cache=bypass_cache, - is_web_url = is_web_url, - is_local_file = is_local_file, - is_raw_html = is_raw_html, - **kwargs, - ) - - if async_response: - crawl_result.status_code = async_response.status_code - crawl_result.response_headers = async_response.response_headers - crawl_result.downloaded_files = async_response.downloaded_files - else: - crawl_result.status_code = 200 - crawl_result.response_headers = cached_result.response_headers if cached_result else {} - - crawl_result.success = bool(html) - crawl_result.session_id = kwargs.get("session_id", None) - - if verbose: - print( - f"[LOG] 🔥 🚀 Crawling done for {_url}, success: {crawl_result.success}, time taken: {time.perf_counter() - start_time:.2f} seconds" - ) - - if not is_raw_html and not no_cache_write: - if not bool(cached_result) or kwargs.get("bypass_cache", False) or self.always_by_pass_cache: - await async_db_manager.acache_url(crawl_result) - - - return crawl_result - - except Exception as e: - if not hasattr(e, "msg"): - e.msg = str(e) - print(f"[ERROR] 🚫 arun(): Failed to crawl {_url}, error: {e.msg}") - return CrawlResult(url=url, html="", markdown = f"[ERROR] 🚫 arun(): Failed to crawl {_url}, error: {e.msg}", success=False, error_message=e.msg) - - async def arun_many( - self, - urls: List[str], - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> List[CrawlResult]: - """ - Runs the crawler for multiple sources: URLs (web, local files, or raw HTML). - - Args: - urls (List[str]): A list of URLs with supported prefixes: - - 'http://' or 'https://': Web URL to crawl. - - 'file://': Local file path to process. - - 'raw:': Raw HTML content to process. - ... [other existing parameters] - - Returns: - List[CrawlResult]: The results of the crawling and processing. - """ - semaphore_count = kwargs.get('semaphore_count', 5) # Adjust as needed - semaphore = asyncio.Semaphore(semaphore_count) - - async def crawl_with_semaphore(url): - async with semaphore: - return await self.arun( - url, - word_count_threshold=word_count_threshold, - extraction_strategy=extraction_strategy, - chunking_strategy=chunking_strategy, - bypass_cache=bypass_cache, - css_selector=css_selector, - screenshot=screenshot, - user_agent=user_agent, - verbose=verbose, - **kwargs, - ) - - tasks = [crawl_with_semaphore(url) for url in urls] - results = await asyncio.gather(*tasks, return_exceptions=True) - return [result if not isinstance(result, Exception) else str(result) for result in results] - - async def aprocess_html( - self, - url: str, - html: str, - extracted_content: str, - word_count_threshold: int, - extraction_strategy: ExtractionStrategy, - chunking_strategy: ChunkingStrategy, - css_selector: str, - screenshot: str, - verbose: bool, - **kwargs, - ) -> CrawlResult: - t = time.perf_counter() - # Extract content from HTML - try: - _url = url if not kwargs.get("is_raw_html", False) else "Raw HTML" - t1 = time.perf_counter() - scrapping_strategy = WebScrapingStrategy() - # result = await scrapping_strategy.ascrap( - result = scrapping_strategy.scrap( - url, - html, - word_count_threshold=word_count_threshold, - css_selector=css_selector, - only_text=kwargs.get("only_text", False), - image_description_min_word_threshold=kwargs.get( - "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD - ), - **kwargs, - ) - - if result is None: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}") - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - except Exception as e: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") - - cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) - markdown = sanitize_input_encode(result.get("markdown", "")) - fit_markdown = sanitize_input_encode(result.get("fit_markdown", "")) - fit_html = sanitize_input_encode(result.get("fit_html", "")) - media = result.get("media", []) - links = result.get("links", []) - metadata = result.get("metadata", {}) - - if verbose: - print( - f"[LOG] 2️⃣ ✅ Scraping done for {_url}, success: True, time taken: {time.perf_counter() - t1:.2f} seconds" - ) - - if extracted_content is None and extraction_strategy and chunking_strategy and not isinstance(extraction_strategy, NoExtractionStrategy): - t1 = time.perf_counter() - # Check if extraction strategy is type of JsonCssExtractionStrategy - if isinstance(extraction_strategy, JsonCssExtractionStrategy) or isinstance(extraction_strategy, JsonCssExtractionStrategy): - extraction_strategy.verbose = verbose - extracted_content = extraction_strategy.run(url, [html]) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) - else: - sections = chunking_strategy.chunk(markdown) - extracted_content = extraction_strategy.run(url, sections) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) - if verbose: - print( - f"[LOG] 3️⃣ ✅ Extraction done for {_url}, time taken: {time.perf_counter() - t1:.2f} seconds" - ) - - screenshot = None if not screenshot else screenshot - - return CrawlResult( - url=url, - html=html, - cleaned_html=format_html(cleaned_html), - markdown=markdown, - fit_markdown=fit_markdown, - fit_html= fit_html, - media=media, - links=links, - metadata=metadata, - screenshot=screenshot, - extracted_content=extracted_content, - success=True, - error_message="", - ) - - async def aclear_cache(self): - # await async_db_manager.aclear_db() - await async_db_manager.cleanup() - - async def aflush_cache(self): - await async_db_manager.aflush_db() - - async def aget_cache_size(self): - return await async_db_manager.aget_total_count() - - diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 7d1814b6..2ff7ce0f 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -7,14 +7,14 @@ from pathlib import Path from typing import Optional, List, Union import json import asyncio -from .models import CrawlResult +from .models import CrawlResult, MarkdownGenerationResult from .async_database import async_db_manager from .chunking_strategy import * from .content_filter_strategy import * from .extraction_strategy import * from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode -from .content_scrapping_strategy import WebScrapingStrategy +from .content_scraping_strategy import WebScrapingStrategy from .async_logger import AsyncLogger from .config import ( @@ -476,7 +476,7 @@ class AsyncWebCrawler: html, word_count_threshold=word_count_threshold, css_selector=css_selector, - only_text=kwargs.get("only_text", False), + only_text=kwargs.pop("only_text", False), image_description_min_word_threshold=kwargs.get( "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD ), @@ -491,6 +491,8 @@ class AsyncWebCrawler: except Exception as e: raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") + markdown_v2: MarkdownGenerationResult = result.get("markdown_v2", None) + cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) markdown = sanitize_input_encode(result.get("markdown", "")) fit_markdown = sanitize_input_encode(result.get("fit_markdown", "")) @@ -542,6 +544,7 @@ class AsyncWebCrawler: url=url, html=html, cleaned_html=format_html(cleaned_html), + markdown_v2=markdown_v2, markdown=markdown, fit_markdown=fit_markdown, fit_html= fit_html, diff --git a/crawl4ai/content_scrapping_strategy.py b/crawl4ai/content_scraping_strategy.py similarity index 84% rename from crawl4ai/content_scrapping_strategy.py rename to crawl4ai/content_scraping_strategy.py index 0f470671..3823a78d 100644 --- a/crawl4ai/content_scrapping_strategy.py +++ b/crawl4ai/content_scraping_strategy.py @@ -1,6 +1,6 @@ import re # Point 1: Pre-Compile Regular Expressions from abc import ABC, abstractmethod -from typing import Dict, Any +from typing import Dict, Any, Optional from bs4 import BeautifulSoup from concurrent.futures import ThreadPoolExecutor import asyncio, requests, re, os @@ -10,103 +10,19 @@ from urllib.parse import urljoin from requests.exceptions import InvalidSchema # from .content_cleaning_strategy import ContentCleaningStrategy from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter - +from .markdown_generation_strategy import MarkdownGenerationStrategy, DefaultMarkdownGenerationStrategy +from .models import MarkdownGenerationResult from .utils import ( sanitize_input_encode, sanitize_html, extract_metadata, InvalidCSSSelectorError, - # CustomHTML2Text, + CustomHTML2Text, normalize_url, is_external_url ) -from .html2text import HTML2Text -class CustomHTML2Text(HTML2Text): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.inside_pre = False - self.inside_code = False - self.preserve_tags = set() # Set of tags to preserve - self.current_preserved_tag = None - self.preserved_content = [] - self.preserve_depth = 0 - - # Configuration options - self.skip_internal_links = False - self.single_line_break = False - self.mark_code = False - self.include_sup_sub = False - self.body_width = 0 - self.ignore_mailto_links = True - self.ignore_links = False - self.escape_backslash = False - self.escape_dot = False - self.escape_plus = False - self.escape_dash = False - self.escape_snob = False - - def update_params(self, **kwargs): - """Update parameters and set preserved tags.""" - for key, value in kwargs.items(): - if key == 'preserve_tags': - self.preserve_tags = set(value) - else: - setattr(self, key, value) - - def handle_tag(self, tag, attrs, start): - # Handle preserved tags - if tag in self.preserve_tags: - if start: - if self.preserve_depth == 0: - self.current_preserved_tag = tag - self.preserved_content = [] - # Format opening tag with attributes - attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) - self.preserved_content.append(f'<{tag}{attr_str}>') - self.preserve_depth += 1 - return - else: - self.preserve_depth -= 1 - if self.preserve_depth == 0: - self.preserved_content.append(f'') - # Output the preserved HTML block with proper spacing - preserved_html = ''.join(self.preserved_content) - self.o('\n' + preserved_html + '\n') - self.current_preserved_tag = None - return - - # If we're inside a preserved tag, collect all content - if self.preserve_depth > 0: - if start: - # Format nested tags with attributes - attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) - self.preserved_content.append(f'<{tag}{attr_str}>') - else: - self.preserved_content.append(f'') - return - - # Handle pre tags - if tag == 'pre': - if start: - self.o('```\n') - self.inside_pre = True - else: - self.o('\n```') - self.inside_pre = False - # elif tag in ["h1", "h2", "h3", "h4", "h5", "h6"]: - # pass - else: - super().handle_tag(tag, attrs, start) - - def handle_data(self, data, entity_char=False): - """Override handle_data to capture content within preserved tags.""" - if self.preserve_depth > 0: - self.preserved_content.append(data) - return - super().handle_data(data, entity_char) - # Pre-compile regular expressions for Open Graph and Twitter metadata OG_REGEX = re.compile(r'^og:') TWITTER_REGEX = re.compile(r'^twitter:') @@ -164,6 +80,98 @@ class WebScrapingStrategy(ContentScrapingStrategy): async def ascrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]: return await asyncio.to_thread(self._get_content_of_website_optimized, url, html, **kwargs) + + def _generate_markdown_content(self, + cleaned_html: str, + html: str, + url: str, + success: bool, + **kwargs) -> Dict[str, Any]: + """Generate markdown content using either new strategy or legacy method. + + Args: + cleaned_html: Sanitized HTML content + html: Original HTML content + url: Base URL of the page + success: Whether scraping was successful + **kwargs: Additional options including: + - markdown_generator: Optional[MarkdownGenerationStrategy] + - html2text: Dict[str, Any] options for HTML2Text + - content_filter: Optional[RelevantContentFilter] + - fit_markdown: bool + - fit_markdown_user_query: Optional[str] + - fit_markdown_bm25_threshold: float + + Returns: + Dict containing markdown content in various formats + """ + markdown_generator: Optional[MarkdownGenerationStrategy] = kwargs.get('markdown_generator', DefaultMarkdownGenerationStrategy()) + + if markdown_generator: + try: + markdown_result = markdown_generator.generate_markdown( + cleaned_html=cleaned_html, + base_url=url, + html2text_options=kwargs.get('html2text', {}), + content_filter=kwargs.get('content_filter', None) + ) + + markdown_v2 = MarkdownGenerationResult( + raw_markdown=markdown_result.raw_markdown, + markdown_with_citations=markdown_result.markdown_with_citations, + references_markdown=markdown_result.references_markdown, + fit_markdown=markdown_result.fit_markdown + ) + + return { + 'markdown': markdown_result.raw_markdown, + 'fit_markdown': markdown_result.fit_markdown or "Set flag 'fit_markdown' to True to get cleaned HTML content.", + 'fit_html': kwargs.get('content_filter', None).filter_content(html) if kwargs.get('content_filter') else "Set flag 'fit_markdown' to True to get cleaned HTML content.", + 'markdown_v2': markdown_v2 + } + except Exception as e: + self._log('error', + message="Error using new markdown generation strategy: {error}", + tag="SCRAPE", + params={"error": str(e)} + ) + markdown_generator = None + + # Legacy method + h = CustomHTML2Text() + h.update_params(**kwargs.get('html2text', {})) + markdown = h.handle(cleaned_html) + markdown = markdown.replace(' ```', '```') + + fit_markdown = "Set flag 'fit_markdown' to True to get cleaned HTML content." + fit_html = "Set flag 'fit_markdown' to True to get cleaned HTML content." + + if kwargs.get('content_filter', None) or kwargs.get('fit_markdown', False): + content_filter = kwargs.get('content_filter', None) + if not content_filter: + content_filter = BM25ContentFilter( + user_query=kwargs.get('fit_markdown_user_query', None), + bm25_threshold=kwargs.get('fit_markdown_bm25_threshold', 1.0) + ) + fit_html = content_filter.filter_content(html) + fit_html = '\n'.join('
{}
'.format(s) for s in fit_html) + fit_markdown = h.handle(fit_html) + + markdown_v2 = MarkdownGenerationResult( + raw_markdown=markdown, + markdown_with_citations=markdown, + references_markdown=markdown, + fit_markdown=fit_markdown + ) + + return { + 'markdown': markdown, + 'fit_markdown': fit_markdown, + 'fit_html': fit_html, + 'markdown_v2' : markdown_v2 + } + + def _get_content_of_website_optimized(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, css_selector: str = None, **kwargs) -> Dict[str, Any]: success = True if not html: @@ -242,8 +250,6 @@ class WebScrapingStrategy(ContentScrapingStrategy): #Score an image for it's usefulness def score_image_for_usefulness(img, base_url, index, images_count): - - image_height = img.get('height') height_value, height_unit = parse_dimension(image_height) image_width = img.get('width') @@ -282,7 +288,7 @@ class WebScrapingStrategy(ContentScrapingStrategy): if not is_valid_image(img, img.parent, img.parent.get('class', [])): return None score = score_image_for_usefulness(img, url, index, total_images) - if score <= IMAGE_SCORE_THRESHOLD: + if score <= kwargs.get('image_score_threshold', IMAGE_SCORE_THRESHOLD): return None return { 'src': img.get('src', ''), @@ -545,41 +551,16 @@ class WebScrapingStrategy(ContentScrapingStrategy): cleaned_html = str_body.replace('\n\n', '\n').replace(' ', ' ') - try: - h = CustomHTML2Text() - h.update_params(**kwargs.get('html2text', {})) - markdown = h.handle(cleaned_html) - except Exception as e: - if not h: - h = CustomHTML2Text() - self._log('error', - message="Error converting HTML to markdown: {error}", - tag="SCRAPE", - params={"error": str(e)} - ) - markdown = h.handle(sanitize_html(cleaned_html)) - markdown = markdown.replace(' ```', '```') - + markdown_content = self._generate_markdown_content( + cleaned_html=cleaned_html, + html=html, + url=url, + success=success, + **kwargs + ) - - fit_markdown = "Set flag 'fit_markdown' to True to get cleaned HTML content." - fit_html = "Set flag 'fit_markdown' to True to get cleaned HTML content." - if kwargs.get('content_filter', None) or kwargs.get('fit_markdown', False): - content_filter = kwargs.get('content_filter', None) - if not content_filter: - content_filter = BM25ContentFilter( - user_query= kwargs.get('fit_markdown_user_query', None), - bm25_threshold= kwargs.get('fit_markdown_bm25_threshold', 1.0) - ) - fit_html = content_filter.filter_content(html) - fit_html = '\n'.join('
{}
'.format(s) for s in fit_html) - fit_markdown = h.handle(fit_html) - - cleaned_html = sanitize_html(cleaned_html) return { - 'markdown': markdown, - 'fit_markdown': fit_markdown, - 'fit_html': fit_html, + **markdown_content, 'cleaned_html': cleaned_html, 'success': success, 'media': media, diff --git a/crawl4ai/markdown_generation_strategy.py b/crawl4ai/markdown_generation_strategy.py new file mode 100644 index 00000000..1adb4c28 --- /dev/null +++ b/crawl4ai/markdown_generation_strategy.py @@ -0,0 +1,115 @@ +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, Tuple +from .models import MarkdownGenerationResult +from .utils import CustomHTML2Text +from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter +import re +from urllib.parse import urljoin + +# Pre-compile the regex pattern +LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)') + +class MarkdownGenerationStrategy(ABC): + """Abstract base class for markdown generation strategies.""" + + @abstractmethod + def generate_markdown(self, + cleaned_html: str, + base_url: str = "", + html2text_options: Optional[Dict[str, Any]] = None, + content_filter: Optional[RelevantContentFilter] = None, + citations: bool = True, + **kwargs) -> MarkdownGenerationResult: + """Generate markdown from cleaned HTML.""" + pass + +class DefaultMarkdownGenerationStrategy(MarkdownGenerationStrategy): + """Default implementation of markdown generation strategy.""" + + def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]: + link_map = {} + url_cache = {} # Cache for URL joins + parts = [] + last_end = 0 + counter = 1 + + for match in LINK_PATTERN.finditer(markdown): + parts.append(markdown[last_end:match.start()]) + text, url, title = match.groups() + + # Use cached URL if available, otherwise compute and cache + if base_url and not url.startswith(('http://', 'https://', 'mailto:')): + if url not in url_cache: + url_cache[url] = fast_urljoin(base_url, url) + url = url_cache[url] + + if url not in link_map: + desc = [] + if title: desc.append(title) + if text and text != title: desc.append(text) + link_map[url] = (counter, ": " + " - ".join(desc) if desc else "") + counter += 1 + + num = link_map[url][0] + parts.append(f"{text}⟨{num}⟩" if not match.group(0).startswith('!') else f"![{text}⟨{num}⟩]") + last_end = match.end() + + parts.append(markdown[last_end:]) + converted_text = ''.join(parts) + + # Pre-build reference strings + references = ["\n\n## References\n\n"] + references.extend( + f"⟨{num}⟩ {url}{desc}\n" + for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0]) + ) + + return converted_text, ''.join(references) + + def generate_markdown(self, + cleaned_html: str, + base_url: str = "", + html2text_options: Optional[Dict[str, Any]] = None, + content_filter: Optional[RelevantContentFilter] = None, + citations: bool = True, + **kwargs) -> MarkdownGenerationResult: + """Generate markdown with citations from cleaned HTML.""" + # Initialize HTML2Text with options + h = CustomHTML2Text() + if html2text_options: + h.update_params(**html2text_options) + + # Generate raw markdown + raw_markdown = h.handle(cleaned_html) + raw_markdown = raw_markdown.replace(' ```', '```') + + # Convert links to citations + if citations: + markdown_with_citations, references_markdown = self.convert_links_to_citations( + raw_markdown, base_url + ) + + # Generate fit markdown if content filter is provided + fit_markdown: Optional[str] = None + if content_filter: + filtered_html = content_filter.filter_content(cleaned_html) + filtered_html = '\n'.join('
{}
'.format(s) for s in filtered_html) + fit_markdown = h.handle(filtered_html) + + return MarkdownGenerationResult( + raw_markdown=raw_markdown, + markdown_with_citations=markdown_with_citations, + references_markdown=references_markdown, + fit_markdown=fit_markdown + ) + +def fast_urljoin(base: str, url: str) -> str: + """Fast URL joining for common cases.""" + if url.startswith(('http://', 'https://', 'mailto:', '//')): + return url + if url.startswith('/'): + # Handle absolute paths + if base.endswith('/'): + return base[:-1] + url + return base + url + return urljoin(base, url) \ No newline at end of file diff --git a/crawl4ai/models.py b/crawl4ai/models.py index cab4c45b..122434ad 100644 --- a/crawl4ai/models.py +++ b/crawl4ai/models.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, HttpUrl -from typing import List, Dict, Optional, Callable, Awaitable +from typing import List, Dict, Optional, Callable, Awaitable, Union @@ -7,6 +7,12 @@ class UrlModel(BaseModel): url: HttpUrl forced: bool = False +class MarkdownGenerationResult(BaseModel): + raw_markdown: str + markdown_with_citations: str + references_markdown: str + fit_markdown: Optional[str] = None + class CrawlResult(BaseModel): url: str html: str @@ -16,7 +22,8 @@ class CrawlResult(BaseModel): links: Dict[str, List[Dict]] = {} downloaded_files: Optional[List[str]] = None screenshot: Optional[str] = None - markdown: Optional[str] = None + markdown: Optional[Union[str, MarkdownGenerationResult]] = None + markdown_v2: Optional[MarkdownGenerationResult] = None fit_markdown: Optional[str] = None fit_html: Optional[str] = None extracted_content: Optional[str] = None @@ -36,3 +43,5 @@ class AsyncCrawlResponse(BaseModel): class Config: arbitrary_types_allowed = True + + diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 9abc5784..b07562df 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -18,6 +18,94 @@ import hashlib from typing import Optional, Tuple, Dict, Any import xxhash + +from .html2text import HTML2Text +class CustomHTML2Text(HTML2Text): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inside_pre = False + self.inside_code = False + self.preserve_tags = set() # Set of tags to preserve + self.current_preserved_tag = None + self.preserved_content = [] + self.preserve_depth = 0 + + # Configuration options + self.skip_internal_links = False + self.single_line_break = False + self.mark_code = False + self.include_sup_sub = False + self.body_width = 0 + self.ignore_mailto_links = True + self.ignore_links = False + self.escape_backslash = False + self.escape_dot = False + self.escape_plus = False + self.escape_dash = False + self.escape_snob = False + + def update_params(self, **kwargs): + """Update parameters and set preserved tags.""" + for key, value in kwargs.items(): + if key == 'preserve_tags': + self.preserve_tags = set(value) + else: + setattr(self, key, value) + + def handle_tag(self, tag, attrs, start): + # Handle preserved tags + if tag in self.preserve_tags: + if start: + if self.preserve_depth == 0: + self.current_preserved_tag = tag + self.preserved_content = [] + # Format opening tag with attributes + attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) + self.preserved_content.append(f'<{tag}{attr_str}>') + self.preserve_depth += 1 + return + else: + self.preserve_depth -= 1 + if self.preserve_depth == 0: + self.preserved_content.append(f'') + # Output the preserved HTML block with proper spacing + preserved_html = ''.join(self.preserved_content) + self.o('\n' + preserved_html + '\n') + self.current_preserved_tag = None + return + + # If we're inside a preserved tag, collect all content + if self.preserve_depth > 0: + if start: + # Format nested tags with attributes + attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) + self.preserved_content.append(f'<{tag}{attr_str}>') + else: + self.preserved_content.append(f'') + return + + # Handle pre tags + if tag == 'pre': + if start: + self.o('```\n') + self.inside_pre = True + else: + self.o('\n```') + self.inside_pre = False + # elif tag in ["h1", "h2", "h3", "h4", "h5", "h6"]: + # pass + else: + super().handle_tag(tag, attrs, start) + + def handle_data(self, data, entity_char=False): + """Override handle_data to capture content within preserved tags.""" + if self.preserve_depth > 0: + self.preserved_content.append(data) + return + super().handle_data(data, entity_char) + + + class InvalidCSSSelectorError(Exception): pass diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 6cfef6f0..a32a988d 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -10,7 +10,7 @@ from .extraction_strategy import * from .crawler_strategy import * from typing import List from concurrent.futures import ThreadPoolExecutor -from .content_scrapping_strategy import WebScrapingStrategy +from .content_scraping_strategy import WebScrapingStrategy from .config import * import warnings import json diff --git a/tests/async/test_content_scraper_strategy.py b/tests/async/test_content_scraper_strategy.py index 5dfa6362..62c49148 100644 --- a/tests/async/test_content_scraper_strategy.py +++ b/tests/async/test_content_scraper_strategy.py @@ -13,8 +13,8 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__f sys.path.append(parent_dir) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) -from crawl4ai.content_scrapping_strategy import WebScrapingStrategy -from crawl4ai.content_scrapping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent +from crawl4ai.content_scraping_strategy import WebScrapingStrategy +from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent # from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent @dataclass diff --git a/tests/async/test_markdown_genertor.py b/tests/async/test_markdown_genertor.py new file mode 100644 index 00000000..025a0318 --- /dev/null +++ b/tests/async/test_markdown_genertor.py @@ -0,0 +1,165 @@ +# ## Issue #236 +# - **Last Updated:** 2024-11-11 01:42:14 +# - **Title:** [user data crawling opens two windows, unable to control correct user browser](https://github.com/unclecode/crawl4ai/issues/236) +# - **State:** open + +import os, sys, time +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) +import asyncio +import os +import time +from typing import Dict, Any +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerationStrategy + +# Get current directory +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + +def print_test_result(name: str, result: Dict[str, Any], execution_time: float): + """Helper function to print test results.""" + print(f"\n{'='*20} {name} {'='*20}") + print(f"Execution time: {execution_time:.4f} seconds") + + + # Save markdown to files + for key, content in result.items(): + if isinstance(content, str): + with open(__location__ + f"/output/{name.lower()}_{key}.md", "w") as f: + f.write(content) + + # # Print first few lines of each markdown version + # for key, content in result.items(): + # if isinstance(content, str): + # preview = '\n'.join(content.split('\n')[:3]) + # print(f"\n{key} (first 3 lines):") + # print(preview) + # print(f"Total length: {len(content)} characters") + +def test_basic_markdown_conversion(): + """Test basic markdown conversion with links.""" + with open(__location__ + "/data/wikipedia.html", "r") as f: + cleaned_html = f.read() + + generator = DefaultMarkdownGenerationStrategy() + + start_time = time.perf_counter() + result = generator.generate_markdown( + cleaned_html=cleaned_html, + base_url="https://en.wikipedia.org" + ) + execution_time = time.perf_counter() - start_time + + print_test_result("Basic Markdown Conversion", { + 'raw': result.raw_markdown, + 'with_citations': result.markdown_with_citations, + 'references': result.references_markdown + }, execution_time) + + # Basic assertions + assert result.raw_markdown, "Raw markdown should not be empty" + assert result.markdown_with_citations, "Markdown with citations should not be empty" + assert result.references_markdown, "References should not be empty" + assert "⟨" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets" + assert "## References" in result.references_markdown, "Should contain references section" + +def test_relative_links(): + """Test handling of relative links with base URL.""" + markdown = """ + Here's a [relative link](/wiki/Apple) and an [absolute link](https://example.com). + Also an [image](/images/test.png) and another [page](/wiki/Banana). + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://en.wikipedia.org" + ) + + assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown + assert "https://example.com" in result.references_markdown + assert "https://en.wikipedia.org/images/test.png" in result.references_markdown + +def test_duplicate_links(): + """Test handling of duplicate links.""" + markdown = """ + Here's a [link](/test) and another [link](/test) and a [different link](/other). + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://example.com" + ) + + # Count citations in markdown + citations = result.markdown_with_citations.count("⟨1⟩") + assert citations == 2, "Same link should use same citation number" + +def test_link_descriptions(): + """Test handling of link titles and descriptions.""" + markdown = """ + Here's a [link with title](/test "Test Title") and a [link with description](/other) to test. + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://example.com" + ) + + assert "Test Title" in result.references_markdown, "Link title should be in references" + assert "link with description" in result.references_markdown, "Link text should be in references" + +def test_performance_large_document(): + """Test performance with large document.""" + with open(__location__ + "/data/wikipedia.md", "r") as f: + markdown = f.read() + + # Test with multiple iterations + iterations = 5 + times = [] + + generator = DefaultMarkdownGenerationStrategy() + + for i in range(iterations): + start_time = time.perf_counter() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://en.wikipedia.org" + ) + end_time = time.perf_counter() + times.append(end_time - start_time) + + avg_time = sum(times) / len(times) + print(f"\n{'='*20} Performance Test {'='*20}") + print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds") + print(f"Min time: {min(times):.4f} seconds") + print(f"Max time: {max(times):.4f} seconds") + +def test_image_links(): + """Test handling of image links.""" + markdown = """ + Here's an ![image](/image.png "Image Title") and another ![image](/other.jpg). + And a regular [link](/page). + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://example.com" + ) + + assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved" + assert "Image Title" in result.references_markdown, "Image title should be in references" + +if __name__ == "__main__": + print("Running markdown generation strategy tests...") + + test_basic_markdown_conversion() + test_relative_links() + test_duplicate_links() + test_link_descriptions() + test_performance_large_document() + test_image_links() + \ No newline at end of file