From 13682482548df1c8c15d4ce7c318701e8ef95770 Mon Sep 17 00:00:00 2001 From: unclecode Date: Fri, 5 Jul 2024 17:59:26 +0800 Subject: [PATCH] feat: Sanitize input and handle encoding issues in LLMExtractionStrategy --- crawl4ai/chunking_strategy.py | 1 + crawl4ai/crawler_strategy.py | 23 ++++++++++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/crawl4ai/chunking_strategy.py b/crawl4ai/chunking_strategy.py index 5fe9b5e1..59006072 100644 --- a/crawl4ai/chunking_strategy.py +++ b/crawl4ai/chunking_strategy.py @@ -3,6 +3,7 @@ import re from collections import Counter import string from .model_loader import load_nltk_punkt +from .utils import * # Define the abstract base class for chunking strategies class ChunkingStrategy(ABC): diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index 21de883e..ae8d93df 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -18,7 +18,7 @@ from typing import List, Callable import requests import os from pathlib import Path -from .utils import wrap_text +from .utils import * logger = logging.getLogger('selenium.webdriver.remote.remote_connection') logger.setLevel(logging.WARNING) @@ -73,7 +73,7 @@ class CloudCrawlerStrategy(CrawlerStrategy): response = requests.post("http://crawl4ai.uccode.io/crawl", json=data) response = response.json() html = response["results"][0]["html"] - return html + return sanitize_input_encode(html) class LocalSeleniumCrawlerStrategy(CrawlerStrategy): def __init__(self, use_cached_html=False, js_code=None, **kwargs): @@ -200,7 +200,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url_hash) if os.path.exists(cache_file_path): with open(cache_file_path, "r") as f: - return f.read() + return sanitize_input_encode(f.read()) try: self.driver = self.execute_hook('before_get_url', self.driver) @@ -215,10 +215,11 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): EC.presence_of_all_elements_located((By.TAG_NAME, "body")) ) self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);") - html = self._ensure_page_load() # self.driver.page_source + html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source can_not_be_done_headless = False # Look at my creativity for naming variables + # TODO: Very ugly way for now but it works - if not kwargs.get('bypass_headless', False) and html == "": + if kwargs.get('bypass_headless', True) or html == "": print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...") can_not_be_done_headless = True options = Options() @@ -227,7 +228,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): options.add_argument("--window-size=5,5") driver = webdriver.Chrome(service=self.service, options=options) driver.get(url) - html = driver.page_source + html = sanitize_input_encode(driver.page_source) driver.quit() self.driver = self.execute_hook('after_get_url', self.driver) @@ -247,7 +248,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): ) if not can_not_be_done_headless: - html = self.driver.page_source + html = sanitize_input_encode(self.driver.page_source) self.driver = self.execute_hook('before_return_html', self.driver, html) # Store in cache @@ -261,16 +262,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): return html except InvalidArgumentException: if not hasattr(e, 'msg'): - e.msg = str(e) + e.msg = sanitize_input_encode(str(e)) raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}") except WebDriverException as e: # If e does nlt have msg attribute create it and set it to str(e) if not hasattr(e, 'msg'): - e.msg = str(e) + e.msg = sanitize_input_encode(str(e)) raise WebDriverException(f"Failed to crawl {url}: {e.msg}") except Exception as e: if not hasattr(e, 'msg'): - e.msg = str(e) + e.msg = sanitize_input_encode(str(e)) raise Exception(f"Failed to crawl {url}: {e.msg}") def take_screenshot(self) -> str: @@ -299,7 +300,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): return img_base64 except Exception as e: - error_message = f"Failed to take screenshot: {str(e)}" + error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}") print(error_message) # Generate an image with black background