feat: Add screenshot functionality to crawl_urls

The code changes in this commit add the `screenshot` parameter to the `crawl_urls` function in `main.py`. This allows users to specify whether they want to take a screenshot of the page during the crawling process. The default value is `False`.

This commit message follows the established convention of starting with a type (feat for feature) and providing a concise and descriptive summary of the changes made.
This commit is contained in:
unclecode
2024-06-07 15:23:32 +08:00
parent 0533aeb814
commit 8e73a482a2
11 changed files with 147 additions and 27 deletions

BIN
.files/screenshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@@ -1,4 +1,4 @@
# Crawl4AI v0.2.2 🕷️🤖 # Crawl4AI v0.2.3 🕷️🤖
[![GitHub Stars](https://img.shields.io/github/stars/unclecode/crawl4ai?style=social)](https://github.com/unclecode/crawl4ai/stargazers) [![GitHub Stars](https://img.shields.io/github/stars/unclecode/crawl4ai?style=social)](https://github.com/unclecode/crawl4ai/stargazers)
[![GitHub Forks](https://img.shields.io/github/forks/unclecode/crawl4ai?style=social)](https://github.com/unclecode/crawl4ai/network/members) [![GitHub Forks](https://img.shields.io/github/forks/unclecode/crawl4ai?style=social)](https://github.com/unclecode/crawl4ai/network/members)
@@ -12,6 +12,10 @@ Crawl4AI has one clear task: to simplify crawling and extract useful information
## Recent Changes ## Recent Changes
### v0.2.3
- 🎨 Extract and return all media tags (Images, Audio, and Video).
- 🖼️ Take screenshots of the page.
### v0.2.2 ### v0.2.2
- Support multiple JS scripts - Support multiple JS scripts
- Fixed some of bugs - Fixed some of bugs

View File

@@ -7,6 +7,15 @@ from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.options import Options
from selenium.common.exceptions import InvalidArgumentException from selenium.common.exceptions import InvalidArgumentException
import logging import logging
import base64
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from typing import List
import requests
import os
from pathlib import Path
from .utils import wrap_text
logger = logging.getLogger('selenium.webdriver.remote.remote_connection') logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
@@ -25,15 +34,16 @@ driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finde
driver_finder_logger.setLevel(logging.WARNING) driver_finder_logger.setLevel(logging.WARNING)
from typing import List
import requests
import os
from pathlib import Path
class CrawlerStrategy(ABC): class CrawlerStrategy(ABC):
@abstractmethod @abstractmethod
def crawl(self, url: str, **kwargs) -> str: def crawl(self, url: str, **kwargs) -> str:
pass pass
@abstractmethod
def take_screenshot(self, save_path: str):
pass
class CloudCrawlerStrategy(CrawlerStrategy): class CloudCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html = False): def __init__(self, use_cached_html = False):
@@ -132,5 +142,62 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
except Exception as e: except Exception as e:
raise Exception(f"Failed to crawl {url}: {str(e)}") raise Exception(f"Failed to crawl {url}: {str(e)}")
def take_screenshot(self) -> str:
try:
# Get the dimensions of the page
total_width = self.driver.execute_script("return document.body.scrollWidth")
total_height = self.driver.execute_script("return document.body.scrollHeight")
# Set the window size to the dimensions of the page
self.driver.set_window_size(total_width, total_height)
# Take screenshot
screenshot = self.driver.get_screenshot_as_png()
# Open the screenshot with PIL
image = Image.open(BytesIO(screenshot))
# Convert to JPEG and compress
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
if self.verbose:
print(f"[LOG] 📸 Screenshot taken and converted to base64")
return img_base64
except Exception as e:
error_message = f"Failed to take screenshot: {str(e)}"
print(error_message)
# Generate an image with black background
img = Image.new('RGB', (800, 600), color='black')
draw = ImageDraw.Draw(img)
# Load a font
try:
font = ImageFont.truetype("arial.ttf", 40)
except IOError:
font = ImageFont.load_default(size=40)
# Define text color and wrap the text
text_color = (255, 255, 255)
max_width = 780
wrapped_text = wrap_text(draw, error_message, font, max_width)
# Calculate text position
text_position = (10, 10)
# Draw the text on the image
draw.text(text_position, wrapped_text, fill=text_color, font=font)
# Convert to base64
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return img_base64
def quit(self): def quit(self):
self.driver.quit() self.driver.quit()

View File

@@ -19,22 +19,23 @@ def init_db():
markdown TEXT, markdown TEXT,
extracted_content TEXT, extracted_content TEXT,
success BOOLEAN, success BOOLEAN,
media TEXT media TEXT DEFAULT "{}",
screenshot TEXT DEFAULT ""
) )
''') ''')
conn.commit() conn.commit()
conn.close() conn.close()
def alter_db_add_media(): def alter_db_add_screenshot(new_column: str = "media"):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('ALTER TABLE crawled_data ADD COLUMN media TEXT DEFAULT ""') cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
print(f"Error altering database to add media column: {e}") print(f"Error altering database to add screenshot column: {e}")
def check_db_path(): def check_db_path():
if not DB_PATH: if not DB_PATH:
@@ -45,7 +46,7 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool, st
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media FROM crawled_data WHERE url = ?', (url,)) cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, screenshot FROM crawled_data WHERE url = ?', (url,))
result = cursor.fetchone() result = cursor.fetchone()
conn.close() conn.close()
return result return result
@@ -53,13 +54,13 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool, st
print(f"Error retrieving cached URL: {e}") print(f"Error retrieving cached URL: {e}")
return None return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media: str = ""): def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", screenshot: str = ""):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute('''
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media) INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, screenshot)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET ON CONFLICT(url) DO UPDATE SET
html = excluded.html, html = excluded.html,
@@ -67,8 +68,9 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
markdown = excluded.markdown, markdown = excluded.markdown,
extracted_content = excluded.extracted_content, extracted_content = excluded.extracted_content,
success = excluded.success, success = excluded.success,
media = excluded.media media = excluded.media,
''', (url, html, cleaned_html, markdown, extracted_content, success, media)) screenshot = excluded.screenshot
''', (url, html, cleaned_html, markdown, extracted_content, success, media, screenshot))
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
@@ -109,12 +111,12 @@ def flush_db():
except Exception as e: except Exception as e:
print(f"Error flushing database: {e}") print(f"Error flushing database: {e}")
def update_existing_records(): def update_existing_records(new_column: str = "media", default_value: str = "{}"):
check_db_path() check_db_path()
try: try:
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('UPDATE crawled_data SET media = "" WHERE media IS NULL') cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL')
conn.commit() conn.commit()
conn.close() conn.close()
except Exception as e: except Exception as e:
@@ -122,5 +124,5 @@ def update_existing_records():
if __name__ == "__main__": if __name__ == "__main__":
init_db() # Initialize the database if not already initialized init_db() # Initialize the database if not already initialized
alter_db_add_media() # Add the new column to the table alter_db_add_screenshot() # Add the new column to the table
update_existing_records() # Update existing records to set the new column to an empty string update_existing_records() # Update existing records to set the new column to an empty string

