chore: Add hooks for customizing the LocalSeleniumCrawlerStrategy
This commit is contained in:
@@ -10,7 +10,7 @@ import logging
|
||||
import base64
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from typing import List, Callable
|
||||
import requests
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -48,6 +48,10 @@ class CrawlerStrategy(ABC):
|
||||
@abstractmethod
|
||||
def update_user_agent(self, user_agent: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
pass
|
||||
|
||||
class CloudCrawlerStrategy(CrawlerStrategy):
|
||||
def __init__(self, use_cached_html = False):
|
||||
@@ -96,6 +100,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.use_cached_html = use_cached_html
|
||||
self.js_code = js_code
|
||||
self.verbose = kwargs.get("verbose", False)
|
||||
|
||||
# Hooks
|
||||
self.hooks = {
|
||||
'on_driver_created': None,
|
||||
'before_get_url': None,
|
||||
'after_get_url': None,
|
||||
'before_return_html': None
|
||||
}
|
||||
|
||||
# chromedriver_autoinstaller.install()
|
||||
import chromedriver_autoinstaller
|
||||
@@ -103,10 +115,29 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
self.service.log_path = "NUL"
|
||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
|
||||
def set_hook(self, hook_type: str, hook: Callable):
|
||||
if hook_type in self.hooks:
|
||||
self.hooks[hook_type] = hook
|
||||
else:
|
||||
raise ValueError(f"Invalid hook type: {hook_type}")
|
||||
|
||||
def execute_hook(self, hook_type: str, *args):
|
||||
hook = self.hooks.get(hook_type)
|
||||
if hook:
|
||||
result = hook(*args)
|
||||
if result is not None:
|
||||
if isinstance(result, webdriver.Chrome):
|
||||
return result
|
||||
else:
|
||||
raise TypeError(f"Hook {hook_type} must return an instance of webdriver.Chrome or None.")
|
||||
# If the hook returns None or there is no hook, return self.driver
|
||||
return self.driver
|
||||
|
||||
def update_user_agent(self, user_agent: str):
|
||||
self.options.add_argument(f"user-agent={user_agent}")
|
||||
self.driver.quit()
|
||||
self.driver = webdriver.Chrome(service=self.service, options=self.options)
|
||||
self.driver = self.execute_hook('on_driver_created', self.driver)
|
||||
|
||||
def crawl(self, url: str) -> str:
|
||||
# Create md5 hash of the URL
|
||||
@@ -120,12 +151,14 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
return f.read()
|
||||
|
||||
try:
|
||||
self.driver = self.execute_hook('before_get_url', self.driver)
|
||||
if self.verbose:
|
||||
print(f"[LOG] 🕸️ Crawling {url} using LocalSeleniumCrawlerStrategy...")
|
||||
self.driver.get(url)
|
||||
WebDriverWait(self.driver, 10).until(
|
||||
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
|
||||
)
|
||||
self.driver = self.execute_hook('after_get_url', self.driver)
|
||||
|
||||
# Execute JS code if provided
|
||||
if self.js_code and type(self.js_code) == str:
|
||||
@@ -142,6 +175,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||
)
|
||||
|
||||
html = self.driver.page_source
|
||||
self.driver = self.execute_hook('before_return_html', self.driver, html)
|
||||
|
||||
# Store in cache
|
||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url_hash)
|
||||
|
||||
Reference in New Issue
Block a user