From dbb751c8f09f76ffce4046784c2cd2b0021de7d0 Mon Sep 17 00:00:00 2001 From: UncleCode Date: Thu, 21 Nov 2024 18:21:43 +0800 Subject: [PATCH] In this commit, we introduce the new concept of MakrdownGenerationStrategy, which allows us to expand our future strategies to generate better markdown. Right now, we generate raw markdown as we were doing before. We have a new algorithm for fitting markdown based on BM25, and now we add the ability to refine markdown into a citation form. Our links will be extracted and replaced by a citation reference number, and then we will have reference sections at the very end; we add all the links with the descriptions. This format is more suitable for large language models. In case we don't need to pass links, we can reduce the size of the markdown significantly and also attach the list of references as a separate file to a large language model. This commit contains changes for this direction. --- crawl4ai/__init__.py | 1 + crawl4ai/async_crawler_strategy.py | 13 +- crawl4ai/async_database.3.73.py | 285 --------------- crawl4ai/async_webcrawler.3.73.py | 344 ------------------ crawl4ai/async_webcrawler.py | 9 +- ...rategy.py => content_scraping_strategy.py} | 229 ++++++------ crawl4ai/markdown_generation_strategy.py | 115 ++++++ crawl4ai/models.py | 13 +- crawl4ai/utils.py | 88 +++++ crawl4ai/web_crawler.py | 2 +- tests/async/test_content_scraper_strategy.py | 4 +- tests/async/test_markdown_genertor.py | 165 +++++++++ 12 files changed, 506 insertions(+), 762 deletions(-) delete mode 100644 crawl4ai/async_database.3.73.py delete mode 100644 crawl4ai/async_webcrawler.3.73.py rename crawl4ai/{content_scrapping_strategy.py => content_scraping_strategy.py} (84%) create mode 100644 crawl4ai/markdown_generation_strategy.py create mode 100644 tests/async/test_markdown_genertor.py diff --git a/crawl4ai/__init__.py b/crawl4ai/__init__.py index ad9475b4..0ccf13d8 100644 --- a/crawl4ai/__init__.py +++ b/crawl4ai/__init__.py @@ -1,6 +1,7 @@ # __init__.py from .async_webcrawler import AsyncWebCrawler, CacheMode + from .models import CrawlResult from .__version__ import __version__ # __version__ = "0.3.73" diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index e7dc9c54..3f332eb0 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -229,6 +229,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): self.headless = kwargs.get("headless", True) self.browser_type = kwargs.get("browser_type", "chromium") self.headers = kwargs.get("headers", {}) + self.cookies = kwargs.get("cookies", []) self.sessions = {} self.session_ttl = 1800 self.js_code = js_code @@ -295,6 +296,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # Set up the default context if self.default_context: await self.default_context.set_extra_http_headers(self.headers) + if self.cookies: + await self.default_context.add_cookies(self.cookies) if self.accept_downloads: await self.default_context.set_default_timeout(60000) await self.default_context.set_default_navigation_timeout(60000) @@ -669,6 +672,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # downloads_path=self.downloads_path if self.accept_downloads else None ) await context.add_cookies([{"name": "cookiesEnabled", "value": "true", "url": url}]) + if self.cookies: + await context.add_cookies(self.cookies) await context.set_extra_http_headers(self.headers) page = await context.new_page() self.sessions[session_id] = (context, page, time.time()) @@ -684,6 +689,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): proxy={"server": self.proxy} if self.proxy else None, accept_downloads=self.accept_downloads, ) + if self.cookies: + await context.add_cookies(self.cookies) await context.set_extra_http_headers(self.headers) if kwargs.get("override_navigator", False) or kwargs.get("simulate_user", False) or kwargs.get("magic", False): @@ -828,7 +835,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): for js in js_code: await page.evaluate(js) - await page.wait_for_load_state('networkidle') + # await page.wait_for_timeout(100) + # Check for on execution event await self.execute_hook('on_execution_started', page) @@ -846,6 +854,9 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): await self.smart_wait(page, wait_for, timeout=kwargs.get("page_timeout", 60000)) except Exception as e: raise RuntimeError(f"Wait condition failed: {str(e)}") + + # if not wait_for and js_code: + # await page.wait_for_load_state('networkidle', timeout=5000) # Update image dimensions update_image_dimensions_js = """ diff --git a/crawl4ai/async_database.3.73.py b/crawl4ai/async_database.3.73.py deleted file mode 100644 index f86c7f1d..00000000 --- a/crawl4ai/async_database.3.73.py +++ /dev/null @@ -1,285 +0,0 @@ -import os -from pathlib import Path -import aiosqlite -import asyncio -from typing import Optional, Tuple, Dict -from contextlib import asynccontextmanager -import logging -import json # Added for serialization/deserialization -from .utils import ensure_content_dirs, generate_content_hash -import xxhash -import aiofiles -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -DB_PATH = os.path.join(Path.home(), ".crawl4ai") -os.makedirs(DB_PATH, exist_ok=True) -DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") - -class AsyncDatabaseManager: - def __init__(self, pool_size: int = 10, max_retries: int = 3): - self.db_path = DB_PATH - self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH)) - self.pool_size = pool_size - self.max_retries = max_retries - self.connection_pool: Dict[int, aiosqlite.Connection] = {} - self.pool_lock = asyncio.Lock() - self.connection_semaphore = asyncio.Semaphore(pool_size) - - async def initialize(self): - """Initialize the database and connection pool""" - await self.ainit_db() - - async def cleanup(self): - """Cleanup connections when shutting down""" - async with self.pool_lock: - for conn in self.connection_pool.values(): - await conn.close() - self.connection_pool.clear() - - @asynccontextmanager - async def get_connection(self): - """Connection pool manager""" - async with self.connection_semaphore: - task_id = id(asyncio.current_task()) - try: - async with self.pool_lock: - if task_id not in self.connection_pool: - conn = await aiosqlite.connect( - self.db_path, - timeout=30.0 - ) - await conn.execute('PRAGMA journal_mode = WAL') - await conn.execute('PRAGMA busy_timeout = 5000') - self.connection_pool[task_id] = conn - - yield self.connection_pool[task_id] - - except Exception as e: - logger.error(f"Connection error: {e}") - raise - finally: - async with self.pool_lock: - if task_id in self.connection_pool: - await self.connection_pool[task_id].close() - del self.connection_pool[task_id] - - async def execute_with_retry(self, operation, *args): - """Execute database operations with retry logic""" - for attempt in range(self.max_retries): - try: - async with self.get_connection() as db: - result = await operation(db, *args) - await db.commit() - return result - except Exception as e: - if attempt == self.max_retries - 1: - logger.error(f"Operation failed after {self.max_retries} attempts: {e}") - raise - await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff - - async def ainit_db(self): - """Initialize database schema""" - async def _init(db): - await db.execute(''' - CREATE TABLE IF NOT EXISTS crawled_data ( - url TEXT PRIMARY KEY, - html TEXT, - cleaned_html TEXT, - markdown TEXT, - extracted_content TEXT, - success BOOLEAN, - media TEXT DEFAULT "{}", - links TEXT DEFAULT "{}", - metadata TEXT DEFAULT "{}", - screenshot TEXT DEFAULT "", - response_headers TEXT DEFAULT "{}", - downloaded_files TEXT DEFAULT "{}" -- New column added - ) - ''') - - await self.execute_with_retry(_init) - await self.update_db_schema() - - async def update_db_schema(self): - """Update database schema if needed""" - async def _check_columns(db): - cursor = await db.execute("PRAGMA table_info(crawled_data)") - columns = await cursor.fetchall() - return [column[1] for column in columns] - - column_names = await self.execute_with_retry(_check_columns) - - # List of new columns to add - new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files'] - - for column in new_columns: - if column not in column_names: - await self.aalter_db_add_column(column) - - async def aalter_db_add_column(self, new_column: str): - """Add new column to the database""" - async def _alter(db): - if new_column == 'response_headers': - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"') - else: - await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""') - logger.info(f"Added column '{new_column}' to the database.") - - await self.execute_with_retry(_alter) - - - async def aget_cached_url(self, url: str) -> Optional[Tuple[str, str, str, str, str, bool, str, str, str, str]]: - """Retrieve cached URL data""" - async def _get(db): - async with db.execute( - ''' - SELECT url, html, cleaned_html, markdown, - extracted_content, success, media, links, - metadata, screenshot, response_headers, - downloaded_files - FROM crawled_data WHERE url = ? - ''', - (url,) - ) as cursor: - row = await cursor.fetchone() - if row: - # Load content from files using stored hashes - html = await self._load_content(row[1], 'html') if row[1] else "" - cleaned = await self._load_content(row[2], 'cleaned') if row[2] else "" - markdown = await self._load_content(row[3], 'markdown') if row[3] else "" - extracted = await self._load_content(row[4], 'extracted') if row[4] else "" - screenshot = await self._load_content(row[9], 'screenshots') if row[9] else "" - - return ( - row[0], # url - html or "", # Return empty string if file not found - cleaned or "", - markdown or "", - extracted or "", - row[5], # success - json.loads(row[6] or '{}'), # media - json.loads(row[7] or '{}'), # links - json.loads(row[8] or '{}'), # metadata - screenshot or "", - json.loads(row[10] or '{}'), # response_headers - json.loads(row[11] or '[]') # downloaded_files - ) - return None - - try: - return await self.execute_with_retry(_get) - except Exception as e: - logger.error(f"Error retrieving cached URL: {e}") - return None - - async def acache_url(self, url: str, html: str, cleaned_html: str, - markdown: str, extracted_content: str, success: bool, - media: str = "{}", links: str = "{}", - metadata: str = "{}", screenshot: str = "", - response_headers: str = "{}", downloaded_files: str = "[]"): - """Cache URL data with content stored in filesystem""" - - # Store content files and get hashes - html_hash = await self._store_content(html, 'html') - cleaned_hash = await self._store_content(cleaned_html, 'cleaned') - markdown_hash = await self._store_content(markdown, 'markdown') - extracted_hash = await self._store_content(extracted_content, 'extracted') - screenshot_hash = await self._store_content(screenshot, 'screenshots') - - async def _cache(db): - await db.execute(''' - INSERT INTO crawled_data ( - url, html, cleaned_html, markdown, - extracted_content, success, media, links, metadata, - screenshot, response_headers, downloaded_files - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(url) DO UPDATE SET - html = excluded.html, - cleaned_html = excluded.cleaned_html, - markdown = excluded.markdown, - extracted_content = excluded.extracted_content, - success = excluded.success, - media = excluded.media, - links = excluded.links, - metadata = excluded.metadata, - screenshot = excluded.screenshot, - response_headers = excluded.response_headers, - downloaded_files = excluded.downloaded_files - ''', (url, html_hash, cleaned_hash, markdown_hash, extracted_hash, - success, media, links, metadata, screenshot_hash, - response_headers, downloaded_files)) - - try: - await self.execute_with_retry(_cache) - except Exception as e: - logger.error(f"Error caching URL: {e}") - - - - async def aget_total_count(self) -> int: - """Get total number of cached URLs""" - async def _count(db): - async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor: - result = await cursor.fetchone() - return result[0] if result else 0 - - try: - return await self.execute_with_retry(_count) - except Exception as e: - logger.error(f"Error getting total count: {e}") - return 0 - - async def aclear_db(self): - """Clear all data from the database""" - async def _clear(db): - await db.execute('DELETE FROM crawled_data') - - try: - await self.execute_with_retry(_clear) - except Exception as e: - logger.error(f"Error clearing database: {e}") - - async def aflush_db(self): - """Drop the entire table""" - async def _flush(db): - await db.execute('DROP TABLE IF EXISTS crawled_data') - - try: - await self.execute_with_retry(_flush) - except Exception as e: - logger.error(f"Error flushing database: {e}") - - - async def _store_content(self, content: str, content_type: str) -> str: - """Store content in filesystem and return hash""" - if not content: - return "" - - content_hash = generate_content_hash(content) - file_path = os.path.join(self.content_paths[content_type], content_hash) - - # Only write if file doesn't exist - if not os.path.exists(file_path): - async with aiofiles.open(file_path, 'w', encoding='utf-8') as f: - await f.write(content) - - return content_hash - - async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]: - """Load content from filesystem by hash""" - if not content_hash: - return None - - file_path = os.path.join(self.content_paths[content_type], content_hash) - try: - async with aiofiles.open(file_path, 'r', encoding='utf-8') as f: - return await f.read() - except: - logger.error(f"Failed to load content: {file_path}") - return None - -# Create a singleton instance -async_db_manager = AsyncDatabaseManager() diff --git a/crawl4ai/async_webcrawler.3.73.py b/crawl4ai/async_webcrawler.3.73.py deleted file mode 100644 index 03e7a393..00000000 --- a/crawl4ai/async_webcrawler.3.73.py +++ /dev/null @@ -1,344 +0,0 @@ -import os -import time -from pathlib import Path -from typing import Optional -import json -import asyncio -from .models import CrawlResult -from .async_database import async_db_manager -from .chunking_strategy import * -from .extraction_strategy import * -from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse -from .content_scrapping_strategy import WebScrapingStrategy -from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD -from .utils import ( - sanitize_input_encode, - InvalidCSSSelectorError, - format_html -) -from .__version__ import __version__ as crawl4ai_version - -class AsyncWebCrawler: - def __init__( - self, - crawler_strategy: Optional[AsyncCrawlerStrategy] = None, - always_by_pass_cache: bool = False, - base_directory: str = str(Path.home()), - **kwargs, - ): - self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy( - **kwargs - ) - self.always_by_pass_cache = always_by_pass_cache - # self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai") - self.crawl4ai_folder = os.path.join(base_directory, ".crawl4ai") - os.makedirs(self.crawl4ai_folder, exist_ok=True) - os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) - self.ready = False - self.verbose = kwargs.get("verbose", False) - - async def __aenter__(self): - await self.crawler_strategy.__aenter__() - await self.awarmup() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.crawler_strategy.__aexit__(exc_type, exc_val, exc_tb) - - async def awarmup(self): - # Print a message for crawl4ai and its version - if self.verbose: - print(f"[LOG] 🚀 Crawl4AI {crawl4ai_version}") - print("[LOG] 🌤️ Warming up the AsyncWebCrawler") - # await async_db_manager.ainit_db() - # # await async_db_manager.initialize() - # await self.arun( - # url="https://google.com/", - # word_count_threshold=5, - # bypass_cache=False, - # verbose=False, - # ) - self.ready = True - if self.verbose: - print("[LOG] 🌞 AsyncWebCrawler is ready to crawl") - - async def arun( - self, - url: str, - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - disable_cache: bool = False, - no_cache_read: bool = False, - no_cache_write: bool = False, - **kwargs, - ) -> CrawlResult: - """ - Runs the crawler for a single source: URL (web, local file, or raw HTML). - - Args: - url (str): The URL to crawl. Supported prefixes: - - 'http://' or 'https://': Web URL to crawl. - - 'file://': Local file path to process. - - 'raw:': Raw HTML content to process. - ... [other existing parameters] - - Returns: - CrawlResult: The result of the crawling and processing. - """ - try: - if disable_cache: - bypass_cache = True - no_cache_read = True - no_cache_write = True - - extraction_strategy = extraction_strategy or NoExtractionStrategy() - extraction_strategy.verbose = verbose - if not isinstance(extraction_strategy, ExtractionStrategy): - raise ValueError("Unsupported extraction strategy") - if not isinstance(chunking_strategy, ChunkingStrategy): - raise ValueError("Unsupported chunking strategy") - - word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD) - - async_response: AsyncCrawlResponse = None - cached = None - screenshot_data = None - extracted_content = None - - is_web_url = url.startswith(('http://', 'https://')) - is_local_file = url.startswith("file://") - is_raw_html = url.startswith("raw:") - _url = url if not is_raw_html else "Raw HTML" - - start_time = time.perf_counter() - cached_result = None - if is_web_url and (not bypass_cache or not no_cache_read) and not self.always_by_pass_cache: - cached_result = await async_db_manager.aget_cached_url(url) - - if cached_result: - html = sanitize_input_encode(cached_result.html) - extracted_content = sanitize_input_encode(cached_result.extracted_content or "") - if screenshot: - screenshot_data = cached_result.screenshot - if not screenshot_data: - cached_result = None - if verbose: - print( - f"[LOG] 1️⃣ ✅ Page fetched (cache) for {_url}, success: {bool(html)}, time taken: {time.perf_counter() - start_time:.2f} seconds" - ) - - - if not cached or not html: - t1 = time.perf_counter() - - if user_agent: - self.crawler_strategy.update_user_agent(user_agent) - async_response: AsyncCrawlResponse = await self.crawler_strategy.crawl(url, screenshot=screenshot, **kwargs) - html = sanitize_input_encode(async_response.html) - screenshot_data = async_response.screenshot - t2 = time.perf_counter() - if verbose: - print( - f"[LOG] 1️⃣ ✅ Page fetched (no-cache) for {_url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds" - ) - - t1 = time.perf_counter() - crawl_result = await self.aprocess_html( - url=url, - html=html, - extracted_content=extracted_content, - word_count_threshold=word_count_threshold, - extraction_strategy=extraction_strategy, - chunking_strategy=chunking_strategy, - css_selector=css_selector, - screenshot=screenshot_data, - verbose=verbose, - is_cached=bool(cached), - async_response=async_response, - bypass_cache=bypass_cache, - is_web_url = is_web_url, - is_local_file = is_local_file, - is_raw_html = is_raw_html, - **kwargs, - ) - - if async_response: - crawl_result.status_code = async_response.status_code - crawl_result.response_headers = async_response.response_headers - crawl_result.downloaded_files = async_response.downloaded_files - else: - crawl_result.status_code = 200 - crawl_result.response_headers = cached_result.response_headers if cached_result else {} - - crawl_result.success = bool(html) - crawl_result.session_id = kwargs.get("session_id", None) - - if verbose: - print( - f"[LOG] 🔥 🚀 Crawling done for {_url}, success: {crawl_result.success}, time taken: {time.perf_counter() - start_time:.2f} seconds" - ) - - if not is_raw_html and not no_cache_write: - if not bool(cached_result) or kwargs.get("bypass_cache", False) or self.always_by_pass_cache: - await async_db_manager.acache_url(crawl_result) - - - return crawl_result - - except Exception as e: - if not hasattr(e, "msg"): - e.msg = str(e) - print(f"[ERROR] 🚫 arun(): Failed to crawl {_url}, error: {e.msg}") - return CrawlResult(url=url, html="", markdown = f"[ERROR] 🚫 arun(): Failed to crawl {_url}, error: {e.msg}", success=False, error_message=e.msg) - - async def arun_many( - self, - urls: List[str], - word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, - chunking_strategy: ChunkingStrategy = RegexChunking(), - bypass_cache: bool = False, - css_selector: str = None, - screenshot: bool = False, - user_agent: str = None, - verbose=True, - **kwargs, - ) -> List[CrawlResult]: - """ - Runs the crawler for multiple sources: URLs (web, local files, or raw HTML). - - Args: - urls (List[str]): A list of URLs with supported prefixes: - - 'http://' or 'https://': Web URL to crawl. - - 'file://': Local file path to process. - - 'raw:': Raw HTML content to process. - ... [other existing parameters] - - Returns: - List[CrawlResult]: The results of the crawling and processing. - """ - semaphore_count = kwargs.get('semaphore_count', 5) # Adjust as needed - semaphore = asyncio.Semaphore(semaphore_count) - - async def crawl_with_semaphore(url): - async with semaphore: - return await self.arun( - url, - word_count_threshold=word_count_threshold, - extraction_strategy=extraction_strategy, - chunking_strategy=chunking_strategy, - bypass_cache=bypass_cache, - css_selector=css_selector, - screenshot=screenshot, - user_agent=user_agent, - verbose=verbose, - **kwargs, - ) - - tasks = [crawl_with_semaphore(url) for url in urls] - results = await asyncio.gather(*tasks, return_exceptions=True) - return [result if not isinstance(result, Exception) else str(result) for result in results] - - async def aprocess_html( - self, - url: str, - html: str, - extracted_content: str, - word_count_threshold: int, - extraction_strategy: ExtractionStrategy, - chunking_strategy: ChunkingStrategy, - css_selector: str, - screenshot: str, - verbose: bool, - **kwargs, - ) -> CrawlResult: - t = time.perf_counter() - # Extract content from HTML - try: - _url = url if not kwargs.get("is_raw_html", False) else "Raw HTML" - t1 = time.perf_counter() - scrapping_strategy = WebScrapingStrategy() - # result = await scrapping_strategy.ascrap( - result = scrapping_strategy.scrap( - url, - html, - word_count_threshold=word_count_threshold, - css_selector=css_selector, - only_text=kwargs.get("only_text", False), - image_description_min_word_threshold=kwargs.get( - "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD - ), - **kwargs, - ) - - if result is None: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}") - except InvalidCSSSelectorError as e: - raise ValueError(str(e)) - except Exception as e: - raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") - - cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) - markdown = sanitize_input_encode(result.get("markdown", "")) - fit_markdown = sanitize_input_encode(result.get("fit_markdown", "")) - fit_html = sanitize_input_encode(result.get("fit_html", "")) - media = result.get("media", []) - links = result.get("links", []) - metadata = result.get("metadata", {}) - - if verbose: - print( - f"[LOG] 2️⃣ ✅ Scraping done for {_url}, success: True, time taken: {time.perf_counter() - t1:.2f} seconds" - ) - - if extracted_content is None and extraction_strategy and chunking_strategy and not isinstance(extraction_strategy, NoExtractionStrategy): - t1 = time.perf_counter() - # Check if extraction strategy is type of JsonCssExtractionStrategy - if isinstance(extraction_strategy, JsonCssExtractionStrategy) or isinstance(extraction_strategy, JsonCssExtractionStrategy): - extraction_strategy.verbose = verbose - extracted_content = extraction_strategy.run(url, [html]) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) - else: - sections = chunking_strategy.chunk(markdown) - extracted_content = extraction_strategy.run(url, sections) - extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False) - if verbose: - print( - f"[LOG] 3️⃣ ✅ Extraction done for {_url}, time taken: {time.perf_counter() - t1:.2f} seconds" - ) - - screenshot = None if not screenshot else screenshot - - return CrawlResult( - url=url, - html=html, - cleaned_html=format_html(cleaned_html), - markdown=markdown, - fit_markdown=fit_markdown, - fit_html= fit_html, - media=media, - links=links, - metadata=metadata, - screenshot=screenshot, - extracted_content=extracted_content, - success=True, - error_message="", - ) - - async def aclear_cache(self): - # await async_db_manager.aclear_db() - await async_db_manager.cleanup() - - async def aflush_cache(self): - await async_db_manager.aflush_db() - - async def aget_cache_size(self): - return await async_db_manager.aget_total_count() - - diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 7d1814b6..2ff7ce0f 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -7,14 +7,14 @@ from pathlib import Path from typing import Optional, List, Union import json import asyncio -from .models import CrawlResult +from .models import CrawlResult, MarkdownGenerationResult from .async_database import async_db_manager from .chunking_strategy import * from .content_filter_strategy import * from .extraction_strategy import * from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse from .cache_context import CacheMode, CacheContext, _legacy_to_cache_mode -from .content_scrapping_strategy import WebScrapingStrategy +from .content_scraping_strategy import WebScrapingStrategy from .async_logger import AsyncLogger from .config import ( @@ -476,7 +476,7 @@ class AsyncWebCrawler: html, word_count_threshold=word_count_threshold, css_selector=css_selector, - only_text=kwargs.get("only_text", False), + only_text=kwargs.pop("only_text", False), image_description_min_word_threshold=kwargs.get( "image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD ), @@ -491,6 +491,8 @@ class AsyncWebCrawler: except Exception as e: raise ValueError(f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}") + markdown_v2: MarkdownGenerationResult = result.get("markdown_v2", None) + cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) markdown = sanitize_input_encode(result.get("markdown", "")) fit_markdown = sanitize_input_encode(result.get("fit_markdown", "")) @@ -542,6 +544,7 @@ class AsyncWebCrawler: url=url, html=html, cleaned_html=format_html(cleaned_html), + markdown_v2=markdown_v2, markdown=markdown, fit_markdown=fit_markdown, fit_html= fit_html, diff --git a/crawl4ai/content_scrapping_strategy.py b/crawl4ai/content_scraping_strategy.py similarity index 84% rename from crawl4ai/content_scrapping_strategy.py rename to crawl4ai/content_scraping_strategy.py index 0f470671..3823a78d 100644 --- a/crawl4ai/content_scrapping_strategy.py +++ b/crawl4ai/content_scraping_strategy.py @@ -1,6 +1,6 @@ import re # Point 1: Pre-Compile Regular Expressions from abc import ABC, abstractmethod -from typing import Dict, Any +from typing import Dict, Any, Optional from bs4 import BeautifulSoup from concurrent.futures import ThreadPoolExecutor import asyncio, requests, re, os @@ -10,103 +10,19 @@ from urllib.parse import urljoin from requests.exceptions import InvalidSchema # from .content_cleaning_strategy import ContentCleaningStrategy from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter - +from .markdown_generation_strategy import MarkdownGenerationStrategy, DefaultMarkdownGenerationStrategy +from .models import MarkdownGenerationResult from .utils import ( sanitize_input_encode, sanitize_html, extract_metadata, InvalidCSSSelectorError, - # CustomHTML2Text, + CustomHTML2Text, normalize_url, is_external_url ) -from .html2text import HTML2Text -class CustomHTML2Text(HTML2Text): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.inside_pre = False - self.inside_code = False - self.preserve_tags = set() # Set of tags to preserve - self.current_preserved_tag = None - self.preserved_content = [] - self.preserve_depth = 0 - - # Configuration options - self.skip_internal_links = False - self.single_line_break = False - self.mark_code = False - self.include_sup_sub = False - self.body_width = 0 - self.ignore_mailto_links = True - self.ignore_links = False - self.escape_backslash = False - self.escape_dot = False - self.escape_plus = False - self.escape_dash = False - self.escape_snob = False - - def update_params(self, **kwargs): - """Update parameters and set preserved tags.""" - for key, value in kwargs.items(): - if key == 'preserve_tags': - self.preserve_tags = set(value) - else: - setattr(self, key, value) - - def handle_tag(self, tag, attrs, start): - # Handle preserved tags - if tag in self.preserve_tags: - if start: - if self.preserve_depth == 0: - self.current_preserved_tag = tag - self.preserved_content = [] - # Format opening tag with attributes - attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) - self.preserved_content.append(f'<{tag}{attr_str}>') - self.preserve_depth += 1 - return - else: - self.preserve_depth -= 1 - if self.preserve_depth == 0: - self.preserved_content.append(f'') - # Output the preserved HTML block with proper spacing - preserved_html = ''.join(self.preserved_content) - self.o('\n' + preserved_html + '\n') - self.current_preserved_tag = None - return - - # If we're inside a preserved tag, collect all content - if self.preserve_depth > 0: - if start: - # Format nested tags with attributes - attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) - self.preserved_content.append(f'<{tag}{attr_str}>') - else: - self.preserved_content.append(f'') - return - - # Handle pre tags - if tag == 'pre': - if start: - self.o('```\n') - self.inside_pre = True - else: - self.o('\n```') - self.inside_pre = False - # elif tag in ["h1", "h2", "h3", "h4", "h5", "h6"]: - # pass - else: - super().handle_tag(tag, attrs, start) - - def handle_data(self, data, entity_char=False): - """Override handle_data to capture content within preserved tags.""" - if self.preserve_depth > 0: - self.preserved_content.append(data) - return - super().handle_data(data, entity_char) - # Pre-compile regular expressions for Open Graph and Twitter metadata OG_REGEX = re.compile(r'^og:') TWITTER_REGEX = re.compile(r'^twitter:') @@ -164,6 +80,98 @@ class WebScrapingStrategy(ContentScrapingStrategy): async def ascrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]: return await asyncio.to_thread(self._get_content_of_website_optimized, url, html, **kwargs) + + def _generate_markdown_content(self, + cleaned_html: str, + html: str, + url: str, + success: bool, + **kwargs) -> Dict[str, Any]: + """Generate markdown content using either new strategy or legacy method. + + Args: + cleaned_html: Sanitized HTML content + html: Original HTML content + url: Base URL of the page + success: Whether scraping was successful + **kwargs: Additional options including: + - markdown_generator: Optional[MarkdownGenerationStrategy] + - html2text: Dict[str, Any] options for HTML2Text + - content_filter: Optional[RelevantContentFilter] + - fit_markdown: bool + - fit_markdown_user_query: Optional[str] + - fit_markdown_bm25_threshold: float + + Returns: + Dict containing markdown content in various formats + """ + markdown_generator: Optional[MarkdownGenerationStrategy] = kwargs.get('markdown_generator', DefaultMarkdownGenerationStrategy()) + + if markdown_generator: + try: + markdown_result = markdown_generator.generate_markdown( + cleaned_html=cleaned_html, + base_url=url, + html2text_options=kwargs.get('html2text', {}), + content_filter=kwargs.get('content_filter', None) + ) + + markdown_v2 = MarkdownGenerationResult( + raw_markdown=markdown_result.raw_markdown, + markdown_with_citations=markdown_result.markdown_with_citations, + references_markdown=markdown_result.references_markdown, + fit_markdown=markdown_result.fit_markdown + ) + + return { + 'markdown': markdown_result.raw_markdown, + 'fit_markdown': markdown_result.fit_markdown or "Set flag 'fit_markdown' to True to get cleaned HTML content.", + 'fit_html': kwargs.get('content_filter', None).filter_content(html) if kwargs.get('content_filter') else "Set flag 'fit_markdown' to True to get cleaned HTML content.", + 'markdown_v2': markdown_v2 + } + except Exception as e: + self._log('error', + message="Error using new markdown generation strategy: {error}", + tag="SCRAPE", + params={"error": str(e)} + ) + markdown_generator = None + + # Legacy method + h = CustomHTML2Text() + h.update_params(**kwargs.get('html2text', {})) + markdown = h.handle(cleaned_html) + markdown = markdown.replace(' ```', '```') + + fit_markdown = "Set flag 'fit_markdown' to True to get cleaned HTML content." + fit_html = "Set flag 'fit_markdown' to True to get cleaned HTML content." + + if kwargs.get('content_filter', None) or kwargs.get('fit_markdown', False): + content_filter = kwargs.get('content_filter', None) + if not content_filter: + content_filter = BM25ContentFilter( + user_query=kwargs.get('fit_markdown_user_query', None), + bm25_threshold=kwargs.get('fit_markdown_bm25_threshold', 1.0) + ) + fit_html = content_filter.filter_content(html) + fit_html = '\n'.join('
{}
'.format(s) for s in fit_html) + fit_markdown = h.handle(fit_html) + + markdown_v2 = MarkdownGenerationResult( + raw_markdown=markdown, + markdown_with_citations=markdown, + references_markdown=markdown, + fit_markdown=fit_markdown + ) + + return { + 'markdown': markdown, + 'fit_markdown': fit_markdown, + 'fit_html': fit_html, + 'markdown_v2' : markdown_v2 + } + + def _get_content_of_website_optimized(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, css_selector: str = None, **kwargs) -> Dict[str, Any]: success = True if not html: @@ -242,8 +250,6 @@ class WebScrapingStrategy(ContentScrapingStrategy): #Score an image for it's usefulness def score_image_for_usefulness(img, base_url, index, images_count): - - image_height = img.get('height') height_value, height_unit = parse_dimension(image_height) image_width = img.get('width') @@ -282,7 +288,7 @@ class WebScrapingStrategy(ContentScrapingStrategy): if not is_valid_image(img, img.parent, img.parent.get('class', [])): return None score = score_image_for_usefulness(img, url, index, total_images) - if score <= IMAGE_SCORE_THRESHOLD: + if score <= kwargs.get('image_score_threshold', IMAGE_SCORE_THRESHOLD): return None return { 'src': img.get('src', ''), @@ -545,41 +551,16 @@ class WebScrapingStrategy(ContentScrapingStrategy): cleaned_html = str_body.replace('\n\n', '\n').replace(' ', ' ') - try: - h = CustomHTML2Text() - h.update_params(**kwargs.get('html2text', {})) - markdown = h.handle(cleaned_html) - except Exception as e: - if not h: - h = CustomHTML2Text() - self._log('error', - message="Error converting HTML to markdown: {error}", - tag="SCRAPE", - params={"error": str(e)} - ) - markdown = h.handle(sanitize_html(cleaned_html)) - markdown = markdown.replace(' ```', '```') - + markdown_content = self._generate_markdown_content( + cleaned_html=cleaned_html, + html=html, + url=url, + success=success, + **kwargs + ) - - fit_markdown = "Set flag 'fit_markdown' to True to get cleaned HTML content." - fit_html = "Set flag 'fit_markdown' to True to get cleaned HTML content." - if kwargs.get('content_filter', None) or kwargs.get('fit_markdown', False): - content_filter = kwargs.get('content_filter', None) - if not content_filter: - content_filter = BM25ContentFilter( - user_query= kwargs.get('fit_markdown_user_query', None), - bm25_threshold= kwargs.get('fit_markdown_bm25_threshold', 1.0) - ) - fit_html = content_filter.filter_content(html) - fit_html = '\n'.join('
{}
'.format(s) for s in fit_html) - fit_markdown = h.handle(fit_html) - - cleaned_html = sanitize_html(cleaned_html) return { - 'markdown': markdown, - 'fit_markdown': fit_markdown, - 'fit_html': fit_html, + **markdown_content, 'cleaned_html': cleaned_html, 'success': success, 'media': media, diff --git a/crawl4ai/markdown_generation_strategy.py b/crawl4ai/markdown_generation_strategy.py new file mode 100644 index 00000000..1adb4c28 --- /dev/null +++ b/crawl4ai/markdown_generation_strategy.py @@ -0,0 +1,115 @@ +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, Tuple +from .models import MarkdownGenerationResult +from .utils import CustomHTML2Text +from .content_filter_strategy import RelevantContentFilter, BM25ContentFilter +import re +from urllib.parse import urljoin + +# Pre-compile the regex pattern +LINK_PATTERN = re.compile(r'!?\[([^\]]+)\]\(([^)]+?)(?:\s+"([^"]*)")?\)') + +class MarkdownGenerationStrategy(ABC): + """Abstract base class for markdown generation strategies.""" + + @abstractmethod + def generate_markdown(self, + cleaned_html: str, + base_url: str = "", + html2text_options: Optional[Dict[str, Any]] = None, + content_filter: Optional[RelevantContentFilter] = None, + citations: bool = True, + **kwargs) -> MarkdownGenerationResult: + """Generate markdown from cleaned HTML.""" + pass + +class DefaultMarkdownGenerationStrategy(MarkdownGenerationStrategy): + """Default implementation of markdown generation strategy.""" + + def convert_links_to_citations(self, markdown: str, base_url: str = "") -> Tuple[str, str]: + link_map = {} + url_cache = {} # Cache for URL joins + parts = [] + last_end = 0 + counter = 1 + + for match in LINK_PATTERN.finditer(markdown): + parts.append(markdown[last_end:match.start()]) + text, url, title = match.groups() + + # Use cached URL if available, otherwise compute and cache + if base_url and not url.startswith(('http://', 'https://', 'mailto:')): + if url not in url_cache: + url_cache[url] = fast_urljoin(base_url, url) + url = url_cache[url] + + if url not in link_map: + desc = [] + if title: desc.append(title) + if text and text != title: desc.append(text) + link_map[url] = (counter, ": " + " - ".join(desc) if desc else "") + counter += 1 + + num = link_map[url][0] + parts.append(f"{text}⟨{num}⟩" if not match.group(0).startswith('!') else f"![{text}⟨{num}⟩]") + last_end = match.end() + + parts.append(markdown[last_end:]) + converted_text = ''.join(parts) + + # Pre-build reference strings + references = ["\n\n## References\n\n"] + references.extend( + f"⟨{num}⟩ {url}{desc}\n" + for url, (num, desc) in sorted(link_map.items(), key=lambda x: x[1][0]) + ) + + return converted_text, ''.join(references) + + def generate_markdown(self, + cleaned_html: str, + base_url: str = "", + html2text_options: Optional[Dict[str, Any]] = None, + content_filter: Optional[RelevantContentFilter] = None, + citations: bool = True, + **kwargs) -> MarkdownGenerationResult: + """Generate markdown with citations from cleaned HTML.""" + # Initialize HTML2Text with options + h = CustomHTML2Text() + if html2text_options: + h.update_params(**html2text_options) + + # Generate raw markdown + raw_markdown = h.handle(cleaned_html) + raw_markdown = raw_markdown.replace(' ```', '```') + + # Convert links to citations + if citations: + markdown_with_citations, references_markdown = self.convert_links_to_citations( + raw_markdown, base_url + ) + + # Generate fit markdown if content filter is provided + fit_markdown: Optional[str] = None + if content_filter: + filtered_html = content_filter.filter_content(cleaned_html) + filtered_html = '\n'.join('
{}
'.format(s) for s in filtered_html) + fit_markdown = h.handle(filtered_html) + + return MarkdownGenerationResult( + raw_markdown=raw_markdown, + markdown_with_citations=markdown_with_citations, + references_markdown=references_markdown, + fit_markdown=fit_markdown + ) + +def fast_urljoin(base: str, url: str) -> str: + """Fast URL joining for common cases.""" + if url.startswith(('http://', 'https://', 'mailto:', '//')): + return url + if url.startswith('/'): + # Handle absolute paths + if base.endswith('/'): + return base[:-1] + url + return base + url + return urljoin(base, url) \ No newline at end of file diff --git a/crawl4ai/models.py b/crawl4ai/models.py index cab4c45b..122434ad 100644 --- a/crawl4ai/models.py +++ b/crawl4ai/models.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, HttpUrl -from typing import List, Dict, Optional, Callable, Awaitable +from typing import List, Dict, Optional, Callable, Awaitable, Union @@ -7,6 +7,12 @@ class UrlModel(BaseModel): url: HttpUrl forced: bool = False +class MarkdownGenerationResult(BaseModel): + raw_markdown: str + markdown_with_citations: str + references_markdown: str + fit_markdown: Optional[str] = None + class CrawlResult(BaseModel): url: str html: str @@ -16,7 +22,8 @@ class CrawlResult(BaseModel): links: Dict[str, List[Dict]] = {} downloaded_files: Optional[List[str]] = None screenshot: Optional[str] = None - markdown: Optional[str] = None + markdown: Optional[Union[str, MarkdownGenerationResult]] = None + markdown_v2: Optional[MarkdownGenerationResult] = None fit_markdown: Optional[str] = None fit_html: Optional[str] = None extracted_content: Optional[str] = None @@ -36,3 +43,5 @@ class AsyncCrawlResponse(BaseModel): class Config: arbitrary_types_allowed = True + + diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 9abc5784..b07562df 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -18,6 +18,94 @@ import hashlib from typing import Optional, Tuple, Dict, Any import xxhash + +from .html2text import HTML2Text +class CustomHTML2Text(HTML2Text): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inside_pre = False + self.inside_code = False + self.preserve_tags = set() # Set of tags to preserve + self.current_preserved_tag = None + self.preserved_content = [] + self.preserve_depth = 0 + + # Configuration options + self.skip_internal_links = False + self.single_line_break = False + self.mark_code = False + self.include_sup_sub = False + self.body_width = 0 + self.ignore_mailto_links = True + self.ignore_links = False + self.escape_backslash = False + self.escape_dot = False + self.escape_plus = False + self.escape_dash = False + self.escape_snob = False + + def update_params(self, **kwargs): + """Update parameters and set preserved tags.""" + for key, value in kwargs.items(): + if key == 'preserve_tags': + self.preserve_tags = set(value) + else: + setattr(self, key, value) + + def handle_tag(self, tag, attrs, start): + # Handle preserved tags + if tag in self.preserve_tags: + if start: + if self.preserve_depth == 0: + self.current_preserved_tag = tag + self.preserved_content = [] + # Format opening tag with attributes + attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) + self.preserved_content.append(f'<{tag}{attr_str}>') + self.preserve_depth += 1 + return + else: + self.preserve_depth -= 1 + if self.preserve_depth == 0: + self.preserved_content.append(f'') + # Output the preserved HTML block with proper spacing + preserved_html = ''.join(self.preserved_content) + self.o('\n' + preserved_html + '\n') + self.current_preserved_tag = None + return + + # If we're inside a preserved tag, collect all content + if self.preserve_depth > 0: + if start: + # Format nested tags with attributes + attr_str = ''.join(f' {k}="{v}"' for k, v in attrs.items() if v is not None) + self.preserved_content.append(f'<{tag}{attr_str}>') + else: + self.preserved_content.append(f'') + return + + # Handle pre tags + if tag == 'pre': + if start: + self.o('```\n') + self.inside_pre = True + else: + self.o('\n```') + self.inside_pre = False + # elif tag in ["h1", "h2", "h3", "h4", "h5", "h6"]: + # pass + else: + super().handle_tag(tag, attrs, start) + + def handle_data(self, data, entity_char=False): + """Override handle_data to capture content within preserved tags.""" + if self.preserve_depth > 0: + self.preserved_content.append(data) + return + super().handle_data(data, entity_char) + + + class InvalidCSSSelectorError(Exception): pass diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 6cfef6f0..a32a988d 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -10,7 +10,7 @@ from .extraction_strategy import * from .crawler_strategy import * from typing import List from concurrent.futures import ThreadPoolExecutor -from .content_scrapping_strategy import WebScrapingStrategy +from .content_scraping_strategy import WebScrapingStrategy from .config import * import warnings import json diff --git a/tests/async/test_content_scraper_strategy.py b/tests/async/test_content_scraper_strategy.py index 5dfa6362..62c49148 100644 --- a/tests/async/test_content_scraper_strategy.py +++ b/tests/async/test_content_scraper_strategy.py @@ -13,8 +13,8 @@ parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__f sys.path.append(parent_dir) __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) -from crawl4ai.content_scrapping_strategy import WebScrapingStrategy -from crawl4ai.content_scrapping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent +from crawl4ai.content_scraping_strategy import WebScrapingStrategy +from crawl4ai.content_scraping_strategy import WebScrapingStrategy as WebScrapingStrategyCurrent # from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent @dataclass diff --git a/tests/async/test_markdown_genertor.py b/tests/async/test_markdown_genertor.py new file mode 100644 index 00000000..025a0318 --- /dev/null +++ b/tests/async/test_markdown_genertor.py @@ -0,0 +1,165 @@ +# ## Issue #236 +# - **Last Updated:** 2024-11-11 01:42:14 +# - **Title:** [user data crawling opens two windows, unable to control correct user browser](https://github.com/unclecode/crawl4ai/issues/236) +# - **State:** open + +import os, sys, time +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +__location__ = os.path.realpath( os.path.join(os.getcwd(), os.path.dirname(__file__))) +import asyncio +import os +import time +from typing import Dict, Any +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerationStrategy + +# Get current directory +__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + +def print_test_result(name: str, result: Dict[str, Any], execution_time: float): + """Helper function to print test results.""" + print(f"\n{'='*20} {name} {'='*20}") + print(f"Execution time: {execution_time:.4f} seconds") + + + # Save markdown to files + for key, content in result.items(): + if isinstance(content, str): + with open(__location__ + f"/output/{name.lower()}_{key}.md", "w") as f: + f.write(content) + + # # Print first few lines of each markdown version + # for key, content in result.items(): + # if isinstance(content, str): + # preview = '\n'.join(content.split('\n')[:3]) + # print(f"\n{key} (first 3 lines):") + # print(preview) + # print(f"Total length: {len(content)} characters") + +def test_basic_markdown_conversion(): + """Test basic markdown conversion with links.""" + with open(__location__ + "/data/wikipedia.html", "r") as f: + cleaned_html = f.read() + + generator = DefaultMarkdownGenerationStrategy() + + start_time = time.perf_counter() + result = generator.generate_markdown( + cleaned_html=cleaned_html, + base_url="https://en.wikipedia.org" + ) + execution_time = time.perf_counter() - start_time + + print_test_result("Basic Markdown Conversion", { + 'raw': result.raw_markdown, + 'with_citations': result.markdown_with_citations, + 'references': result.references_markdown + }, execution_time) + + # Basic assertions + assert result.raw_markdown, "Raw markdown should not be empty" + assert result.markdown_with_citations, "Markdown with citations should not be empty" + assert result.references_markdown, "References should not be empty" + assert "⟨" in result.markdown_with_citations, "Citations should use ⟨⟩ brackets" + assert "## References" in result.references_markdown, "Should contain references section" + +def test_relative_links(): + """Test handling of relative links with base URL.""" + markdown = """ + Here's a [relative link](/wiki/Apple) and an [absolute link](https://example.com). + Also an [image](/images/test.png) and another [page](/wiki/Banana). + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://en.wikipedia.org" + ) + + assert "https://en.wikipedia.org/wiki/Apple" in result.references_markdown + assert "https://example.com" in result.references_markdown + assert "https://en.wikipedia.org/images/test.png" in result.references_markdown + +def test_duplicate_links(): + """Test handling of duplicate links.""" + markdown = """ + Here's a [link](/test) and another [link](/test) and a [different link](/other). + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://example.com" + ) + + # Count citations in markdown + citations = result.markdown_with_citations.count("⟨1⟩") + assert citations == 2, "Same link should use same citation number" + +def test_link_descriptions(): + """Test handling of link titles and descriptions.""" + markdown = """ + Here's a [link with title](/test "Test Title") and a [link with description](/other) to test. + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://example.com" + ) + + assert "Test Title" in result.references_markdown, "Link title should be in references" + assert "link with description" in result.references_markdown, "Link text should be in references" + +def test_performance_large_document(): + """Test performance with large document.""" + with open(__location__ + "/data/wikipedia.md", "r") as f: + markdown = f.read() + + # Test with multiple iterations + iterations = 5 + times = [] + + generator = DefaultMarkdownGenerationStrategy() + + for i in range(iterations): + start_time = time.perf_counter() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://en.wikipedia.org" + ) + end_time = time.perf_counter() + times.append(end_time - start_time) + + avg_time = sum(times) / len(times) + print(f"\n{'='*20} Performance Test {'='*20}") + print(f"Average execution time over {iterations} iterations: {avg_time:.4f} seconds") + print(f"Min time: {min(times):.4f} seconds") + print(f"Max time: {max(times):.4f} seconds") + +def test_image_links(): + """Test handling of image links.""" + markdown = """ + Here's an ![image](/image.png "Image Title") and another ![image](/other.jpg). + And a regular [link](/page). + """ + + generator = DefaultMarkdownGenerationStrategy() + result = generator.generate_markdown( + cleaned_html=markdown, + base_url="https://example.com" + ) + + assert "![" in result.markdown_with_citations, "Image markdown syntax should be preserved" + assert "Image Title" in result.references_markdown, "Image title should be in references" + +if __name__ == "__main__": + print("Running markdown generation strategy tests...") + + test_basic_markdown_conversion() + test_relative_links() + test_duplicate_links() + test_link_descriptions() + test_performance_large_document() + test_image_links() + \ No newline at end of file