feat: Add screenshot functionality to crawl_urls

The code changes in this commit add the `screenshot` parameter to the `crawl_urls` function in `main.py`. This allows users to specify whether they want to take a screenshot of the page during the crawling process. The default value is `False`.

This commit message follows the established convention of starting with a type (feat for feature) and providing a concise and descriptive summary of the changes made.
This commit is contained in:
unclecode
2024-06-07 15:23:32 +08:00
parent 0533aeb814
commit 8e73a482a2
11 changed files with 147 additions and 27 deletions

BIN
.files/screenshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@@ -1,4 +1,4 @@
# Crawl4AI v0.2.2 🕷️🤖
# Crawl4AI v0.2.3 🕷️🤖
[![GitHub Stars](https://img.shields.io/github/stars/unclecode/crawl4ai?style=social)](https://github.com/unclecode/crawl4ai/stargazers)
[![GitHub Forks](https://img.shields.io/github/forks/unclecode/crawl4ai?style=social)](https://github.com/unclecode/crawl4ai/network/members)
@@ -12,6 +12,10 @@ Crawl4AI has one clear task: to simplify crawling and extract useful information
## Recent Changes
### v0.2.3
- 🎨 Extract and return all media tags (Images, Audio, and Video).
- 🖼️ Take screenshots of the page.
### v0.2.2
- Support multiple JS scripts
- Fixed some of bugs

View File

@@ -7,6 +7,15 @@ from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options
from selenium.common.exceptions import InvalidArgumentException
import logging
import base64
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from typing import List
import requests
import os
from pathlib import Path
from .utils import wrap_text
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
logger.setLevel(logging.WARNING)
@@ -25,15 +34,16 @@ driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finde
driver_finder_logger.setLevel(logging.WARNING)
from typing import List
import requests
import os
from pathlib import Path
class CrawlerStrategy(ABC):
@abstractmethod
def crawl(self, url: str, **kwargs) -> str:
pass
@abstractmethod
def take_screenshot(self, save_path: str):
pass
class CloudCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html = False):
@@ -132,5 +142,62 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
except Exception as e:
raise Exception(f"Failed to crawl {url}: {str(e)}")
def take_screenshot(self) -> str:
try:
# Get the dimensions of the page
total_width = self.driver.execute_script("return document.body.scrollWidth")
total_height = self.driver.execute_script("return document.body.scrollHeight")
# Set the window size to the dimensions of the page
self.driver.set_window_size(total_width, total_height)
# Take screenshot
screenshot = self.driver.get_screenshot_as_png()
# Open the screenshot with PIL
image = Image.open(BytesIO(screenshot))
# Convert to JPEG and compress
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
if self.verbose:
print(f"[LOG] 📸 Screenshot taken and converted to base64")
return img_base64
except Exception as e:
error_message = f"Failed to take screenshot: {str(e)}"
print(error_message)
# Generate an image with black background
img = Image.new('RGB', (800, 600), color='black')
draw = ImageDraw.Draw(img)
# Load a font
try:
font = ImageFont.truetype("arial.ttf", 40)
except IOError:
font = ImageFont.load_default(size=40)
# Define text color and wrap the text
text_color = (255, 255, 255)
max_width = 780
wrapped_text = wrap_text(draw, error_message, font, max_width)
# Calculate text position
text_position = (10, 10)
# Draw the text on the image
draw.text(text_position, wrapped_text, fill=text_color, font=font)
# Convert to base64
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return img_base64
def quit(self):
self.driver.quit()

View File

@@ -19,22 +19,23 @@ def init_db():
markdown TEXT,
extracted_content TEXT,
success BOOLEAN,
media TEXT
media TEXT DEFAULT "{}",
screenshot TEXT DEFAULT ""
)
''')
conn.commit()
conn.close()
def alter_db_add_media():
def alter_db_add_screenshot(new_column: str = "media"):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('ALTER TABLE crawled_data ADD COLUMN media TEXT DEFAULT ""')
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
conn.commit()
conn.close()
except Exception as e:
print(f"Error altering database to add media column: {e}")
print(f"Error altering database to add screenshot column: {e}")
def check_db_path():
if not DB_PATH:
@@ -45,7 +46,7 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool, st
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media FROM crawled_data WHERE url = ?', (url,))
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, screenshot FROM crawled_data WHERE url = ?', (url,))
result = cursor.fetchone()
conn.close()
return result
@@ -53,13 +54,13 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool, st
print(f"Error retrieving cached URL: {e}")
return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media: str = ""):
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", screenshot: str = ""):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media)
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, screenshot)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
html = excluded.html,
@@ -67,8 +68,9 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
markdown = excluded.markdown,
extracted_content = excluded.extracted_content,
success = excluded.success,
media = excluded.media
''', (url, html, cleaned_html, markdown, extracted_content, success, media))
media = excluded.media,
screenshot = excluded.screenshot
''', (url, html, cleaned_html, markdown, extracted_content, success, media, screenshot))
conn.commit()
conn.close()
except Exception as e:
@@ -109,12 +111,12 @@ def flush_db():
except Exception as e:
print(f"Error flushing database: {e}")
def update_existing_records():
def update_existing_records(new_column: str = "media", default_value: str = "{}"):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('UPDATE crawled_data SET media = "" WHERE media IS NULL')
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL')
conn.commit()
conn.close()
except Exception as e:
@@ -122,5 +124,5 @@ def update_existing_records():
if __name__ == "__main__":
init_db() # Initialize the database if not already initialized
alter_db_add_media() # Add the new column to the table
alter_db_add_screenshot() # Add the new column to the table
update_existing_records() # Update existing records to set the new column to an empty string

View File

@@ -1,5 +1,5 @@
from pydantic import BaseModel, HttpUrl
from typing import List, Dict
from typing import List, Dict, Optional
class UrlModel(BaseModel):
url: HttpUrl
@@ -9,9 +9,10 @@ class CrawlResult(BaseModel):
url: str
html: str
success: bool
cleaned_html: str = None
cleaned_html: Optional[str] = None
media: Dict[str, List[Dict]] = {}
markdown: str = None
extracted_content: str = None
metadata: dict = None
error_message: str = None
screenshot: Optional[str] = None
markdown: Optional[str] = None
extracted_content: Optional[str] = None
metadata: Optional[dict] = None
error_message: Optional[str] = None

View File

@@ -513,4 +513,16 @@ def process_sections(url: str, sections: list, provider: str, api_token: str) ->
for future in as_completed(futures):
extracted_content.extend(future.result())
return extracted_content
return extracted_content
def wrap_text(draw, text, font, max_width):
# Wrap the text to fit within the specified width
lines = []
words = text.split()
while words:
line = ''
while words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width:
line += (words.pop(0) + ' ')
lines.append(line)
return '\n'.join(lines)

View File

@@ -59,6 +59,8 @@ class WebCrawler:
api_token: str = None,
extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD,
css_selector: str = None,
screenshot: bool = False,
use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
@@ -70,6 +72,8 @@ class WebCrawler:
extraction_strategy or NoExtractionStrategy(),
chunking_strategy,
bypass_cache=url_model.forced,
css_selector=css_selector,
screenshot=screenshot,
**kwargs,
)
pass
@@ -83,6 +87,7 @@ class WebCrawler:
chunking_strategy: ChunkingStrategy = RegexChunking(),
bypass_cache: bool = False,
css_selector: str = None,
screenshot: bool = False,
verbose=True,
**kwargs,
) -> CrawlResult:
@@ -110,7 +115,8 @@ class WebCrawler:
"markdown": cached[3],
"extracted_content": cached[4],
"success": cached[5],
"media": json.loads(cached[6]),
"media": json.loads(cached[6] or "{}"),
"screenshot": cached[7],
"error_message": "",
}
)
@@ -118,6 +124,9 @@ class WebCrawler:
# Initialize WebDriver for crawling
t = time.time()
html = self.crawler_strategy.crawl(url)
base64_image = None
if screenshot:
base64_image = self.crawler_strategy.take_screenshot()
success = True
error_message = ""
# Extract content from HTML
@@ -166,6 +175,7 @@ class WebCrawler:
extracted_content,
success,
json.dumps(media),
screenshot=base64_image,
)
return CrawlResult(
@@ -174,6 +184,7 @@ class WebCrawler:
cleaned_html=cleaned_html,
markdown=markdown,
media=media,
screenshot=base64_image,
extracted_content=extracted_content,
success=success,
error_message=error_message,
@@ -187,6 +198,8 @@ class WebCrawler:
extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False,
css_selector: str = None,
screenshot: bool = False,
extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs,
@@ -204,6 +217,8 @@ class WebCrawler:
[api_token] * len(url_models),
[extract_blocks_flag] * len(url_models),
[word_count_threshold] * len(url_models),
[css_selector] * len(url_models),
[screenshot] * len(url_models),
[use_cached_html] * len(url_models),
[extraction_strategy] * len(url_models),
[chunking_strategy] * len(url_models),

View File

@@ -35,7 +35,7 @@ def cprint(message, press_any_key=False):
def basic_usage(crawler):
cprint("🛠️ [bold cyan]Basic Usage: Simply provide a URL and let Crawl4ai do the magic![/bold cyan]")
result = crawler.run(url="https://www.nbcnews.com/business")
result = crawler.run(url="https://www.nbcnews.com/business", screenshot=True)
cprint("[LOG] 📦 [bold yellow]Basic crawl result:[/bold yellow]")
print_result(result)
@@ -187,6 +187,7 @@ def main():
crawler = create_crawler()
crawler.always_by_pass_cache = True
basic_usage(crawler)
understanding_parameters(crawler)

View File

@@ -56,6 +56,7 @@ class CrawlRequest(BaseModel):
chunking_strategy: Optional[str] = "RegexChunking"
chunking_strategy_args: Optional[dict] = {}
css_selector: Optional[str] = None
screenshot: Optional[bool] = False
verbose: Optional[bool] = True
@@ -125,6 +126,7 @@ async def crawl_urls(crawl_request: CrawlRequest, request: Request):
chunking_strategy,
crawl_request.bypass_cache,
crawl_request.css_selector,
crawl_request.screenshot,
crawl_request.verbose
)
for url in crawl_request.urls

View File

@@ -104,6 +104,7 @@ document.getElementById("crawl-btn").addEventListener("click", () => {
chunking_strategy: document.getElementById("chunking-strategy-select").value,
chunking_strategy_args: {},
css_selector: document.getElementById("css-selector").value,
screenshot: document.getElementById("screenshot-checkbox").checked,
// instruction: document.getElementById("instruction").value,
// semantic_filter: document.getElementById("semantic_filter").value,
verbose: true,
@@ -138,7 +139,14 @@ document.getElementById("crawl-btn").addEventListener("click", () => {
document.getElementById("cleaned-html-result").textContent = result.cleaned_html;
document.getElementById("markdown-result").textContent = result.markdown;
document.getElementById("media-result").textContent = JSON.stringify( result.media, null, 2);
if (result.screenshot){
const imgElement = document.createElement("img");
// Set the src attribute with the base64 data
imgElement.src = `data:image/png;base64,${result.screenshot}`;
document.getElementById("screenshot-result").innerHTML = "";
document.getElementById("screenshot-result").appendChild(imgElement);
}
// Update code examples dynamically
const extractionStrategy = data.extraction_strategy;
const isLLMExtraction = extractionStrategy === "LLMExtractionStrategy";

View File

@@ -124,6 +124,10 @@
<input type="checkbox" id="bypass-cache-checkbox" checked />
<label for="bypass-cache-checkbox" class="text-lime-500 font-bold">Bypass Cache</label>
</div>
<div class="flex items-center gap-2">
<input type="checkbox" id="screenshot-checkbox" checked />
<label for="screenshot-checkbox" class="text-lime-500 font-bold">Screenshot</label>
</div>
<div class="flex items-center gap-2 hidden">
<input type="checkbox" id="extract-blocks-checkbox" />
<label for="extract-blocks-checkbox" class="text-lime-500 font-bold">Extract Blocks</label>
@@ -152,12 +156,16 @@
<button class="tab-btn px-4 py-1 text-sm bg-zinc-700 rounded-t text-lime-500" data-tab="media">
Medias
</button>
<button class="tab-btn px-4 py-1 text-sm bg-zinc-700 rounded-t text-lime-500" data-tab="screenshot">
Screenshot
</button>
</div>
<div class="tab-content code bg-zinc-900 p-2 rounded h-full border border-zinc-700 text-sm">
<pre class="h-full flex"><code id="json-result" class="language-json"></code></pre>
<pre class="hidden h-full flex"><code id="cleaned-html-result" class="language-html"></code></pre>
<pre class="hidden h-full flex"><code id="markdown-result" class="language-markdown"></code></pre>
<pre class="hidden h-full flex"><code id="media-result" class="language-json"></code></pre>
<pre class="hidden h-full flex"><code id="screenshot-result"></code></pre>
</div>
</div>