View File

@@ -1,5 +1,5 @@
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
from typing import List, Dict from typing import List, Dict, Optional
class UrlModel(BaseModel): class UrlModel(BaseModel):
url: HttpUrl url: HttpUrl
@@ -9,9 +9,10 @@ class CrawlResult(BaseModel):
url: str url: str
html: str html: str
success: bool success: bool
cleaned_html: str = None cleaned_html: Optional[str] = None
media: Dict[str, List[Dict]] = {} media: Dict[str, List[Dict]] = {}
markdown: str = None screenshot: Optional[str] = None
extracted_content: str = None markdown: Optional[str] = None
metadata: dict = None extracted_content: Optional[str] = None
error_message: str = None metadata: Optional[dict] = None
error_message: Optional[str] = None

View File

@@ -513,4 +513,16 @@ def process_sections(url: str, sections: list, provider: str, api_token: str) ->
for future in as_completed(futures): for future in as_completed(futures):
extracted_content.extend(future.result()) extracted_content.extend(future.result())
return extracted_content return extracted_content
def wrap_text(draw, text, font, max_width):
# Wrap the text to fit within the specified width
lines = []
words = text.split()
while words:
line = ''
while words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width:
line += (words.pop(0) + ' ')
lines.append(line)
return '\n'.join(lines)

View File

