chore: Add hooks for customizing the LocalSeleniumCrawlerStrategy

This commit is contained in:
unclecode
2024-06-17 15:37:18 +08:00
parent 52daf3936a
commit 9a97aacd85
2 changed files with 84 additions and 1 deletions

View File

@@ -10,7 +10,7 @@ import logging
import base64
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from typing import List
from typing import List, Callable
import requests
import os
from pathlib import Path
@@ -48,6 +48,10 @@ class CrawlerStrategy(ABC):
@abstractmethod
def update_user_agent(self, user_agent: str):
pass
@abstractmethod
def set_hook(self, hook_type: str, hook: Callable):
pass
class CloudCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html = False):
@@ -96,6 +100,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.use_cached_html = use_cached_html
self.js_code = js_code
self.verbose = kwargs.get("verbose", False)
# Hooks
self.hooks = {
'on_driver_created': None,
'before_get_url': None,
'after_get_url': None,
'before_return_html': None
}
# chromedriver_autoinstaller.install()
import chromedriver_autoinstaller
@@ -103,10 +115,29 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.service.log_path = "NUL"
self.driver = webdriver.Chrome(service=self.service, options=self.options)
def set_hook(self, hook_type: str, hook: Callable):
if hook_type in self.hooks:
self.hooks[hook_type] = hook
else:
raise ValueError(f"Invalid hook type: {hook_type}")
def execute_hook(self, hook_type: str, *args):
hook = self.hooks.get(hook_type)
if hook:
result = hook(*args)
if result is not None:
if isinstance(result, webdriver.Chrome):
return result
else:
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.")
# If the hook returns None or there is no hook, return self.driver
return self.driver
def update_user_agent(self, user_agent: str):
self.options.add_argument(f"user-agent={user_agent}")
self.driver.quit()
self.driver = webdriver.Chrome(service=self.service, options=self.options)
self.driver = self.execute_hook('on_driver_created', self.driver)
def crawl(self, url: str) -> str:
# Create md5 hash of the URL
@@ -120,12 +151,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return f.read()
try:
self.driver = self.execute_hook('before_get_url', self.driver)
if self.verbose:
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
self.driver.get(url)
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
)
self.driver = self.execute_hook('after_get_url', self.driver)
# Execute JS code if provided
if self.js_code and type(self.js_code) == str:
@@ -142,6 +175,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
)
html = self.driver.page_source
self.driver = self.execute_hook('before_return_html', self.driver, html)
# Store in cache
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url_hash)