feat: Sanitize input and handle encoding issues in LLMExtractionStrategy

This commit is contained in:
unclecode
2024-07-05 17:59:26 +08:00
parent b0ec54b9e9
commit 1368248254
2 changed files with 13 additions and 11 deletions

View File

@@ -3,6 +3,7 @@ import re
from collections import Counter
import string
from .model_loader import load_nltk_punkt
from .utils import *
# Define the abstract base class for chunking strategies
class ChunkingStrategy(ABC):

View File

@@ -18,7 +18,7 @@ from typing import List, Callable
import requests
import os
from pathlib import Path
from .utils import wrap_text
from .utils import *
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
logger.setLevel(logging.WARNING)
@@ -73,7 +73,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
response = requests.post("http://crawl4ai.uccode.io/crawl", json=data)
response = response.json()
html = response["results"][0]["html"]
return html
return sanitize_input_encode(html)
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
@@ -200,7 +200,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url_hash)
if os.path.exists(cache_file_path):
with open(cache_file_path, "r") as f:
return f.read()
return sanitize_input_encode(f.read())
try:
self.driver = self.execute_hook('before_get_url', self.driver)
@@ -215,10 +215,11 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
)
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
html = self._ensure_page_load() # self.driver.page_source
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source
can_not_be_done_headless = False # Look at my creativity for naming variables
# TODO: Very ugly way for now but it works
if not kwargs.get('bypass_headless', False) and html == "<html><head></head><body></body></html>":
if kwargs.get('bypass_headless', True) or html == "<html><head></head><body></body></html>":
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...")
can_not_be_done_headless = True
options = Options()
@@ -227,7 +228,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
options.add_argument("--window-size=5,5")
driver = webdriver.Chrome(service=self.service, options=options)
driver.get(url)
html = driver.page_source
html = sanitize_input_encode(driver.page_source)
driver.quit()
self.driver = self.execute_hook('after_get_url', self.driver)
@@ -247,7 +248,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
)
if not can_not_be_done_headless:
html = self.driver.page_source
html = sanitize_input_encode(self.driver.page_source)
self.driver = self.execute_hook('before_return_html', self.driver, html)
# Store in cache
@@ -261,16 +262,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return html
except InvalidArgumentException:
if not hasattr(e, 'msg'):
e.msg = str(e)
e.msg = sanitize_input_encode(str(e))
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
except WebDriverException as e:
# If e does nlt have msg attribute create it and set it to str(e)
if not hasattr(e, 'msg'):
e.msg = str(e)
e.msg = sanitize_input_encode(str(e))
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
except Exception as e:
if not hasattr(e, 'msg'):
e.msg = str(e)
e.msg = sanitize_input_encode(str(e))
raise Exception(f"Failed to crawl {url}: {e.msg}")
def take_screenshot(self) -> str:
@@ -299,7 +300,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
return img_base64
except Exception as e:
error_message = f"Failed to take screenshot: {str(e)}"
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}")
print(error_message)
# Generate an image with black background