From 396f430022164698f8f240da5efe6bf2e1623077 Mon Sep 17 00:00:00 2001 From: unclecode Date: Thu, 12 Sep 2024 15:49:49 +0800 Subject: [PATCH] Refactor AsyncCrawlerStrategy to return AsyncCrawlResponse This commit refactors the AsyncCrawlerStrategy class in the async_crawler_strategy.py file to modify the return types of the crawl and crawl_many methods. Instead of returning strings, these methods now return instances of the AsyncCrawlResponse class from the pydantic module. The AsyncCrawlResponse class contains the crawled HTML, response headers, and status code. This change improves the clarity and consistency of the code. --- crawl4ai/async_crawler_strategy.py | 42 +++++++++++++++++++++++++----- crawl4ai/async_webcrawler.py | 8 ++++-- crawl4ai/models.py | 4 ++- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index 3840260e..7b24620c 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -12,6 +12,7 @@ import json, uuid import hashlib from pathlib import Path from playwright.async_api import ProxySettings +from pydantic import BaseModel def calculate_semaphore_count(): cpu_count = os.cpu_count() @@ -20,13 +21,18 @@ def calculate_semaphore_count(): memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance return min(base_count, memory_based_cap) +class AsyncCrawlResponse(BaseModel): + html: str + response_headers: Dict[str, str] + status_code: int + class AsyncCrawlerStrategy(ABC): @abstractmethod - async def crawl(self, url: str, **kwargs) -> str: + async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: pass @abstractmethod - async def crawl_many(self, urls: List[str], **kwargs) -> List[str]: + async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]: pass @abstractmethod @@ -140,7 +146,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): for sid in expired_sessions: asyncio.create_task(self.kill_session(sid)) - async def crawl(self, url: str, **kwargs) -> str: + async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: + response_headers = {} + status_code = None + self._cleanup_expired_sessions() session_id = kwargs.get("session_id") if session_id: @@ -168,13 +177,25 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): if self.use_cached_html: cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest()) if os.path.exists(cache_file_path): + html = "" with open(cache_file_path, "r") as f: - return f.read() + html = f.read() + # retrieve response headers and status code from cache + with open(cache_file_path + ".meta", "r") as f: + meta = json.load(f) + response_headers = meta.get("response_headers", {}) + status_code = meta.get("status_code") + response = AsyncCrawlResponse(html=html, response_headers=response_headers, status_code=status_code) + return response if not kwargs.get("js_only", False): await self.execute_hook('before_goto', page) - await page.goto(url, wait_until="domcontentloaded", timeout=60000) + response = await page.goto(url, wait_until="domcontentloaded", timeout=60000) await self.execute_hook('after_goto', page) + + # Get status code and headers + status_code = response.status + response_headers = response.headers await page.wait_for_selector('body') await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") @@ -202,8 +223,15 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest()) with open(cache_file_path, "w", encoding="utf-8") as f: f.write(html) + # store response headers and status code in cache + with open(cache_file_path + ".meta", "w", encoding="utf-8") as f: + json.dump({ + "response_headers": response_headers, + "status_code": status_code + }, f) - return html + response = AsyncCrawlResponse(html=html, response_headers=response_headers, status_code=status_code) + return response except Error as e: raise Error(f"Failed to crawl {url}: {str(e)}") finally: @@ -218,7 +246,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): # except Exception as e: # raise Exception(f"Failed to crawl {url}: {str(e)}") - async def crawl_many(self, urls: List[str], **kwargs) -> List[str]: + async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]: semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count()) semaphore = asyncio.Semaphore(semaphore_count) diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 3cdc9ac1..ceca09f5 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -8,7 +8,7 @@ from .models import CrawlResult from .async_database import async_db_manager from .chunking_strategy import * from .extraction_strategy import * -from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy +from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse from .content_scrapping_strategy import WebScrappingStrategy from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD from .utils import ( @@ -101,7 +101,8 @@ class AsyncWebCrawler: t1 = time.time() if user_agent: self.crawler_strategy.update_user_agent(user_agent) - html = await self.crawler_strategy.crawl(url, **kwargs) + async_response : AsyncCrawlResponse = await self.crawler_strategy.crawl(url, **kwargs) + html = sanitize_input_encode(async_response.html) t2 = time.time() if verbose: print( @@ -121,8 +122,11 @@ class AsyncWebCrawler: screenshot_data, verbose, bool(cached), + async_response=async_response, **kwargs, ) + crawl_result.status_code = async_response.status_code + crawl_result.responser_headers = async_response.response_headers crawl_result.success = bool(html) crawl_result.session_id = kwargs.get("session_id", None) return crawl_result diff --git a/crawl4ai/models.py b/crawl4ai/models.py index e48441b8..eefb0cb9 100644 --- a/crawl4ai/models.py +++ b/crawl4ai/models.py @@ -17,4 +17,6 @@ class CrawlResult(BaseModel): extracted_content: Optional[str] = None metadata: Optional[dict] = None error_message: Optional[str] = None - session_id: Optional[str] = None \ No newline at end of file + session_id: Optional[str] = None + responser_headers: Optional[dict] = None + status_code: Optional[int] = None \ No newline at end of file