feat: Add screenshot functionality to crawl_urls
The code changes in this commit add the `screenshot` parameter to the `crawl_urls` function in `main.py`. This allows users to specify whether they want to take a screenshot of the page during the crawling process. The default value is `False`. This commit message follows the established convention of starting with a type (feat for feature) and providing a concise and descriptive summary of the changes made.
This commit is contained in:
@@ -7,6 +7,15 @@ from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.common.exceptions import InvalidArgumentException
|
||||
import logging
|
||||
import base64
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
from .utils import wrap_text
|
||||
|
||||
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
@@ -25,15 +34,16 @@ driver_finder_logger = logging.getLogger('selenium.webdriver.common.driver_finde
|
||||
driver_finder_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
from typing import List
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class CrawlerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def crawl(self, url: str, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def take_screenshot(self, save_path: str):
|
||||
pass
|
||||
|
||||
class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self, use_cached_html = False):
|
||||
@@ -132,5 +142,62 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to crawl {url}: {str(e)}")
|
||||
|
||||
def take_screenshot(self) -> str:
|
||||
try:
|
||||
# Get the dimensions of the page
|
||||
total_width = self.driver.execute_script("return document.body.scrollWidth")
|
||||
total_height = self.driver.execute_script("return document.body.scrollHeight")
|
||||
|
||||
# Set the window size to the dimensions of the page
|
||||
self.driver.set_window_size(total_width, total_height)
|
||||
|
||||
# Take screenshot
|
||||
screenshot = self.driver.get_screenshot_as_png()
|
||||
|
||||
# Open the screenshot with PIL
|
||||
image = Image.open(BytesIO(screenshot))
|
||||
|
||||
# Convert to JPEG and compress
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG", quality=85)
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
if self.verbose:
|
||||
print(f"[LOG] 📸 Screenshot taken and converted to base64")
|
||||
|
||||
return img_base64
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Failed to take screenshot: {str(e)}"
|
||||
print(error_message)
|
||||
|
||||
# Generate an image with black background
|
||||
img = Image.new('RGB', (800, 600), color='black')
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Load a font
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 40)
|
||||
except IOError:
|
||||
font = ImageFont.load_default(size=40)
|
||||
|
||||
# Define text color and wrap the text
|
||||
text_color = (255, 255, 255)
|
||||
max_width = 780
|
||||
wrapped_text = wrap_text(draw, error_message, font, max_width)
|
||||
|
||||
# Calculate text position
|
||||
text_position = (10, 10)
|
||||
|
||||
# Draw the text on the image
|
||||
draw.text(text_position, wrapped_text, fill=text_color, font=font)
|
||||
|
||||
# Convert to base64
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
return img_base64
|
||||
|
||||
def quit(self):
|
||||
self.driver.quit()
|
||||
@@ -19,22 +19,23 @@ def init_db():
|
||||
markdown TEXT,
|
||||
extracted_content TEXT,
|
||||
success BOOLEAN,
|
||||
media TEXT
|
||||
media TEXT DEFAULT "{}",
|
||||
screenshot TEXT DEFAULT ""
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def alter_db_add_media():
|
||||
def alter_db_add_screenshot(new_column: str = "media"):
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('ALTER TABLE crawled_data ADD COLUMN media TEXT DEFAULT ""')
|
||||
cursor.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error altering database to add media column: {e}")
|
||||
print(f"Error altering database to add screenshot column: {e}")
|
||||
|
||||
def check_db_path():
|
||||
if not DB_PATH:
|
||||
@@ -45,7 +46,7 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool, st
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media FROM crawled_data WHERE url = ?', (url,))
|
||||
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success, media, screenshot FROM crawled_data WHERE url = ?', (url,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result
|
||||
@@ -53,13 +54,13 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool, st
|
||||
print(f"Error retrieving cached URL: {e}")
|
||||
return None
|
||||
|
||||
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media: str = ""):
|
||||
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool, media : str = "{}", screenshot: str = ""):
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, media)
|
||||
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success, screenshot)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(url) DO UPDATE SET
|
||||
html = excluded.html,
|
||||
@@ -67,8 +68,9 @@ def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_c
|
||||
markdown = excluded.markdown,
|
||||
extracted_content = excluded.extracted_content,
|
||||
success = excluded.success,
|
||||
media = excluded.media
|
||||
''', (url, html, cleaned_html, markdown, extracted_content, success, media))
|
||||
media = excluded.media,
|
||||
screenshot = excluded.screenshot
|
||||
''', (url, html, cleaned_html, markdown, extracted_content, success, media, screenshot))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
@@ -109,12 +111,12 @@ def flush_db():
|
||||
except Exception as e:
|
||||
print(f"Error flushing database: {e}")
|
||||
|
||||
def update_existing_records():
|
||||
def update_existing_records(new_column: str = "media", default_value: str = "{}"):
|
||||
check_db_path()
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('UPDATE crawled_data SET media = "" WHERE media IS NULL')
|
||||
cursor.execute(f'UPDATE crawled_data SET {new_column} = "{default_value}" WHERE screenshot IS NULL')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
@@ -122,5 +124,5 @@ def update_existing_records():
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db() # Initialize the database if not already initialized
|
||||
alter_db_add_media() # Add the new column to the table
|
||||
alter_db_add_screenshot() # Add the new column to the table
|
||||
update_existing_records() # Update existing records to set the new column to an empty string
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
class UrlModel(BaseModel):
|
||||
url: HttpUrl
|
||||
@@ -9,9 +9,10 @@ class CrawlResult(BaseModel):
|
||||
url: str
|
||||
html: str
|
||||
success: bool
|
||||
cleaned_html: str = None
|
||||
cleaned_html: Optional[str] = None
|
||||
media: Dict[str, List[Dict]] = {}
|
||||
markdown: str = None
|
||||
extracted_content: str = None
|
||||
metadata: dict = None
|
||||
error_message: str = None
|
||||
screenshot: Optional[str] = None
|
||||
markdown: Optional[str] = None
|
||||
extracted_content: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
error_message: Optional[str] = None
|
||||
@@ -513,4 +513,16 @@ def process_sections(url: str, sections: list, provider: str, api_token: str) ->
|
||||
for future in as_completed(futures):
|
||||
extracted_content.extend(future.result())
|
||||
|
||||
return extracted_content
|
||||
return extracted_content
|
||||
|
||||
|
||||
def wrap_text(draw, text, font, max_width):
|
||||
# Wrap the text to fit within the specified width
|
||||
lines = []
|
||||
words = text.split()
|
||||
while words:
|
||||
line = ''
|
||||
while words and draw.textbbox((0, 0), line + words[0], font=font)[2] <= max_width:
|
||||
line += (words.pop(0) + ' ')
|
||||
lines.append(line)
|
||||
return '\n'.join(lines)
|
||||
@@ -59,6 +59,8 @@ class WebCrawler:
|
||||
api_token: str = None,
|
||||
extract_blocks_flag: bool = True,
|
||||
word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
use_cached_html: bool = False,
|
||||
extraction_strategy: ExtractionStrategy = None,
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
@@ -70,6 +72,8 @@ class WebCrawler:
|
||||
extraction_strategy or NoExtractionStrategy(),
|
||||
chunking_strategy,
|
||||
bypass_cache=url_model.forced,
|
||||
css_selector=css_selector,
|
||||
screenshot=screenshot,
|
||||
**kwargs,
|
||||
)
|
||||
pass
|
||||
@@ -83,6 +87,7 @@ class WebCrawler:
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
bypass_cache: bool = False,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
) -> CrawlResult:
|
||||
@@ -110,7 +115,8 @@ class WebCrawler:
|
||||
"markdown": cached[3],
|
||||
"extracted_content": cached[4],
|
||||
"success": cached[5],
|
||||
"media": json.loads(cached[6]),
|
||||
"media": json.loads(cached[6] or "{}"),
|
||||
"screenshot": cached[7],
|
||||
"error_message": "",
|
||||
}
|
||||
)
|
||||
@@ -118,6 +124,9 @@ class WebCrawler:
|
||||
# Initialize WebDriver for crawling
|
||||
t = time.time()
|
||||
html = self.crawler_strategy.crawl(url)
|
||||
base64_image = None
|
||||
if screenshot:
|
||||
base64_image = self.crawler_strategy.take_screenshot()
|
||||
success = True
|
||||
error_message = ""
|
||||
# Extract content from HTML
|
||||
@@ -166,6 +175,7 @@ class WebCrawler:
|
||||
extracted_content,
|
||||
success,
|
||||
json.dumps(media),
|
||||
screenshot=base64_image,
|
||||
)
|
||||
|
||||
return CrawlResult(
|
||||
@@ -174,6 +184,7 @@ class WebCrawler:
|
||||
cleaned_html=cleaned_html,
|
||||
markdown=markdown,
|
||||
media=media,
|
||||
screenshot=base64_image,
|
||||
extracted_content=extracted_content,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
@@ -187,6 +198,8 @@ class WebCrawler:
|
||||
extract_blocks_flag: bool = True,
|
||||
word_count_threshold=MIN_WORD_THRESHOLD,
|
||||
use_cached_html: bool = False,
|
||||
css_selector: str = None,
|
||||
screenshot: bool = False,
|
||||
extraction_strategy: ExtractionStrategy = None,
|
||||
chunking_strategy: ChunkingStrategy = RegexChunking(),
|
||||
**kwargs,
|
||||
@@ -204,6 +217,8 @@ class WebCrawler:
|
||||
[api_token] * len(url_models),
|
||||
[extract_blocks_flag] * len(url_models),
|
||||
[word_count_threshold] * len(url_models),
|
||||
[css_selector] * len(url_models),
|
||||
[screenshot] * len(url_models),
|
||||
[use_cached_html] * len(url_models),
|
||||
[extraction_strategy] * len(url_models),
|
||||
[chunking_strategy] * len(url_models),
|
||||
|
||||
Reference in New Issue
Block a user