@@ -59,6 +59,8 @@ class WebCrawler:
api_token: str = None, api_token: str = None,
extract_blocks_flag: bool = True, extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD, word_count_threshold=MIN_WORD_THRESHOLD,
css_selector: str = None,
screenshot: bool = False,
use_cached_html: bool = False, use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = None, extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
@@ -70,6 +72,8 @@ class WebCrawler:
extraction_strategy or NoExtractionStrategy(), extraction_strategy or NoExtractionStrategy(),
chunking_strategy, chunking_strategy,
bypass_cache=url_model.forced, bypass_cache=url_model.forced,
css_selector=css_selector,
screenshot=screenshot,
**kwargs, **kwargs,
) )
pass pass
@@ -83,6 +87,7 @@ class WebCrawler:
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
bypass_cache: bool = False, bypass_cache: bool = False,
css_selector: str = None, css_selector: str = None,
screenshot: bool = False,
verbose=True, verbose=True,
**kwargs, **kwargs,
) -> CrawlResult: ) -> CrawlResult:
@@ -110,7 +115,8 @@ class WebCrawler:
"markdown": cached[3], "markdown": cached[3],
"extracted_content": cached[4], "extracted_content": cached[4],
"success": cached[5], "success": cached[5],
"media": json.loads(cached[6]), "media": json.loads(cached[6] or "{}"),
"screenshot": cached[7],
"error_message": "", "error_message": "",
} }
) )
@@ -118,6 +124,9 @@ class WebCrawler:
# Initialize WebDriver for crawling # Initialize WebDriver for crawling
t = time.time() t = time.time()
html = self.crawler_strategy.crawl(url) html = self.crawler_strategy.crawl(url)
base64_image = None
if screenshot:
base64_image = self.crawler_strategy.take_screenshot()
success = True success = True
error_message = "" error_message = ""
# Extract content from HTML # Extract content from HTML
@@ -166,6 +175,7 @@ class WebCrawler:
extracted_content, extracted_content,
success, success,
json.dumps(media), json.dumps(media),
screenshot=base64_image,
) )
return CrawlResult( return CrawlResult(
@@ -174,6 +184,7 @@ class WebCrawler:
cleaned_html=cleaned_html, cleaned_html=cleaned_html,
markdown=markdown, markdown=markdown,
media=media, media=media,
screenshot=base64_image,
extracted_content=extracted_content, extracted_content=extracted_content,
success=success, success=success,
error_message=error_message, error_message=error_message,
@@ -187,6 +198,8 @@ class WebCrawler:
extract_blocks_flag: bool = True, extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD, word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False, use_cached_html: bool = False,
css_selector: str = None,
screenshot: bool = False,
extraction_strategy: ExtractionStrategy = None, extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs, **kwargs,
@@ -204,6 +217,8 @@ class WebCrawler:
[api_token] * len(url_models), [api_token] * len(url_models),
[extract_blocks_flag] * len(url_models), [extract_blocks_flag] * len(url_models),
[word_count_threshold] * len(url_models), [word_count_threshold] * len(url_models),
[css_selector] * len(url_models),
[screenshot] * len(url_models),
[use_cached_html] * len(url_models), [use_cached_html] * len(url_models),
[extraction_strategy] * len(url_models), [extraction_strategy] * len(url_models),
[chunking_strategy] * len(url_models), [chunking_strategy] * len(url_models),

View File

@@ -35,7 +35,7 @@ def cprint(message, press_any_key=False):
def basic_usage(crawler): def basic_usage(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]") cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business") result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]") cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result) print_result(result)
@@ -187,6 +187,7 @@ def main():
crawler = create_crawler() crawler = create_crawler()
crawler.always_by_pass_cache = True
basic_usage(crawler) basic_usage(crawler)
understanding_parameters(crawler) understanding_parameters(crawler)

View File

@@ -56,6 +56,7 @@ class CrawlRequest(BaseModel):
chunking_strategy: Optional[str] = "RegexChunking" chunking_strategy: Optional[str] = "RegexChunking"
chunking_strategy_args: Optional[dict] = {} chunking_strategy_args: Optional[dict] = {}
css_selector: Optional[str] = None css_selector: Optional[str] = None
screenshot: Optional[bool] = False
verbose: Optional[bool] = True verbose: Optional[bool] = True
@@ -125,6 +126,7 @@ async def crawl_urls(crawl_request: CrawlRequest, request: Request):
chunking_strategy, chunking_strategy,
crawl_request.bypass_cache, crawl_request.bypass_cache,
crawl_request.css_selector, crawl_request.css_selector,
crawl_request.screenshot,
crawl_request.verbose crawl_request.verbose
) )
for url in crawl_request.urls for url in crawl_request.urls

View File

