From 320afdea64f92c9a5942e901f4a9016ea7ab13f1 Mon Sep 17 00:00:00 2001 From: unclecode Date: Mon, 14 Oct 2024 21:03:28 +0800 Subject: [PATCH] feat: Enhance crawler flexibility and LLM extraction capabilities - Add browser type selection (Chromium, Firefox, WebKit) - Implement iframe content extraction - Improve image processing and dimension updates - Add custom headers support in AsyncPlaywrightCrawlerStrategy - Enhance delayed content retrieval with new parameter - Optimize HTML sanitization and Markdown conversion - Update examples in quickstart_async.py for new features --- .gitignore | 3 +- crawl4ai/async_crawler_strategy.py | 125 ++++++++++++++++++- crawl4ai/content_scrapping_strategy.py | 13 +- crawl4ai/prompts.py | 4 +- crawl4ai/utils.py | 160 ++++++++++++------------- crawl4ai/web_crawler.py | 1 + docs/examples/quickstart_async.py | 25 ++++ 7 files changed, 238 insertions(+), 93 deletions(-) diff --git a/.gitignore b/.gitignore index 8b8f014c..e5718a14 100644 --- a/.gitignore +++ b/.gitignore @@ -203,4 +203,5 @@ git_changes.py git_changes.md pypi_build.sh -.tests/ \ No newline at end of file +.tests/ +git_changes.py \ No newline at end of file diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index c74aff13..e9699953 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -50,7 +50,8 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): self.user_agent = kwargs.get("user_agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") self.proxy = kwargs.get("proxy") self.headless = kwargs.get("headless", True) - self.headers = {} + self.browser_type = kwargs.get("browser_type", "chromium") # New parameter + self.headers = kwargs.get("headers", {}) self.sessions = {} self.session_ttl = 1800 self.js_code = js_code @@ -80,7 +81,6 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): if self.browser is None: browser_args = { "headless": self.headless, - # "headless": False, "args": [ "--disable-gpu", "--disable-dev-shm-usage", @@ -95,7 +95,14 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): browser_args["proxy"] = proxy_settings - self.browser = await self.playwright.chromium.launch(**browser_args) + # Select the appropriate browser based on the browser_type + if self.browser_type == "firefox": + self.browser = await self.playwright.firefox.launch(**browser_args) + elif self.browser_type == "webkit": + self.browser = await self.playwright.webkit.launch(**browser_args) + else: + self.browser = await self.playwright.chromium.launch(**browser_args) + await self.execute_hook('on_browser_created', self.browser) async def close(self): @@ -145,7 +152,6 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): for sid in expired_sessions: asyncio.create_task(self.kill_session(sid)) - async def smart_wait(self, page: Page, wait_for: str, timeout: float = 30000): wait_for = wait_for.strip() @@ -209,6 +215,48 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): except Exception as e: raise RuntimeError(f"Error in wait condition: {str(e)}") + async def process_iframes(self, page): + # Find all iframes + iframes = await page.query_selector_all('iframe') + + for i, iframe in enumerate(iframes): + try: + # Add a unique identifier to the iframe + await iframe.evaluate(f'(element) => element.id = "iframe-{i}"') + + # Get the frame associated with this iframe + frame = await iframe.content_frame() + + if frame: + # Wait for the frame to load + await frame.wait_for_load_state('load', timeout=30000) # 30 seconds timeout + + # Extract the content of the iframe's body + iframe_content = await frame.evaluate('() => document.body.innerHTML') + + # Generate a unique class name for this iframe + class_name = f'extracted-iframe-content-{i}' + + # Replace the iframe with a div containing the extracted content + _iframe = iframe_content.replace('`', '\\`') + await page.evaluate(f""" + () => {{ + const iframe = document.getElementById('iframe-{i}'); + const div = document.createElement('div'); + div.innerHTML = `{_iframe}`; + div.className = '{class_name}'; + iframe.replaceWith(div); + }} + """) + else: + print(f"Warning: Could not access content frame for iframe {i}") + except Exception as e: + print(f"Error processing iframe {i}: {str(e)}") + + # Return the page object + return page + + async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: response_headers = {} status_code = None @@ -263,6 +311,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): status_code = 200 response_headers = {} + await page.wait_for_selector('body') await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") @@ -305,11 +354,78 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): if kwargs.get("screenshot"): screenshot_data = await self.take_screenshot(url) + + # New code to update image dimensions + update_image_dimensions_js = """ + () => { + return new Promise((resolve) => { + const filterImage = (img) => { + // Filter out images that are too small + if (img.width < 100 && img.height < 100) return false; + + // Filter out images that are not visible + const rect = img.getBoundingClientRect(); + if (rect.width === 0 || rect.height === 0) return false; + + // Filter out images with certain class names (e.g., icons, thumbnails) + if (img.classList.contains('icon') || img.classList.contains('thumbnail')) return false; + + // Filter out images with certain patterns in their src (e.g., placeholder images) + if (img.src.includes('placeholder') || img.src.includes('icon')) return false; + + return true; + }; + + const images = Array.from(document.querySelectorAll('img')).filter(filterImage); + let imagesLeft = images.length; + + if (imagesLeft === 0) { + resolve(); + return; + } + + const checkImage = (img) => { + if (img.complete && img.naturalWidth !== 0) { + img.setAttribute('width', img.naturalWidth); + img.setAttribute('height', img.naturalHeight); + imagesLeft--; + if (imagesLeft === 0) resolve(); + } + }; + + images.forEach(img => { + checkImage(img); + if (!img.complete) { + img.onload = () => { + checkImage(img); + }; + img.onerror = () => { + imagesLeft--; + if (imagesLeft === 0) resolve(); + }; + } + }); + + // Fallback timeout of 5 seconds + setTimeout(() => resolve(), 5000); + }); + } + """ + await page.evaluate(update_image_dimensions_js) + + # Wait a bit for any onload events to complete + await page.wait_for_timeout(100) + + # Process iframes + if kwargs.get("process_iframes", False): + page = await self.process_iframes(page) + await self.execute_hook('before_retrieve_html', page) # Check if delay_before_return_html is set then wait for that time delay_before_return_html = kwargs.get("delay_before_return_html") if delay_before_return_html: await asyncio.sleep(delay_before_return_html) + html = await page.content() await self.execute_hook('before_return_html', page, html) @@ -398,7 +514,6 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): except Error as e: raise Error(f"Failed to execute JavaScript or wait for condition in session {session_id}: {str(e)}") - async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]: semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count()) semaphore = asyncio.Semaphore(semaphore_count) diff --git a/crawl4ai/content_scrapping_strategy.py b/crawl4ai/content_scrapping_strategy.py index afd75892..68f03412 100644 --- a/crawl4ai/content_scrapping_strategy.py +++ b/crawl4ai/content_scrapping_strategy.py @@ -16,8 +16,6 @@ from .utils import ( CustomHTML2Text ) - - class ContentScrappingStrategy(ABC): @abstractmethod def scrap(self, url: str, html: str, **kwargs) -> Dict[str, Any]: @@ -129,7 +127,7 @@ class WebScrappingStrategy(ContentScrappingStrategy): image_size = 0 #int(fetch_image_file_size(img,base_url) or 0) image_format = os.path.splitext(img.get('src',''))[1].lower() # Remove . from format - image_format = image_format.strip('.') + image_format = image_format.strip('.').split('?')[0] score = 0 if height_value: if height_unit == 'px' and height_value > 150: @@ -158,6 +156,7 @@ class WebScrappingStrategy(ContentScrappingStrategy): return None return { 'src': img.get('src', ''), + 'data-src': img.get('data-src', ''), 'alt': img.get('alt', ''), 'desc': find_closest_parent_with_useful_text(img), 'score': score, @@ -275,11 +274,14 @@ class WebScrappingStrategy(ContentScrappingStrategy): # Replace base64 data with empty string img['src'] = base64_pattern.sub('', src) cleaned_html = str(body).replace('\n\n', '\n').replace(' ', ' ') - cleaned_html = sanitize_html(cleaned_html) h = CustomHTML2Text() h.ignore_links = True - markdown = h.handle(cleaned_html) + h.body_width = 0 + try: + markdown = h.handle(cleaned_html) + except Exception as e: + markdown = h.handle(sanitize_html(cleaned_html)) markdown = markdown.replace(' ```', '```') try: @@ -288,6 +290,7 @@ class WebScrappingStrategy(ContentScrappingStrategy): print('Error extracting metadata:', str(e)) meta = {} + cleaned_html = sanitize_html(cleaned_html) return { 'markdown': markdown, 'cleaned_html': cleaned_html, diff --git a/crawl4ai/prompts.py b/crawl4ai/prompts.py index a55d6fca..7a963e6d 100644 --- a/crawl4ai/prompts.py +++ b/crawl4ai/prompts.py @@ -1,4 +1,4 @@ -PROMPT_EXTRACT_BLOCKS = """YHere is the URL of the webpage: +PROMPT_EXTRACT_BLOCKS = """Here is the URL of the webpage: {URL} And here is the cleaned HTML content of that webpage: @@ -79,7 +79,7 @@ To generate the JSON objects: 2. For each block: a. Assign it an index based on its order in the content. b. Analyze the content and generate ONE semantic tag that describe what the block is about. - c. Extract the text content, EXACTLY SAME AS GIVE DATA, clean it up if needed, and store it as a list of strings in the "content" field. + c. Extract the text content, EXACTLY SAME AS THE GIVE DATA, clean it up if needed, and store it as a list of strings in the "content" field. 3. Ensure that the order of the JSON objects matches the order of the blocks as they appear in the original HTML content. diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 77671a20..efb5d79b 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -131,7 +131,7 @@ def split_and_parse_json_objects(json_string): return parsed_objects, unparsed_segments def sanitize_html(html): - # Replace all weird and special characters with an empty string + # Replace all unwanted and special characters with an empty string sanitized_html = html # sanitized_html = re.sub(r'[^\w\s.,;:!?=\[\]{}()<>\/\\\-"]', '', html) @@ -301,7 +301,7 @@ def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, if tag.name != 'img': tag.attrs = {} - # Extract all img tgas inti [{src: '', alt: ''}] + # Extract all img tgas int0 [{src: '', alt: ''}] media = { 'images': [], 'videos': [], @@ -339,7 +339,7 @@ def get_content_of_website(url, html, word_count_threshold = MIN_WORD_THRESHOLD, img.decompose() - # Create a function that replace content of all"pre" tage with its inner text + # Create a function that replace content of all"pre" tag with its inner text def replace_pre_tags_with_text(node): for child in node.find_all('pre'): # set child inner html to its text @@ -502,7 +502,7 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: current_tag = tag while current_tag: current_tag = current_tag.parent - # Get the text content of the parent tag + # Get the text content from the parent tag if current_tag: text_content = current_tag.get_text(separator=' ',strip=True) # Check if the text content has at least word_count_threshold @@ -511,88 +511,88 @@ def get_content_of_website_optimized(url: str, html: str, word_count_threshold: return None def process_image(img, url, index, total_images): - #Check if an image has valid display and inside undesired html elements - def is_valid_image(img, parent, parent_classes): - style = img.get('style', '') - src = img.get('src', '') - classes_to_check = ['button', 'icon', 'logo'] - tags_to_check = ['button', 'input'] - return all([ - 'display:none' not in style, - src, - not any(s in var for var in [src, img.get('alt', ''), *parent_classes] for s in classes_to_check), - parent.name not in tags_to_check - ]) + #Check if an image has valid display and inside undesired html elements + def is_valid_image(img, parent, parent_classes): + style = img.get('style', '') + src = img.get('src', '') + classes_to_check = ['button', 'icon', 'logo'] + tags_to_check = ['button', 'input'] + return all([ + 'display:none' not in style, + src, + not any(s in var for var in [src, img.get('alt', ''), *parent_classes] for s in classes_to_check), + parent.name not in tags_to_check + ]) - #Score an image for it's usefulness - def score_image_for_usefulness(img, base_url, index, images_count): - # Function to parse image height/width value and units - def parse_dimension(dimension): - if dimension: - match = re.match(r"(\d+)(\D*)", dimension) - if match: - number = int(match.group(1)) - unit = match.group(2) or 'px' # Default unit is 'px' if not specified - return number, unit - return None, None + #Score an image for it's usefulness + def score_image_for_usefulness(img, base_url, index, images_count): + # Function to parse image height/width value and units + def parse_dimension(dimension): + if dimension: + match = re.match(r"(\d+)(\D*)", dimension) + if match: + number = int(match.group(1)) + unit = match.group(2) or 'px' # Default unit is 'px' if not specified + return number, unit + return None, None - # Fetch image file metadata to extract size and extension - def fetch_image_file_size(img, base_url): - #If src is relative path construct full URL, if not it may be CDN URL - img_url = urljoin(base_url,img.get('src')) - try: - response = requests.head(img_url) - if response.status_code == 200: - return response.headers.get('Content-Length',None) - else: - print(f"Failed to retrieve file size for {img_url}") - return None - except InvalidSchema as e: + # Fetch image file metadata to extract size and extension + def fetch_image_file_size(img, base_url): + #If src is relative path construct full URL, if not it may be CDN URL + img_url = urljoin(base_url,img.get('src')) + try: + response = requests.head(img_url) + if response.status_code == 200: + return response.headers.get('Content-Length',None) + else: + print(f"Failed to retrieve file size for {img_url}") return None - finally: - return + except InvalidSchema as e: + return None + finally: + return - image_height = img.get('height') - height_value, height_unit = parse_dimension(image_height) - image_width = img.get('width') - width_value, width_unit = parse_dimension(image_width) - image_size = 0 #int(fetch_image_file_size(img,base_url) or 0) - image_format = os.path.splitext(img.get('src',''))[1].lower() - # Remove . from format - image_format = image_format.strip('.') - score = 0 - if height_value: - if height_unit == 'px' and height_value > 150: - score += 1 - if height_unit in ['%','vh','vmin','vmax'] and height_value >30: - score += 1 - if width_value: - if width_unit == 'px' and width_value > 150: - score += 1 - if width_unit in ['%','vh','vmin','vmax'] and width_value >30: - score += 1 - if image_size > 10000: + image_height = img.get('height') + height_value, height_unit = parse_dimension(image_height) + image_width = img.get('width') + width_value, width_unit = parse_dimension(image_width) + image_size = 0 #int(fetch_image_file_size(img,base_url) or 0) + image_format = os.path.splitext(img.get('src',''))[1].lower() + # Remove . from format + image_format = image_format.strip('.') + score = 0 + if height_value: + if height_unit == 'px' and height_value > 150: score += 1 - if img.get('alt') != '': - score+=1 - if any(image_format==format for format in ['jpg','png','webp']): - score+=1 - if index/images_count<0.5: - score+=1 - return score + if height_unit in ['%','vh','vmin','vmax'] and height_value >30: + score += 1 + if width_value: + if width_unit == 'px' and width_value > 150: + score += 1 + if width_unit in ['%','vh','vmin','vmax'] and width_value >30: + score += 1 + if image_size > 10000: + score += 1 + if img.get('alt') != '': + score+=1 + if any(image_format==format for format in ['jpg','png','webp']): + score+=1 + if index/images_count<0.5: + score+=1 + return score - if not is_valid_image(img, img.parent, img.parent.get('class', [])): - return None - score = score_image_for_usefulness(img, url, index, total_images) - if score <= IMAGE_SCORE_THRESHOLD: - return None - return { - 'src': img.get('src', ''), - 'alt': img.get('alt', ''), - 'desc': find_closest_parent_with_useful_text(img), - 'score': score, - 'type': 'image' - } + if not is_valid_image(img, img.parent, img.parent.get('class', [])): + return None + score = score_image_for_usefulness(img, url, index, total_images) + if score <= IMAGE_SCORE_THRESHOLD: + return None + return { + 'src': img.get('src', '').replace('\\"', '"').strip(), + 'alt': img.get('alt', ''), + 'desc': find_closest_parent_with_useful_text(img), + 'score': score, + 'type': 'image' + } def process_element(element: element.PageElement) -> bool: try: diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 7dea56ca..20e9b04e 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -12,6 +12,7 @@ from typing import List from concurrent.futures import ThreadPoolExecutor from .config import * import warnings +import json warnings.filterwarnings("ignore", message='Field "model_name" has conflict with protected namespace "model_".') diff --git a/docs/examples/quickstart_async.py b/docs/examples/quickstart_async.py index 27a162e3..f6c16a4e 100644 --- a/docs/examples/quickstart_async.py +++ b/docs/examples/quickstart_async.py @@ -357,6 +357,28 @@ async def crawl_dynamic_content_pages_method_3(): await crawler.crawler_strategy.kill_session(session_id) print(f"Successfully crawled {len(all_commits)} commits across 3 pages") +async def crawl_custom_browser_type(): + # Use Firefox + start = time.time() + async with AsyncWebCrawler(browser_type="firefox", verbose=True, headless = True) as crawler: + result = await crawler.arun(url="https://www.example.com", bypass_cache=True) + print(result.markdown[:500]) + print("Time taken: ", time.time() - start) + + # Use WebKit + start = time.time() + async with AsyncWebCrawler(browser_type="webkit", verbose=True, headless = True) as crawler: + result = await crawler.arun(url="https://www.example.com", bypass_cache=True) + print(result.markdown[:500]) + print("Time taken: ", time.time() - start) + + # Use Chromium (default) + start = time.time() + async with AsyncWebCrawler(verbose=True, headless = True) as crawler: + result = await crawler.arun(url="https://www.example.com", bypass_cache=True) + print(result.markdown[:500]) + print("Time taken: ", time.time() - start) + async def speed_comparison(): # print("\n--- Speed Comparison ---") # print("Firecrawl (simulated):") @@ -446,6 +468,9 @@ async def main(): # await crawl_dynamic_content_pages_method_1() # await crawl_dynamic_content_pages_method_2() await crawl_dynamic_content_pages_method_3() + + await crawl_custom_browser_type() + await speed_comparison()