Add Async Version, JsonCss Extrator
This commit is contained in:
254
crawl4ai/async_crawler_strategy.py
Normal file
254
crawl4ai/async_crawler_strategy.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import asyncio
|
||||
import base64, time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Any, List, Optional
|
||||
import os
|
||||
import psutil
|
||||
from playwright.async_api import async_playwright, Page, Browser, Error
|
||||
from io import BytesIO
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from .utils import sanitize_input_encode
|
||||
import json, uuid
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from playwright.async_api import ProxySettings
|
||||
|
||||
def calculate_semaphore_count():
|
||||
cpu_count = os.cpu_count()
|
||||
memory_gb = psutil.virtual_memory().total / (1024 ** 3) # Convert to GB
|
||||
base_count = max(1, cpu_count // 2)
|
||||
memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance
|
||||
return min(base_count, memory_based_cap)
|
||||
|
||||
class AsyncCrawlerStrategy(ABC):
|
||||
@abstractmethod
|
||||
async def crawl(self, url: str, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def crawl_many(self, urls: List[str], **kwargs) -> List[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def take_screenshot(self, url: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_user_agent(self, user_agent: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
pass
|
||||
|
||||
class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
||||
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||
self.use_cached_html = use_cached_html
|
||||
self.user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
|
||||
self.proxy = kwargs.get("proxy")
|
||||
self.headers = {}
|
||||
self.sessions = {}
|
||||
self.session_ttl = 1800
|
||||
self.js_code = js_code
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
self.playwright = None
|
||||
self.browser = None
|
||||
self.hooks = {
|
||||
'on_browser_created': None,
|
||||
'on_user_agent_updated': None,
|
||||
'on_execution_started': None,
|
||||
'before_goto': None,
|
||||
'after_goto': None,
|
||||
'before_return_html': None
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def start(self):
|
||||
if self.playwright is None:
|
||||
self.playwright = await async_playwright().start()
|
||||
if self.browser is None:
|
||||
browser_args = {
|
||||
"headless": True,
|
||||
# "headless": False,
|
||||
"args": [
|
||||
"--disable-gpu",
|
||||
"--disable-dev-shm-usage",
|
||||
"--disable-setuid-sandbox",
|
||||
"--no-sandbox",
|
||||
]
|
||||
}
|
||||
|
||||
# Add proxy settings if a proxy is specified
|
||||
if self.proxy:
|
||||
proxy_settings = ProxySettings(server=self.proxy)
|
||||
browser_args["proxy"] = proxy_settings
|
||||
|
||||
|
||||
self.browser = await self.playwright.chromium.launch(**browser_args)
|
||||
await self.execute_hook('on_browser_created', self.browser)
|
||||
|
||||
async def close(self):
|
||||
if self.browser:
|
||||
await self.browser.close()
|
||||
self.browser = None
|
||||
if self.playwright:
|
||||
await self.playwright.stop()
|
||||
self.playwright = None
|
||||
|
||||
def __del__(self):
|
||||
if self.browser or self.playwright:
|
||||
asyncio.get_event_loop().run_until_complete(self.close())
|
||||
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
if hook_type in self.hooks:
|
||||
self.hooks[hook_type] = hook
|
||||
else:
|
||||
raise ValueError(f"Invalid hook type: {hook_type}")
|
||||
|
||||
async def execute_hook(self, hook_type: str, *args):
|
||||
hook = self.hooks.get(hook_type)
|
||||
if hook:
|
||||
if asyncio.iscoroutinefunction(hook):
|
||||
return await hook(*args)
|
||||
else:
|
||||
return hook(*args)
|
||||
return args[0] if args else None
|
||||
|
||||
def update_user_agent(self, user_agent: str):
|
||||
self.user_agent = user_agent
|
||||
|
||||
def set_custom_headers(self, headers: Dict[str, str]):
|
||||
self.headers = headers
|
||||
|
||||
async def kill_session(self, session_id: str):
|
||||
if session_id in self.sessions:
|
||||
context, page, _ = self.sessions[session_id]
|
||||
await page.close()
|
||||
await context.close()
|
||||
del self.sessions[session_id]
|
||||
|
||||
def _cleanup_expired_sessions(self):
|
||||
current_time = time.time()
|
||||
expired_sessions = [sid for sid, (_, _, last_used) in self.sessions.items()
|
||||
if current_time - last_used > self.session_ttl]
|
||||
for sid in expired_sessions:
|
||||
asyncio.create_task(self.kill_session(sid))
|
||||
|
||||
async def crawl(self, url: str, **kwargs) -> str:
|
||||
self._cleanup_expired_sessions()
|
||||
session_id = kwargs.get("session_id")
|
||||
if session_id:
|
||||
context, page, _ = self.sessions.get(session_id, (None, None, None))
|
||||
if not context:
|
||||
context = await self.browser.new_context(
|
||||
user_agent=self.user_agent,
|
||||
proxy={"server": self.proxy} if self.proxy else None
|
||||
)
|
||||
await context.set_extra_http_headers(self.headers)
|
||||
page = await context.new_page()
|
||||
self.sessions[session_id] = (context, page, time.time())
|
||||
else:
|
||||
context = await self.browser.new_context(
|
||||
user_agent=self.user_agent,
|
||||
proxy={"server": self.proxy} if self.proxy else None
|
||||
)
|
||||
await context.set_extra_http_headers(self.headers)
|
||||
page = await context.new_page()
|
||||
|
||||
try:
|
||||
if self.verbose:
|
||||
print(f"[LOG] 🕸️ Crawling {url} using AsyncPlaywrightCrawlerStrategy...")
|
||||
|
||||
if self.use_cached_html:
|
||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
|
||||
if os.path.exists(cache_file_path):
|
||||
with open(cache_file_path, "r") as f:
|
||||
return f.read()
|
||||
|
||||
if not kwargs.get("js_only", False):
|
||||
await self.execute_hook('before_goto', page)
|
||||
await page.goto(url, wait_until="domcontentloaded", timeout=60000)
|
||||
await self.execute_hook('after_goto', page)
|
||||
|
||||
await page.wait_for_selector('body')
|
||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
|
||||
js_code = kwargs.get("js_code", kwargs.get("js", self.js_code))
|
||||
if js_code:
|
||||
if isinstance(js_code, str):
|
||||
await page.evaluate(js_code)
|
||||
elif isinstance(js_code, list):
|
||||
for js in js_code:
|
||||
await page.evaluate(js)
|
||||
|
||||
# await page.wait_for_timeout(100)
|
||||
await page.wait_for_load_state('networkidle')
|
||||
# Check for on execution even
|
||||
await self.execute_hook('on_execution_started', page)
|
||||
|
||||
html = await page.content()
|
||||
page = await self.execute_hook('before_return_html', page, html)
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] ✅ Crawled {url} successfully!")
|
||||
|
||||
if self.use_cached_html:
|
||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
|
||||
with open(cache_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(html)
|
||||
|
||||
return html
|
||||
except Error as e:
|
||||
raise Error(f"Failed to crawl {url}: {str(e)}")
|
||||
finally:
|
||||
if not session_id:
|
||||
await page.close()
|
||||
|
||||
# try:
|
||||
# html = await _crawl()
|
||||
# return sanitize_input_encode(html)
|
||||
# except Error as e:
|
||||
# raise Error(f"Failed to crawl {url}: {str(e)}")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to crawl {url}: {str(e)}")
|
||||
|
||||
async def crawl_many(self, urls: List[str], **kwargs) -> List[str]:
|
||||
semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count())
|
||||
semaphore = asyncio.Semaphore(semaphore_count)
|
||||
|
||||
async def crawl_with_semaphore(url):
|
||||
async with semaphore:
|
||||
return await self.crawl(url, **kwargs)
|
||||
|
||||
tasks = [crawl_with_semaphore(url) for url in urls]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return [result if not isinstance(result, Exception) else str(result) for result in results]
|
||||
|
||||
async def take_screenshot(self, url: str) -> str:
|
||||
async with await self.browser.new_context(user_agent=self.user_agent) as context:
|
||||
page = await context.new_page()
|
||||
try:
|
||||
await page.goto(url, wait_until="domcontentloaded")
|
||||
screenshot = await page.screenshot(full_page=True)
|
||||
return base64.b64encode(screenshot).decode('utf-8')
|
||||
except Exception as e:
|
||||
error_message = f"Failed to take screenshot: {str(e)}"
|
||||
print(error_message)
|
||||
|
||||
# Generate an error image
|
||||
img = Image.new('RGB', (800, 600), color='black')
|
||||
draw = ImageDraw.Draw(img)
|
||||
font = ImageFont.load_default()
|
||||
draw.text((10, 10), error_message, fill=(255, 255, 255), font=font)
|
||||
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
finally:
|
||||
await page.close()
|
||||
97
crawl4ai/async_database.py
Normal file
97
crawl4ai/async_database.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
DB_PATH = os.path.join(Path.home(), ".crawl4ai")
|
||||
os.makedirs(DB_PATH, exist_ok=True)
|
||||
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
|
||||
|
||||
class AsyncDatabaseManager:
|
||||
def __init__(self):
|
||||
self.db_path = DB_PATH
|
||||
|
||||
async def ainit_db(self):
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute('''
|
||||
CREATE TABLE IF NOT EXISTS crawled_data (
|
||||
url TEXT PRIMARY KEY,
|
||||
html TEXT,
|
||||
cleaned_html TEXT,
|
||||
markdown TEXT,
|
||||
extracted_content TEXT,
|
||||
success BOOLEAN,
|
||||
media TEXT DEFAULT "{}",
|
||||
links TEXT DEFAULT "{}",
|
||||
metadata TEXT DEFAULT "{}",
|
||||
screenshot TEXT DEFAULT ""
|
||||
)
|
||||
''')
|
||||
await db.commit()
|
||||
|
||||
async def aalter_db_add_screenshot(self, new_column: str = "media"):
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
print(f"Error altering database to add screenshot column: {e}")
|
||||
|
||||
async def aget_cached_url(self, url: str) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]:
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot FROM crawled_data WHERE url = ?', (url,)) as cursor:
|
||||
return await cursor.fetchone()
|
||||
except Exception as e:
|
||||
print(f"Error retrieving cached URL: {e}")
|
||||
return None
|
||||
|
||||
async def acache_url(self, url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media: str = "{}", links: str = "{}", metadata: str = "{}", screenshot: str = ""):
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute('''
|
||||
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(url) DO UPDATE SET
|
||||
html = excluded.html,
|
||||
cleaned_html = excluded.cleaned_html,
|
||||
markdown = excluded.markdown,
|
||||
extracted_content = excluded.extracted_content,
|
||||
success = excluded.success,
|
||||
media = excluded.media,
|
||||
links = excluded.links,
|
||||
metadata = excluded.metadata,
|
||||
screenshot = excluded.screenshot
|
||||
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot))
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
print(f"Error caching URL: {e}")
|
||||
|
||||
async def aget_total_count(self) -> int:
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor:
|
||||
result = await cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
except Exception as e:
|
||||
print(f"Error getting total count: {e}")
|
||||
return 0
|
||||
|
||||
async def aclear_db(self):
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute('DELETE FROM crawled_data')
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
print(f"Error clearing database: {e}")
|
||||
|
||||
async def aflush_db(self):
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute('DROP TABLE IF EXISTS crawled_data')
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
print(f"Error flushing database: {e}")
|
||||
|
||||
async_db_manager = AsyncDatabaseManager()
|
||||
269
crawl4ai/async_webcrawler.py
Normal file
269
crawl4ai/async_webcrawler.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import json
|
||||
import asyncio
|
||||
from .models import CrawlResult
|
||||
from .async_database import async_db_manager
|
||||
from .chunking_strategy import *
|
||||
from .extraction_strategy import *
|
||||
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy
|
||||
from .content_scrapping_strategy import WebScrappingStrategy
|
||||
from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
|
||||
from .utils import (
|
||||
sanitize_input_encode,
|
||||
InvalidCSSSelectorError,
|
||||
format_html
|
||||
)
|
||||
|
||||
|
||||
class AsyncWebCrawler:
|
||||
def __init__(
|
||||
self,
|
||||
crawler_strategy: Optional[AsyncCrawlerStrategy] = None,
|
||||
always_by_pass_cache: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self.crawler_strategy = crawler_strategy or AsyncPlaywrightCrawlerStrategy(
|
||||
verbose=verbose
|
||||
)
|
||||
self.always_by_pass_cache = always_by_pass_cache
|
||||
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
|
||||
os.makedirs(self.crawl4ai_folder, exist_ok=True)
|
||||
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
|
||||
self.ready = False
|
||||
self.verbose = verbose
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.crawler_strategy.__aenter__()
|
||||
await self.awarmup()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.crawler_strategy.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def awarmup(self):
|
||||
print("[LOG] 🌤️ Warming up the AsyncWebCrawler")
|
||||
await async_db_manager.ainit_db()
|
||||
await self.arun(
|
||||
url="https://google.com/",
|
||||
word_count_threshold=5,
|
||||
bypass_cache=False,
|
||||
verbose=False,
|
||||
)
|
||||
self.ready = True
|
||||
print("[LOG] 🌞 AsyncWebCrawler is ready to crawl")
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
url: str,
|
||||
word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
extraction_strategy: ExtractionStrategy = None,
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
bypass_cache: bool = False,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
user_agent: str = None,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
try:
|
||||
extraction_strategy = extraction_strategy or NoExtractionStrategy()
|
||||
extraction_strategy.verbose = verbose
|
||||
if not isinstance(extraction_strategy, ExtractionStrategy):
|
||||
raise ValueError("Unsupported extraction strategy")
|
||||
if not isinstance(chunking_strategy, ChunkingStrategy):
|
||||
raise ValueError("Unsupported chunking strategy")
|
||||
|
||||
word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)
|
||||
|
||||
cached = None
|
||||
screenshot_data = None
|
||||
extracted_content = None
|
||||
if not bypass_cache and not self.always_by_pass_cache:
|
||||
cached = await async_db_manager.aget_cached_url(url)
|
||||
|
||||
if kwargs.get("warmup", True) and not self.ready:
|
||||
return None
|
||||
|
||||
if cached:
|
||||
html = sanitize_input_encode(cached[1])
|
||||
extracted_content = sanitize_input_encode(cached[4])
|
||||
if screenshot:
|
||||
screenshot_data = cached[9]
|
||||
if not screenshot_data:
|
||||
cached = None
|
||||
|
||||
if not cached or not html:
|
||||
t1 = time.time()
|
||||
if user_agent:
|
||||
self.crawler_strategy.update_user_agent(user_agent)
|
||||
html = await self.crawler_strategy.crawl(url, **kwargs)
|
||||
t2 = time.time()
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🚀 Crawling done for {url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
|
||||
)
|
||||
if screenshot:
|
||||
screenshot_data = await self.crawler_strategy.take_screenshot(url)
|
||||
|
||||
crawl_result = await self.aprocess_html(
|
||||
url,
|
||||
html,
|
||||
extracted_content,
|
||||
word_count_threshold,
|
||||
extraction_strategy,
|
||||
chunking_strategy,
|
||||
css_selector,
|
||||
screenshot_data,
|
||||
verbose,
|
||||
bool(cached),
|
||||
**kwargs,
|
||||
)
|
||||
crawl_result.success = bool(html)
|
||||
crawl_result.session_id = kwargs.get("session_id", None)
|
||||
return crawl_result
|
||||
except Exception as e:
|
||||
if not hasattr(e, "msg"):
|
||||
e.msg = str(e)
|
||||
print(f"[ERROR] 🚫 Failed to crawl {url}, error: {e.msg}")
|
||||
return CrawlResult(url=url, html="", success=False, error_message=e.msg)
|
||||
|
||||
async def arun_many(
|
||||
self,
|
||||
urls: List[str],
|
||||
word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
extraction_strategy: ExtractionStrategy = None,
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
bypass_cache: bool = False,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
user_agent: str = None,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
) -> List[CrawlResult]:
|
||||
tasks = [
|
||||
self.arun(
|
||||
url,
|
||||
word_count_threshold,
|
||||
extraction_strategy,
|
||||
chunking_strategy,
|
||||
bypass_cache,
|
||||
css_selector,
|
||||
screenshot,
|
||||
user_agent,
|
||||
verbose,
|
||||
**kwargs
|
||||
)
|
||||
for url in urls
|
||||
]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def aprocess_html(
|
||||
self,
|
||||
url: str,
|
||||
html: str,
|
||||
extracted_content: str,
|
||||
word_count_threshold: int,
|
||||
extraction_strategy: ExtractionStrategy,
|
||||
chunking_strategy: ChunkingStrategy,
|
||||
css_selector: str,
|
||||
screenshot: str,
|
||||
verbose: bool,
|
||||
is_cached: bool,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
t = time.time()
|
||||
# Extract content from HTML
|
||||
try:
|
||||
t1 = time.time()
|
||||
scrapping_strategy = WebScrappingStrategy()
|
||||
result = await scrapping_strategy.ascrap(
|
||||
url,
|
||||
html,
|
||||
word_count_threshold=word_count_threshold,
|
||||
css_selector=css_selector,
|
||||
only_text=kwargs.get("only_text", False),
|
||||
image_description_min_word_threshold=kwargs.get(
|
||||
"image_description_min_word_threshold", IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
|
||||
),
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🚀 Content extracted for {url}, success: True, time taken: {time.time() - t1:.2f} seconds"
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise ValueError(f"Failed to extract content from the website: {url}")
|
||||
except InvalidCSSSelectorError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to extract content from the website: {url}, error: {str(e)}")
|
||||
|
||||
cleaned_html = sanitize_input_encode(result.get("cleaned_html", ""))
|
||||
markdown = sanitize_input_encode(result.get("markdown", ""))
|
||||
media = result.get("media", [])
|
||||
links = result.get("links", [])
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
if extracted_content is None and extraction_strategy and chunking_strategy:
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Check if extraction strategy is type of JsonCssExtractionStrategy
|
||||
if isinstance(extraction_strategy, JsonCssExtractionStrategy) or isinstance(extraction_strategy, EnhancedJsonCssExtractionStrategy):
|
||||
extraction_strategy.verbose = verbose
|
||||
extracted_content = extraction_strategy.run(url, [html])
|
||||
extracted_content = json.dumps(extracted_content, indent=4, default=str)
|
||||
else:
|
||||
sections = chunking_strategy.chunk(markdown)
|
||||
extracted_content = extraction_strategy.run(url, sections)
|
||||
extracted_content = json.dumps(extracted_content, indent=4, default=str)
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds."
|
||||
)
|
||||
|
||||
screenshot = None if not screenshot else screenshot
|
||||
|
||||
if not is_cached:
|
||||
await async_db_manager.acache_url(
|
||||
url,
|
||||
html,
|
||||
cleaned_html,
|
||||
markdown,
|
||||
extracted_content,
|
||||
True,
|
||||
json.dumps(media),
|
||||
json.dumps(links),
|
||||
json.dumps(metadata),
|
||||
screenshot=screenshot,
|
||||
)
|
||||
|
||||
return CrawlResult(
|
||||
url=url,
|
||||
html=html,
|
||||
cleaned_html=format_html(cleaned_html),
|
||||
markdown=markdown,
|
||||
media=media,
|
||||
links=links,
|
||||
metadata=metadata,
|
||||
screenshot=screenshot,
|
||||
extracted_content=extracted_content,
|
||||
success=True,
|
||||
error_message="",
|
||||
)
|
||||
|
||||
async def aclear_cache(self):
|
||||
await async_db_manager.aclear_db()
|
||||
|
||||
async def aflush_cache(self):
|
||||
await async_db_manager.aflush_db()
|
||||
|
||||
async def aget_cache_size(self):
|
||||
return await async_db_manager.aget_total_count()
|
||||
283
crawl4ai/content_scrapping_strategy.py
Normal file
283
crawl4ai/content_scrapping_strategy.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from bs4 import BeautifulSoup
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio, requests, re, os
|
||||
from .config import *
|
||||
from bs4 import element, NavigableString, Comment
|
||||
from urllib.parse import urljoin
|
||||
from requests.exceptions import InvalidSchema
|
||||
|
||||
from .utils import (
|
||||
sanitize_input_encode,
|
||||
sanitize_html,
|
||||
extract_metadata,
|
||||
InvalidCSSSelectorError,
|
||||
CustomHTML2Text
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ContentScrappingStrategy(ABC):
|
||||
@abstractmethod
|
||||
def scrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def ascrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
class WebScrappingStrategy(ContentScrappingStrategy):
|
||||
def scrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]:
|
||||
return self._get_content_of_website_optimized(url, html, is_async=False, **kwargs)
|
||||
|
||||
async def ascrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]:
|
||||
return await asyncio.to_thread(self._get_content_of_website_optimized, url, html, **kwargs)
|
||||
|
||||
def _get_content_of_website_optimized(self, url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, css_selector: str = None, **kwargs) -> Dict[str, Any]:
|
||||
if not html:
|
||||
return None
|
||||
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
body = soup.body
|
||||
|
||||
image_description_min_word_threshold = kwargs.get('image_description_min_word_threshold', IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD)
|
||||
|
||||
if css_selector:
|
||||
selected_elements = body.select(css_selector)
|
||||
if not selected_elements:
|
||||
raise InvalidCSSSelectorError(f"Invalid CSS selector, No elements found for CSS selector: {css_selector}")
|
||||
body = soup.new_tag('div')
|
||||
for el in selected_elements:
|
||||
body.append(el)
|
||||
|
||||
links = {'internal': [], 'external': []}
|
||||
media = {'images': [], 'videos': [], 'audios': []}
|
||||
|
||||
# Extract meaningful text for media files from closest parent
|
||||
def find_closest_parent_with_useful_text(tag):
|
||||
current_tag = tag
|
||||
while current_tag:
|
||||
current_tag = current_tag.parent
|
||||
# Get the text content of the parent tag
|
||||
if current_tag:
|
||||
text_content = current_tag.get_text(separator=' ',strip=True)
|
||||
# Check if the text content has at least word_count_threshold
|
||||
if len(text_content.split()) >= image_description_min_word_threshold:
|
||||
return text_content
|
||||
return None
|
||||
|
||||
def process_image(img, url, index, total_images):
|
||||
#Check if an image has valid display and inside undesired html elements
|
||||
def is_valid_image(img, parent, parent_classes):
|
||||
style = img.get('style', '')
|
||||
src = img.get('src', '')
|
||||
classes_to_check = ['button', 'icon', 'logo']
|
||||
tags_to_check = ['button', 'input']
|
||||
return all([
|
||||
'display:none' not in style,
|
||||
src,
|
||||
not any(s in var for var in [src, img.get('alt', ''), *parent_classes] for s in classes_to_check),
|
||||
parent.name not in tags_to_check
|
||||
])
|
||||
|
||||
#Score an image for it's usefulness
|
||||
def score_image_for_usefulness(img, base_url, index, images_count):
|
||||
# Function to parse image height/width value and units
|
||||
def parse_dimension(dimension):
|
||||
if dimension:
|
||||
match = re.match(r"(\d+)(\D*)", dimension)
|
||||
if match:
|
||||
number = int(match.group(1))
|
||||
unit = match.group(2) or 'px' # Default unit is 'px' if not specified
|
||||
return number, unit
|
||||
return None, None
|
||||
|
||||
# Fetch image file metadata to extract size and extension
|
||||
def fetch_image_file_size(img, base_url):
|
||||
#If src is relative path construct full URL, if not it may be CDN URL
|
||||
img_url = urljoin(base_url,img.get('src'))
|
||||
try:
|
||||
response = requests.head(img_url)
|
||||
if response.status_code == 200:
|
||||
return response.headers.get('Content-Length',None)
|
||||
else:
|
||||
print(f"Failed to retrieve file size for {img_url}")
|
||||
return None
|
||||
except InvalidSchema as e:
|
||||
return None
|
||||
finally:
|
||||
return
|
||||
|
||||
image_height = img.get('height')
|
||||
height_value, height_unit = parse_dimension(image_height)
|
||||
image_width = img.get('width')
|
||||
width_value, width_unit = parse_dimension(image_width)
|
||||
image_size = 0 #int(fetch_image_file_size(img,base_url) or 0)
|
||||
image_format = os.path.splitext(img.get('src',''))[1].lower()
|
||||
# Remove . from format
|
||||
image_format = image_format.strip('.')
|
||||
score = 0
|
||||
if height_value:
|
||||
if height_unit == 'px' and height_value > 150:
|
||||
score += 1
|
||||
if height_unit in ['%','vh','vmin','vmax'] and height_value >30:
|
||||
score += 1
|
||||
if width_value:
|
||||
if width_unit == 'px' and width_value > 150:
|
||||
score += 1
|
||||
if width_unit in ['%','vh','vmin','vmax'] and width_value >30:
|
||||
score += 1
|
||||
if image_size > 10000:
|
||||
score += 1
|
||||
if img.get('alt') != '':
|
||||
score+=1
|
||||
if any(image_format==format for format in ['jpg','png','webp']):
|
||||
score+=1
|
||||
if index/images_count<0.5:
|
||||
score+=1
|
||||
return score
|
||||
|
||||
if not is_valid_image(img, img.parent, img.parent.get('class', [])):
|
||||
return None
|
||||
score = score_image_for_usefulness(img, url, index, total_images)
|
||||
if score <= IMAGE_SCORE_THRESHOLD:
|
||||
return None
|
||||
return {
|
||||
'src': img.get('src', ''),
|
||||
'alt': img.get('alt', ''),
|
||||
'desc': find_closest_parent_with_useful_text(img),
|
||||
'score': score,
|
||||
'type': 'image'
|
||||
}
|
||||
|
||||
def process_element(element: element.PageElement) -> bool:
|
||||
try:
|
||||
if isinstance(element, NavigableString):
|
||||
if isinstance(element, Comment):
|
||||
element.extract()
|
||||
return False
|
||||
|
||||
if element.name in ['script', 'style', 'link', 'meta', 'noscript']:
|
||||
if element.name == 'img':
|
||||
process_image(element, url, 0, 1)
|
||||
element.decompose()
|
||||
return False
|
||||
|
||||
keep_element = False
|
||||
|
||||
if element.name == 'a' and element.get('href'):
|
||||
href = element['href']
|
||||
url_base = url.split('/')[2]
|
||||
link_data = {'href': href, 'text': element.get_text()}
|
||||
if href.startswith('http') and url_base not in href:
|
||||
links['external'].append(link_data)
|
||||
else:
|
||||
links['internal'].append(link_data)
|
||||
keep_element = True
|
||||
|
||||
elif element.name == 'img':
|
||||
return True # Always keep image elements
|
||||
|
||||
elif element.name in ['video', 'audio']:
|
||||
media[f"{element.name}s"].append({
|
||||
'src': element.get('src'),
|
||||
'alt': element.get('alt'),
|
||||
'type': element.name,
|
||||
'description': find_closest_parent_with_useful_text(element)
|
||||
})
|
||||
source_tags = element.find_all('source')
|
||||
for source_tag in source_tags:
|
||||
media[f"{element.name}s"].append({
|
||||
'src': source_tag.get('src'),
|
||||
'alt': element.get('alt'),
|
||||
'type': element.name,
|
||||
'description': find_closest_parent_with_useful_text(element)
|
||||
})
|
||||
return True # Always keep video and audio elements
|
||||
|
||||
if element.name != 'pre':
|
||||
if element.name in ['b', 'i', 'u', 'span', 'del', 'ins', 'sub', 'sup', 'strong', 'em', 'code', 'kbd', 'var', 's', 'q', 'abbr', 'cite', 'dfn', 'time', 'small', 'mark']:
|
||||
if kwargs.get('only_text', False):
|
||||
element.replace_with(element.get_text())
|
||||
else:
|
||||
element.unwrap()
|
||||
elif element.name != 'img':
|
||||
element.attrs = {}
|
||||
|
||||
# Process children
|
||||
for child in list(element.children):
|
||||
if isinstance(child, NavigableString) and not isinstance(child, Comment):
|
||||
if len(child.strip()) > 0:
|
||||
keep_element = True
|
||||
else:
|
||||
if process_element(child):
|
||||
keep_element = True
|
||||
|
||||
|
||||
# Check word count
|
||||
if not keep_element:
|
||||
word_count = len(element.get_text(strip=True).split())
|
||||
keep_element = word_count >= word_count_threshold
|
||||
|
||||
if not keep_element:
|
||||
element.decompose()
|
||||
|
||||
return keep_element
|
||||
except Exception as e:
|
||||
print('Error processing element:', str(e))
|
||||
return False
|
||||
|
||||
#process images by filtering and extracting contextual text from the page
|
||||
# imgs = body.find_all('img')
|
||||
# media['images'] = [
|
||||
# result for result in
|
||||
# (process_image(img, url, i, len(imgs)) for i, img in enumerate(imgs))
|
||||
# if result is not None
|
||||
# ]
|
||||
|
||||
process_element(body)
|
||||
|
||||
# # Process images using ThreadPoolExecutor
|
||||
imgs = body.find_all('img')
|
||||
with ThreadPoolExecutor() as executor:
|
||||
image_results = list(executor.map(process_image, imgs, [url]*len(imgs), range(len(imgs)), [len(imgs)]*len(imgs)))
|
||||
media['images'] = [result for result in image_results if result is not None]
|
||||
|
||||
def flatten_nested_elements(node):
|
||||
if isinstance(node, NavigableString):
|
||||
return node
|
||||
if len(node.contents) == 1 and isinstance(node.contents[0], element.Tag) and node.contents[0].name == node.name:
|
||||
return flatten_nested_elements(node.contents[0])
|
||||
node.contents = [flatten_nested_elements(child) for child in node.contents]
|
||||
return node
|
||||
|
||||
body = flatten_nested_elements(body)
|
||||
base64_pattern = re.compile(r'data:image/[^;]+;base64,([^"]+)')
|
||||
for img in imgs:
|
||||
src = img.get('src', '')
|
||||
if base64_pattern.match(src):
|
||||
# Replace base64 data with empty string
|
||||
img['src'] = base64_pattern.sub('', src)
|
||||
cleaned_html = str(body).replace('\n\n', '\n').replace(' ', ' ')
|
||||
cleaned_html = sanitize_html(cleaned_html)
|
||||
|
||||
h = CustomHTML2Text()
|
||||
h.ignore_links = True
|
||||
markdown = h.handle(cleaned_html)
|
||||
markdown = markdown.replace(' ```', '```')
|
||||
|
||||
try:
|
||||
meta = extract_metadata(html, soup)
|
||||
except Exception as e:
|
||||
print('Error extracting metadata:', str(e))
|
||||
meta = {}
|
||||
|
||||
return {
|
||||
'markdown': markdown,
|
||||
'cleaned_html': cleaned_html,
|
||||
'success': True,
|
||||
'media': media,
|
||||
'links': links,
|
||||
'metadata': meta
|
||||
}
|
||||
@@ -623,3 +623,158 @@ class ContentSummarizationStrategy(ExtractionStrategy):
|
||||
# Sort summaries by the original section index to maintain order
|
||||
summaries.sort(key=lambda x: x[0])
|
||||
return [summary for _, summary in summaries]
|
||||
|
||||
|
||||
class JsonCssExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.schema = schema
|
||||
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
base_elements = soup.select(self.schema['baseSelector'])
|
||||
|
||||
results = []
|
||||
for element in base_elements:
|
||||
item = {}
|
||||
for field in self.schema['fields']:
|
||||
value = self._extract_field(element, field)
|
||||
if value is not None:
|
||||
item[field['name']] = value
|
||||
if item:
|
||||
results.append(item)
|
||||
|
||||
return results
|
||||
|
||||
def _extract_field(self, element, field):
|
||||
try:
|
||||
selected = element.select_one(field['selector'])
|
||||
if not selected:
|
||||
return None
|
||||
|
||||
if field['type'] == 'text':
|
||||
return selected.get_text(strip=True)
|
||||
elif field['type'] == 'attribute':
|
||||
return selected.get(field['attribute'])
|
||||
elif field['type'] == 'html':
|
||||
return str(selected)
|
||||
elif field['type'] == 'regex':
|
||||
text = selected.get_text(strip=True)
|
||||
match = re.search(field['pattern'], text)
|
||||
return match.group(1) if match else None
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error extracting field {field['name']}: {str(e)}")
|
||||
return None
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
combined_html = self.DEL.join(sections)
|
||||
return self.extract(url, combined_html, **kwargs)
|
||||
|
||||
|
||||
class EnhancedJsonCssExtractionStrategy(ExtractionStrategy):
|
||||
def __init__(self, schema: Dict[str, Any], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.schema = schema
|
||||
|
||||
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
base_elements = soup.select(self.schema['baseSelector'])
|
||||
|
||||
results = []
|
||||
for element in base_elements:
|
||||
item = self._extract_item(element, self.schema['fields'])
|
||||
if item:
|
||||
results.append(item)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _extract_field(self, element, field):
|
||||
try:
|
||||
if field['type'] == 'nested':
|
||||
nested_element = element.select_one(field['selector'])
|
||||
return self._extract_item(nested_element, field['fields']) if nested_element else {}
|
||||
|
||||
if field['type'] == 'list':
|
||||
elements = element.select(field['selector'])
|
||||
return [self._extract_list_item(el, field['fields']) for el in elements]
|
||||
|
||||
if field['type'] == 'nested_list':
|
||||
elements = element.select(field['selector'])
|
||||
return [self._extract_item(el, field['fields']) for el in elements]
|
||||
|
||||
return self._extract_single_field(element, field)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error extracting field {field['name']}: {str(e)}")
|
||||
return field.get('default')
|
||||
|
||||
def _extract_list_item(self, element, fields):
|
||||
item = {}
|
||||
for field in fields:
|
||||
value = self._extract_single_field(element, field)
|
||||
if value is not None:
|
||||
item[field['name']] = value
|
||||
return item
|
||||
|
||||
def _extract_single_field(self, element, field):
|
||||
if 'selector' in field:
|
||||
selected = element.select_one(field['selector'])
|
||||
if not selected:
|
||||
return field.get('default')
|
||||
else:
|
||||
selected = element
|
||||
|
||||
value = None
|
||||
if field['type'] == 'text':
|
||||
value = selected.get_text(strip=True)
|
||||
elif field['type'] == 'attribute':
|
||||
value = selected.get(field['attribute'])
|
||||
elif field['type'] == 'html':
|
||||
value = str(selected)
|
||||
elif field['type'] == 'regex':
|
||||
text = selected.get_text(strip=True)
|
||||
match = re.search(field['pattern'], text)
|
||||
value = match.group(1) if match else None
|
||||
|
||||
if 'transform' in field:
|
||||
value = self._apply_transform(value, field['transform'])
|
||||
|
||||
return value if value is not None else field.get('default')
|
||||
|
||||
def _extract_item(self, element, fields):
|
||||
item = {}
|
||||
for field in fields:
|
||||
if field['type'] == 'computed':
|
||||
value = self._compute_field(item, field)
|
||||
else:
|
||||
value = self._extract_field(element, field)
|
||||
if value is not None:
|
||||
item[field['name']] = value
|
||||
return item
|
||||
|
||||
def _apply_transform(self, value, transform):
|
||||
if transform == 'lowercase':
|
||||
return value.lower()
|
||||
elif transform == 'uppercase':
|
||||
return value.upper()
|
||||
elif transform == 'strip':
|
||||
return value.strip()
|
||||
return value
|
||||
|
||||
def _compute_field(self, item, field):
|
||||
try:
|
||||
if 'expression' in field:
|
||||
return eval(field['expression'], {}, item)
|
||||
elif 'function' in field:
|
||||
return field['function'](item)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error computing field {field['name']}: {str(e)}")
|
||||
return field.get('default')
|
||||
|
||||
def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]:
|
||||
combined_html = self.DEL.join(sections)
|
||||
return self.extract(url, combined_html, **kwargs)
|
||||
@@ -16,4 +16,5 @@ class CrawlResult(BaseModel):
|
||||
markdown: Optional[str] = None
|
||||
extracted_content: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
error_message: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
@@ -122,7 +122,7 @@ class WebCrawler:
|
||||
if not isinstance(chunking_strategy, ChunkingStrategy):
|
||||
raise ValueError("Unsupported chunking strategy")
|
||||
|
||||
word_count_threshold = max(word_count_threshold, 0)
|
||||
word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD)
|
||||
|
||||
cached = None
|
||||
screenshot_data = None
|
||||
|
||||
81
tests/async/test_basic_crawling.py
Normal file
81
tests/async/test_basic_crawling.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_crawl():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert result.url == url
|
||||
assert result.html
|
||||
assert result.markdown
|
||||
assert result.cleaned_html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_url():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.invalidurl12345.com"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert not result.success
|
||||
assert result.error_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_urls():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
urls = [
|
||||
"https://www.nbcnews.com/business",
|
||||
"https://www.example.com",
|
||||
"https://www.python.org"
|
||||
]
|
||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||
assert len(results) == len(urls)
|
||||
assert all(result.success for result in results)
|
||||
assert all(result.html for result in results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_javascript_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
js_code = "document.body.innerHTML = '<h1>Modified by JS</h1>';"
|
||||
url = "https://www.example.com"
|
||||
result = await crawler.arun(url=url, bypass_cache=True, js_code=js_code)
|
||||
assert result.success
|
||||
assert "<h1>Modified by JS</h1>" in result.html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_crawling_performance():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
urls = [
|
||||
"https://www.nbcnews.com/business",
|
||||
"https://www.example.com",
|
||||
"https://www.python.org",
|
||||
"https://www.github.com",
|
||||
"https://www.stackoverflow.com"
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = end_time - start_time
|
||||
print(f"Total time for concurrent crawling: {total_time:.2f} seconds")
|
||||
|
||||
assert all(result.success for result in results)
|
||||
assert len(results) == len(urls)
|
||||
|
||||
# Assert that concurrent crawling is faster than sequential
|
||||
# This multiplier may need adjustment based on the number of URLs and their complexity
|
||||
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
82
tests/async/test_caching.py
Normal file
82
tests/async/test_caching.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# First crawl (should not use cache)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
result1 = await crawler.arun(url=url, bypass_cache=True)
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
time_taken1 = end_time - start_time
|
||||
|
||||
assert result1.success
|
||||
|
||||
# Second crawl (should use cache)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
result2 = await crawler.arun(url=url, bypass_cache=False)
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
time_taken2 = end_time - start_time
|
||||
|
||||
assert result2.success
|
||||
assert time_taken2 < time_taken1 # Cached result should be faster
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bypass_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# First crawl
|
||||
result1 = await crawler.arun(url=url, bypass_cache=False)
|
||||
assert result1.success
|
||||
|
||||
# Second crawl with bypass_cache=True
|
||||
result2 = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result2.success
|
||||
|
||||
# Content should be different (or at least, not guaranteed to be the same)
|
||||
assert result1.html != result2.html or result1.markdown != result2.markdown
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# Crawl and cache
|
||||
await crawler.arun(url=url, bypass_cache=False)
|
||||
|
||||
# Clear cache
|
||||
await crawler.aclear_cache()
|
||||
|
||||
# Check cache size
|
||||
cache_size = await crawler.aget_cache_size()
|
||||
assert cache_size == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# Crawl and cache
|
||||
await crawler.arun(url=url, bypass_cache=False)
|
||||
|
||||
# Flush cache
|
||||
await crawler.aflush_cache()
|
||||
|
||||
# Check cache size
|
||||
cache_size = await crawler.aget_cache_size()
|
||||
assert cache_size == 0
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
87
tests/async/test_chunking_and_extraction_strategies.py
Normal file
87
tests/async/test_chunking_and_extraction_strategies.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
from crawl4ai.chunking_strategy import RegexChunking, NlpSentenceChunking
|
||||
from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regex_chunking():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
chunking_strategy = RegexChunking(patterns=["\n\n"])
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
chunking_strategy=chunking_strategy,
|
||||
bypass_cache=True
|
||||
)
|
||||
assert result.success
|
||||
assert result.extracted_content
|
||||
chunks = json.loads(result.extracted_content)
|
||||
assert len(chunks) > 1 # Ensure multiple chunks were created
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cosine_strategy():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
extraction_strategy = CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold=0.3)
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
extraction_strategy=extraction_strategy,
|
||||
bypass_cache=True
|
||||
)
|
||||
assert result.success
|
||||
assert result.extracted_content
|
||||
extracted_data = json.loads(result.extracted_content)
|
||||
assert len(extracted_data) > 0
|
||||
assert all('tags' in item for item in extracted_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_extraction_strategy():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
extraction_strategy = LLMExtractionStrategy(
|
||||
provider="openai/gpt-4o-mini",
|
||||
api_token=os.getenv('OPENAI_API_KEY'),
|
||||
instruction="Extract only content related to technology"
|
||||
)
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
extraction_strategy=extraction_strategy,
|
||||
bypass_cache=True
|
||||
)
|
||||
assert result.success
|
||||
assert result.extracted_content
|
||||
extracted_data = json.loads(result.extracted_content)
|
||||
assert len(extracted_data) > 0
|
||||
assert all('content' in item for item in extracted_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_chunking_and_extraction():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
chunking_strategy = RegexChunking(patterns=["\n\n"])
|
||||
extraction_strategy = CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold=0.3)
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
chunking_strategy=chunking_strategy,
|
||||
extraction_strategy=extraction_strategy,
|
||||
bypass_cache=True
|
||||
)
|
||||
assert result.success
|
||||
assert result.extracted_content
|
||||
extracted_data = json.loads(result.extracted_content)
|
||||
assert len(extracted_data) > 0
|
||||
assert all('tags' in item for item in extracted_data)
|
||||
assert all('content' in item for item in extracted_data)
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
90
tests/async/test_content_extraction.py
Normal file
90
tests/async/test_content_extraction.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_markdown():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert result.markdown
|
||||
assert isinstance(result.markdown, str)
|
||||
assert len(result.markdown) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_cleaned_html():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert result.cleaned_html
|
||||
assert isinstance(result.cleaned_html, str)
|
||||
assert len(result.cleaned_html) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_media():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert result.media
|
||||
media = result.media
|
||||
assert isinstance(media, dict)
|
||||
assert "images" in media
|
||||
assert isinstance(media["images"], list)
|
||||
for image in media["images"]:
|
||||
assert "src" in image
|
||||
assert "alt" in image
|
||||
assert "type" in image
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_links():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert result.links
|
||||
links = result.links
|
||||
assert isinstance(links, dict)
|
||||
assert "internal" in links
|
||||
assert "external" in links
|
||||
assert isinstance(links["internal"], list)
|
||||
assert isinstance(links["external"], list)
|
||||
for link in links["internal"] + links["external"]:
|
||||
assert "href" in link
|
||||
assert "text" in link
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_metadata():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert result.metadata
|
||||
metadata = result.metadata
|
||||
assert isinstance(metadata, dict)
|
||||
assert "title" in metadata
|
||||
assert isinstance(metadata["title"], str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_css_selector_extraction():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
css_selector = "h1, h2, h3"
|
||||
result = await crawler.arun(url=url, bypass_cache=True, css_selector=css_selector)
|
||||
assert result.success
|
||||
assert result.markdown
|
||||
assert all(heading in result.markdown for heading in ["#", "##", "###"])
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
68
tests/async/test_crawler_strategy.py
Normal file
68
tests/async/test_crawler_strategy.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_user_agent():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
custom_user_agent = "MyCustomUserAgent/1.0"
|
||||
crawler.crawler_strategy.update_user_agent(custom_user_agent)
|
||||
url = "https://httpbin.org/user-agent"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert custom_user_agent in result.html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_headers():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
custom_headers = {"X-Test-Header": "TestValue"}
|
||||
crawler.crawler_strategy.set_custom_headers(custom_headers)
|
||||
url = "https://httpbin.org/headers"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert "X-Test-Header" in result.html
|
||||
assert "TestValue" in result.html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_javascript_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
js_code = "document.body.innerHTML = '<h1>Modified by JS</h1>';"
|
||||
url = "https://www.example.com"
|
||||
result = await crawler.arun(url=url, bypass_cache=True, js_code=js_code)
|
||||
assert result.success
|
||||
assert "<h1>Modified by JS</h1>" in result.html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
async def test_hook(page):
|
||||
await page.evaluate("document.body.style.backgroundColor = 'red';")
|
||||
return page
|
||||
|
||||
crawler.crawler_strategy.set_hook('after_goto', test_hook)
|
||||
url = "https://www.example.com"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result.success
|
||||
assert "background-color: red" in result.html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.example.com"
|
||||
result = await crawler.arun(url=url, bypass_cache=True, screenshot=True)
|
||||
assert result.success
|
||||
assert result.screenshot
|
||||
assert isinstance(result.screenshot, str)
|
||||
assert len(result.screenshot) > 0
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
82
tests/async/test_database_operations.py
Normal file
82
tests/async/test_database_operations.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_url():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.example.com"
|
||||
# First run to cache the URL
|
||||
result1 = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result1.success
|
||||
|
||||
# Second run to retrieve from cache
|
||||
result2 = await crawler.arun(url=url, bypass_cache=False)
|
||||
assert result2.success
|
||||
assert result2.html == result1.html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bypass_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.python.org"
|
||||
# First run to cache the URL
|
||||
result1 = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result1.success
|
||||
|
||||
# Second run bypassing cache
|
||||
result2 = await crawler.arun(url=url, bypass_cache=True)
|
||||
assert result2.success
|
||||
assert result2.html != result1.html # Content might be different due to dynamic nature of websites
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_size():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
initial_size = await crawler.aget_cache_size()
|
||||
|
||||
url = "https://www.nbcnews.com/business"
|
||||
await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
new_size = await crawler.aget_cache_size()
|
||||
assert new_size == initial_size + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.example.org"
|
||||
await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
initial_size = await crawler.aget_cache_size()
|
||||
assert initial_size > 0
|
||||
|
||||
await crawler.aclear_cache()
|
||||
new_size = await crawler.aget_cache_size()
|
||||
assert new_size == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_cache():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.example.net"
|
||||
await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
initial_size = await crawler.aget_cache_size()
|
||||
assert initial_size > 0
|
||||
|
||||
await crawler.aflush_cache()
|
||||
new_size = await crawler.aget_cache_size()
|
||||
assert new_size == 0
|
||||
|
||||
# Try to retrieve the previously cached URL
|
||||
result = await crawler.arun(url=url, bypass_cache=False)
|
||||
assert result.success # The crawler should still succeed, but it will fetch the content anew
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
127
tests/async/test_edge_cases.py
Normal file
127
tests/async/test_edge_cases.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import pytest
|
||||
import json
|
||||
from bs4 import BeautifulSoup
|
||||
import asyncio
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_large_content_page():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# url = "https://en.wikipedia.org/wiki/List_of_largest_known_stars" # A page with a large table
|
||||
# result = await crawler.arun(url=url, bypass_cache=True)
|
||||
# assert result.success
|
||||
# assert len(result.html) > 1000000 # Expecting more than 1MB of content
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_minimal_content_page():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# url = "https://example.com" # A very simple page
|
||||
# result = await crawler.arun(url=url, bypass_cache=True)
|
||||
# assert result.success
|
||||
# assert len(result.html) < 10000 # Expecting less than 10KB of content
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_single_page_application():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# url = "https://reactjs.org/" # React's website is a SPA
|
||||
# result = await crawler.arun(url=url, bypass_cache=True)
|
||||
# assert result.success
|
||||
# assert "react" in result.html.lower()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_page_with_infinite_scroll():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# url = "https://news.ycombinator.com/" # Hacker News has infinite scroll
|
||||
# result = await crawler.arun(url=url, bypass_cache=True)
|
||||
# assert result.success
|
||||
# assert "hacker news" in result.html.lower()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_page_with_heavy_javascript():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# url = "https://www.airbnb.com/" # Airbnb uses a lot of JavaScript
|
||||
# result = await crawler.arun(url=url, bypass_cache=True)
|
||||
# assert result.success
|
||||
# assert "airbnb" in result.html.lower()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_page_with_mixed_content():
|
||||
# async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
# url = "https://github.com/" # GitHub has a mix of static and dynamic content
|
||||
# result = await crawler.arun(url=url, bypass_cache=True)
|
||||
# assert result.success
|
||||
# assert "github" in result.html.lower()
|
||||
|
||||
# Add this test to your existing test file
|
||||
@pytest.mark.asyncio
|
||||
async def test_typescript_commits_multi_page():
|
||||
first_commit = ""
|
||||
async def on_execution_started(page):
|
||||
nonlocal first_commit
|
||||
try:
|
||||
# Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
|
||||
while True:
|
||||
await page.wait_for_selector('li.Box-sc-g0xbh4-0 h4')
|
||||
commit = await page.query_selector('li.Box-sc-g0xbh4-0 h4')
|
||||
commit = await commit.evaluate('(element) => element.textContent')
|
||||
commit = re.sub(r'\s+', '', commit)
|
||||
if commit and commit != first_commit:
|
||||
first_commit = commit
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
except Exception as e:
|
||||
print(f"Warning: New content didn't appear after JavaScript execution: {e}")
|
||||
|
||||
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
crawler.crawler_strategy.set_hook('on_execution_started', on_execution_started)
|
||||
|
||||
url = "https://github.com/microsoft/TypeScript/commits/main"
|
||||
session_id = "typescript_commits_session"
|
||||
all_commits = []
|
||||
|
||||
js_next_page = """
|
||||
const button = document.querySelector('a[data-testid="pagination-next-button"]');
|
||||
if (button) button.click();
|
||||
"""
|
||||
|
||||
for page in range(3): # Crawl 3 pages
|
||||
result = await crawler.arun(
|
||||
url=url, # Only use URL for the first page
|
||||
session_id=session_id,
|
||||
css_selector="li.Box-sc-g0xbh4-0",
|
||||
js=js_next_page if page > 0 else None, # Don't click 'next' on the first page
|
||||
bypass_cache=True,
|
||||
js_only=page > 0 # Use js_only for subsequent pages
|
||||
)
|
||||
|
||||
assert result.success, f"Failed to crawl page {page + 1}"
|
||||
|
||||
# Parse the HTML and extract commits
|
||||
soup = BeautifulSoup(result.cleaned_html, 'html.parser')
|
||||
commits = soup.select("li")
|
||||
# Take first commit find h4 extract text
|
||||
first_commit = commits[0].find("h4").text
|
||||
first_commit = re.sub(r'\s+', '', first_commit)
|
||||
all_commits.extend(commits)
|
||||
|
||||
print(f"Page {page + 1}: Found {len(commits)} commits")
|
||||
|
||||
# Clean up the session
|
||||
await crawler.crawler_strategy.kill_session(session_id)
|
||||
|
||||
# Assertions
|
||||
assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}"
|
||||
|
||||
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
78
tests/async/test_error_handling.py
Normal file
78
tests/async/test_error_handling.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# import os
|
||||
# import sys
|
||||
# import pytest
|
||||
# import asyncio
|
||||
|
||||
# # Add the parent directory to the Python path
|
||||
# parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# sys.path.append(parent_dir)
|
||||
|
||||
# from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
# from crawl4ai.utils import InvalidCSSSelectorError
|
||||
|
||||
# class AsyncCrawlerWrapper:
|
||||
# def __init__(self):
|
||||
# self.crawler = None
|
||||
|
||||
# async def setup(self):
|
||||
# self.crawler = AsyncWebCrawler(verbose=True)
|
||||
# await self.crawler.awarmup()
|
||||
|
||||
# async def cleanup(self):
|
||||
# if self.crawler:
|
||||
# await self.crawler.aclear_cache()
|
||||
|
||||
# @pytest.fixture(scope="module")
|
||||
# def crawler_wrapper():
|
||||
# wrapper = AsyncCrawlerWrapper()
|
||||
# asyncio.get_event_loop().run_until_complete(wrapper.setup())
|
||||
# yield wrapper
|
||||
# asyncio.get_event_loop().run_until_complete(wrapper.cleanup())
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_network_error(crawler_wrapper):
|
||||
# url = "https://www.nonexistentwebsite123456789.com"
|
||||
# result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True)
|
||||
# assert not result.success
|
||||
# assert "Failed to crawl" in result.error_message
|
||||
|
||||
# # @pytest.mark.asyncio
|
||||
# # async def test_timeout_error(crawler_wrapper):
|
||||
# # # Simulating a timeout by using a very short timeout value
|
||||
# # url = "https://www.nbcnews.com/business"
|
||||
# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, timeout=0.001)
|
||||
# # assert not result.success
|
||||
# # assert "timeout" in result.error_message.lower()
|
||||
|
||||
# # @pytest.mark.asyncio
|
||||
# # async def test_invalid_css_selector(crawler_wrapper):
|
||||
# # url = "https://www.nbcnews.com/business"
|
||||
# # with pytest.raises(InvalidCSSSelectorError):
|
||||
# # await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, css_selector="invalid>>selector")
|
||||
|
||||
# # @pytest.mark.asyncio
|
||||
# # async def test_js_execution_error(crawler_wrapper):
|
||||
# # url = "https://www.nbcnews.com/business"
|
||||
# # invalid_js = "This is not valid JavaScript code;"
|
||||
# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, js=invalid_js)
|
||||
# # assert not result.success
|
||||
# # assert "JavaScript" in result.error_message
|
||||
|
||||
# # @pytest.mark.asyncio
|
||||
# # async def test_empty_page(crawler_wrapper):
|
||||
# # # Use a URL that typically returns an empty page
|
||||
# # url = "http://example.com/empty"
|
||||
# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True)
|
||||
# # assert result.success # The crawl itself should succeed
|
||||
# # assert not result.markdown.strip() # The markdown content should be empty or just whitespace
|
||||
|
||||
# # @pytest.mark.asyncio
|
||||
# # async def test_rate_limiting(crawler_wrapper):
|
||||
# # # Simulate rate limiting by making multiple rapid requests
|
||||
# # url = "https://www.nbcnews.com/business"
|
||||
# # results = await asyncio.gather(*[crawler_wrapper.crawler.arun(url=url, bypass_cache=True) for _ in range(10)])
|
||||
# # assert any(not result.success and "rate limit" in result.error_message.lower() for result in results)
|
||||
|
||||
# # Entry point for debugging
|
||||
# if __name__ == "__main__":
|
||||
# pytest.main([__file__, "-v"])
|
||||
94
tests/async/test_parameters_and_options.py
Normal file
94
tests/async/test_parameters_and_options.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_word_count_threshold():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result_no_threshold = await crawler.arun(url=url, word_count_threshold=0, bypass_cache=True)
|
||||
result_with_threshold = await crawler.arun(url=url, word_count_threshold=50, bypass_cache=True)
|
||||
|
||||
assert len(result_no_threshold.markdown) > len(result_with_threshold.markdown)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_css_selector():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
css_selector = "h1, h2, h3"
|
||||
result = await crawler.arun(url=url, css_selector=css_selector, bypass_cache=True)
|
||||
|
||||
assert result.success
|
||||
assert "<h1" in result.cleaned_html or "<h2" in result.cleaned_html or "<h3" in result.cleaned_html
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_javascript_execution():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
# Crawl without JS
|
||||
result_without_more = await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
js_code = ["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"]
|
||||
result_with_more = await crawler.arun(url=url, js=js_code, bypass_cache=True)
|
||||
|
||||
assert result_with_more.success
|
||||
assert len(result_with_more.markdown) > len(result_without_more.markdown)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_screenshot():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, screenshot=True, bypass_cache=True)
|
||||
|
||||
assert result.success
|
||||
assert result.screenshot
|
||||
assert isinstance(result.screenshot, str) # Should be a base64 encoded string
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_user_agent():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
custom_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36 Crawl4AI/1.0"
|
||||
result = await crawler.arun(url=url, user_agent=custom_user_agent, bypass_cache=True)
|
||||
|
||||
assert result.success
|
||||
# Note: We can't directly verify the user agent in the result, but we can check if the crawl was successful
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_media_and_links():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
assert result.success
|
||||
assert result.media
|
||||
assert isinstance(result.media, dict)
|
||||
assert 'images' in result.media
|
||||
assert result.links
|
||||
assert isinstance(result.links, dict)
|
||||
assert 'internal' in result.links and 'external' in result.links
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_extraction():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
|
||||
assert result.success
|
||||
assert result.metadata
|
||||
assert isinstance(result.metadata, dict)
|
||||
# Check for common metadata fields
|
||||
assert any(key in result.metadata for key in ['title', 'description', 'keywords'])
|
||||
|
||||
# Entry point for debugging
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
72
tests/async/test_performance.py
Normal file
72
tests/async/test_performance.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Add the parent directory to the Python path
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from crawl4ai.async_webcrawler import AsyncWebCrawler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_speed():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
start_time = time.time()
|
||||
result = await crawler.arun(url=url, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
|
||||
assert result.success
|
||||
crawl_time = end_time - start_time
|
||||
print(f"Crawl time: {crawl_time:.2f} seconds")
|
||||
|
||||
assert crawl_time < 10, f"Crawl took too long: {crawl_time:.2f} seconds"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_crawling_performance():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
urls = [
|
||||
"https://www.nbcnews.com/business",
|
||||
"https://www.example.com",
|
||||
"https://www.python.org",
|
||||
"https://www.github.com",
|
||||
"https://www.stackoverflow.com"
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
results = await crawler.arun_many(urls=urls, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = end_time - start_time
|
||||
print(f"Total time for concurrent crawling: {total_time:.2f} seconds")
|
||||
|
||||
assert all(result.success for result in results)
|
||||
assert len(results) == len(urls)
|
||||
|
||||
assert total_time < len(urls) * 5, f"Concurrent crawling not significantly faster: {total_time:.2f} seconds"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawl_speed_with_caching():
|
||||
async with AsyncWebCrawler(verbose=True) as crawler:
|
||||
url = "https://www.nbcnews.com/business"
|
||||
|
||||
start_time = time.time()
|
||||
result1 = await crawler.arun(url=url, bypass_cache=True)
|
||||
end_time = time.time()
|
||||
first_crawl_time = end_time - start_time
|
||||
|
||||
start_time = time.time()
|
||||
result2 = await crawler.arun(url=url, bypass_cache=False)
|
||||
end_time = time.time()
|
||||
second_crawl_time = end_time - start_time
|
||||
|
||||
assert result1.success and result2.success
|
||||
print(f"First crawl time: {first_crawl_time:.2f} seconds")
|
||||
print(f"Second crawl time (cached): {second_crawl_time:.2f} seconds")
|
||||
|
||||
assert second_crawl_time < first_crawl_time / 2, "Cached crawl not significantly faster"
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user