@@ -104,6 +104,7 @@ document.getElementById("crawl-btn").addEventListener("click", () => {
chunking_strategy: document.getElementById("chunking-strategy-select").value, chunking_strategy: document.getElementById("chunking-strategy-select").value,
chunking_strategy_args: {}, chunking_strategy_args: {},
css_selector: document.getElementById("css-selector").value, css_selector: document.getElementById("css-selector").value,
screenshot: document.getElementById("screenshot-checkbox").checked,
// instruction: document.getElementById("instruction").value, // instruction: document.getElementById("instruction").value,
// semantic_filter: document.getElementById("semantic_filter").value, // semantic_filter: document.getElementById("semantic_filter").value,
verbose: true, verbose: true,
@@ -138,7 +139,14 @@ document.getElementById("crawl-btn").addEventListener("click", () => {
document.getElementById("cleaned-html-result").textContent = result.cleaned_html; document.getElementById("cleaned-html-result").textContent = result.cleaned_html;
document.getElementById("markdown-result").textContent = result.markdown; document.getElementById("markdown-result").textContent = result.markdown;
document.getElementById("media-result").textContent = JSON.stringify( result.media, null, 2); document.getElementById("media-result").textContent = JSON.stringify( result.media, null, 2);
if (result.screenshot){
const imgElement = document.createElement("img");
// Set the src attribute with the base64 data
imgElement.src = `data:image/png;base64,${result.screenshot}`;
document.getElementById("screenshot-result").innerHTML = "";
document.getElementById("screenshot-result").appendChild(imgElement);
}
// Update code examples dynamically // Update code examples dynamically
const extractionStrategy = data.extraction_strategy; const extractionStrategy = data.extraction_strategy;
const isLLMExtraction = extractionStrategy === "LLMExtractionStrategy"; const isLLMExtraction = extractionStrategy === "LLMExtractionStrategy";

View File

@@ -124,6 +124,10 @@
<input type="checkbox" id="bypass-cache-checkbox" checked /> <input type="checkbox" id="bypass-cache-checkbox" checked />
<label for="bypass-cache-checkbox" class="text-lime-500 font-bold">Bypass Cache</label> <label for="bypass-cache-checkbox" class="text-lime-500 font-bold">Bypass Cache</label>
</div> </div>
<div class="flex items-center gap-2">
<input type="checkbox" id="screenshot-checkbox" checked />
<label for="screenshot-checkbox" class="text-lime-500 font-bold">Screenshot</label>
</div>
<div class="flex items-center gap-2 hidden"> <div class="flex items-center gap-2 hidden">
<input type="checkbox" id="extract-blocks-checkbox" /> <input type="checkbox" id="extract-blocks-checkbox" />
<label for="extract-blocks-checkbox" class="text-lime-500 font-bold">Extract Blocks</label> <label for="extract-blocks-checkbox" class="text-lime-500 font-bold">Extract Blocks</label>
@@ -152,12 +156,16 @@
<button class="tab-btn px-4 py-1 text-sm bg-zinc-700 rounded-t text-lime-500" data-tab="media"> <button class="tab-btn px-4 py-1 text-sm bg-zinc-700 rounded-t text-lime-500" data-tab="media">
Medias Medias
</button> </button>
<button class="tab-btn px-4 py-1 text-sm bg-zinc-700 rounded-t text-lime-500" data-tab="screenshot">
Screenshot
</button>
</div> </div>
<div class="tab-content code bg-zinc-900 p-2 rounded h-full border border-zinc-700 text-sm"> <div class="tab-content code bg-zinc-900 p-2 rounded h-full border border-zinc-700 text-sm">
<pre class="h-full flex"><code id="json-result" class="language-json"></code></pre> <pre class="h-full flex"><code id="json-result" class="language-json"></code></pre>
<pre class="hidden h-full flex"><code id="cleaned-html-result" class="language-html"></code></pre> <pre class="hidden h-full flex"><code id="cleaned-html-result" class="language-html"></code></pre>
<pre class="hidden h-full flex"><code id="markdown-result" class="language-markdown"></code></pre> <pre class="hidden h-full flex"><code id="markdown-result" class="language-markdown"></code></pre>
<pre class="hidden h-full flex"><code id="media-result" class="language-json"></code></pre> <pre class="hidden h-full flex"><code id="media-result" class="language-json"></code></pre>
<pre class="hidden h-full flex"><code id="screenshot-result"></code></pre>
</div> </div>
</div> </div>