feat: Sanitize input and handle encoding issues in LLMExtractionStrategy
This commit is contained in:
@@ -3,6 +3,7 @@ import re
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
import string
|
import string
|
||||||
from .model_loader import load_nltk_punkt
|
from .model_loader import load_nltk_punkt
|
||||||
|
from .utils import *
|
||||||
|
|
||||||
# Define the abstract base class for chunking strategies
|
# Define the abstract base class for chunking strategies
|
||||||
class ChunkingStrategy(ABC):
|
class ChunkingStrategy(ABC):
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import List, Callable
|
|||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .utils import wrap_text
|
from .utils import *
|
||||||
|
|
||||||
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
|
logger = logging.getLogger('selenium.webdriver.remote.remote_connection')
|
||||||
logger.setLevel(logging.WARNING)
|
logger.setLevel(logging.WARNING)
|
||||||
@@ -73,7 +73,7 @@ class CloudCrawlerStrategy(CrawlerStrategy):
|
|||||||
response = requests.post("http://crawl4ai.uccode.io/crawl", json=data)
|
response = requests.post("http://crawl4ai.uccode.io/crawl", json=data)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
html = response["results"][0]["html"]
|
html = response["results"][0]["html"]
|
||||||
return html
|
return sanitize_input_encode(html)
|
||||||
|
|
||||||
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
||||||
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
def __init__(self, use_cached_html=False, js_code=None, **kwargs):
|
||||||
@@ -200,7 +200,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url_hash)
|
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url_hash)
|
||||||
if os.path.exists(cache_file_path):
|
if os.path.exists(cache_file_path):
|
||||||
with open(cache_file_path, "r") as f:
|
with open(cache_file_path, "r") as f:
|
||||||
return f.read()
|
return sanitize_input_encode(f.read())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.driver = self.execute_hook('before_get_url', self.driver)
|
self.driver = self.execute_hook('before_get_url', self.driver)
|
||||||
@@ -215,10 +215,11 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
|
EC.presence_of_all_elements_located((By.TAG_NAME, "body"))
|
||||||
)
|
)
|
||||||
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
||||||
html = self._ensure_page_load() # self.driver.page_source
|
html = sanitize_input_encode(self._ensure_page_load()) # self.driver.page_source
|
||||||
can_not_be_done_headless = False # Look at my creativity for naming variables
|
can_not_be_done_headless = False # Look at my creativity for naming variables
|
||||||
|
|
||||||
# TODO: Very ugly way for now but it works
|
# TODO: Very ugly way for now but it works
|
||||||
if not kwargs.get('bypass_headless', False) and html == "<html><head></head><body></body></html>":
|
if kwargs.get('bypass_headless', True) or html == "<html><head></head><body></body></html>":
|
||||||
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...")
|
print("[LOG] 🙌 Page could not be loaded in headless mode. Trying non-headless mode...")
|
||||||
can_not_be_done_headless = True
|
can_not_be_done_headless = True
|
||||||
options = Options()
|
options = Options()
|
||||||
@@ -227,7 +228,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
options.add_argument("--window-size=5,5")
|
options.add_argument("--window-size=5,5")
|
||||||
driver = webdriver.Chrome(service=self.service, options=options)
|
driver = webdriver.Chrome(service=self.service, options=options)
|
||||||
driver.get(url)
|
driver.get(url)
|
||||||
html = driver.page_source
|
html = sanitize_input_encode(driver.page_source)
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
||||||
self.driver = self.execute_hook('after_get_url', self.driver)
|
self.driver = self.execute_hook('after_get_url', self.driver)
|
||||||
@@ -247,7 +248,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not can_not_be_done_headless:
|
if not can_not_be_done_headless:
|
||||||
html = self.driver.page_source
|
html = sanitize_input_encode(self.driver.page_source)
|
||||||
self.driver = self.execute_hook('before_return_html', self.driver, html)
|
self.driver = self.execute_hook('before_return_html', self.driver, html)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
@@ -261,16 +262,16 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
return html
|
return html
|
||||||
except InvalidArgumentException:
|
except InvalidArgumentException:
|
||||||
if not hasattr(e, 'msg'):
|
if not hasattr(e, 'msg'):
|
||||||
e.msg = str(e)
|
e.msg = sanitize_input_encode(str(e))
|
||||||
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
|
raise InvalidArgumentException(f"Failed to crawl {url}: {e.msg}")
|
||||||
except WebDriverException as e:
|
except WebDriverException as e:
|
||||||
# If e does nlt have msg attribute create it and set it to str(e)
|
# If e does nlt have msg attribute create it and set it to str(e)
|
||||||
if not hasattr(e, 'msg'):
|
if not hasattr(e, 'msg'):
|
||||||
e.msg = str(e)
|
e.msg = sanitize_input_encode(str(e))
|
||||||
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
|
raise WebDriverException(f"Failed to crawl {url}: {e.msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not hasattr(e, 'msg'):
|
if not hasattr(e, 'msg'):
|
||||||
e.msg = str(e)
|
e.msg = sanitize_input_encode(str(e))
|
||||||
raise Exception(f"Failed to crawl {url}: {e.msg}")
|
raise Exception(f"Failed to crawl {url}: {e.msg}")
|
||||||
|
|
||||||
def take_screenshot(self) -> str:
|
def take_screenshot(self) -> str:
|
||||||
@@ -299,7 +300,7 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
|
|||||||
return img_base64
|
return img_base64
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Failed to take screenshot: {str(e)}"
|
error_message = sanitize_input_encode(f"Failed to take screenshot: {str(e)}")
|
||||||
print(error_message)
|
print(error_message)
|
||||||
|
|
||||||
# Generate an image with black background
|
# Generate an image with black background
|
||||||
|
|||||||
Reference in New Issue
Block a user