Refactor AsyncCrawlerStrategy to return AsyncCrawlResponse

This commit refactors the AsyncCrawlerStrategy class in the async_crawler_strategy.py file to modify the return types of the crawl and crawl_many methods. Instead of returning strings, these methods now return instances of the AsyncCrawlResponse class from the pydantic module. The AsyncCrawlResponse class contains the crawled HTML, response headers, and status code. This change improves the clarity and consistency of the code.
This commit is contained in:
unclecode
2024-09-12 15:49:49 +08:00
parent eb131bebdf
commit 396f430022
3 changed files with 44 additions and 10 deletions

View File

@@ -12,6 +12,7 @@ import json, uuid
import hashlib import hashlib
from pathlib import Path from pathlib import Path
from playwright.async_api import ProxySettings from playwright.async_api import ProxySettings
from pydantic import BaseModel
def calculate_semaphore_count(): def calculate_semaphore_count():
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
@@ -20,13 +21,18 @@ def calculate_semaphore_count():
memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance
return min(base_count, memory_based_cap) return min(base_count, memory_based_cap)
class AsyncCrawlResponse(BaseModel):
html: str
response_headers: Dict[str, str]
status_code: int
class AsyncCrawlerStrategy(ABC): class AsyncCrawlerStrategy(ABC):
@abstractmethod @abstractmethod
async def crawl(self, url: str, **kwargs) -> str: async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
pass pass
@abstractmethod @abstractmethod
async def crawl_many(self, urls: List[str], **kwargs) -> List[str]: async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]:
pass pass
@abstractmethod @abstractmethod
@@ -140,7 +146,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
for sid in expired_sessions: for sid in expired_sessions:
asyncio.create_task(self.kill_session(sid)) asyncio.create_task(self.kill_session(sid))
async def crawl(self, url: str, **kwargs) -> str: async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
response_headers = {}
status_code = None
self._cleanup_expired_sessions() self._cleanup_expired_sessions()
session_id = kwargs.get("session_id") session_id = kwargs.get("session_id")
if session_id: if session_id:
@@ -168,13 +177,25 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
if self.use_cached_html: if self.use_cached_html:
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest()) cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
if os.path.exists(cache_file_path): if os.path.exists(cache_file_path):
html = ""
with open(cache_file_path, "r") as f: with open(cache_file_path, "r") as f:
return f.read() html = f.read()
# retrieve response headers and status code from cache
with open(cache_file_path + ".meta", "r") as f:
meta = json.load(f)
response_headers = meta.get("response_headers", {})
status_code = meta.get("status_code")
response = AsyncCrawlResponse(html=html, response_headers=response_headers, status_code=status_code)
return response
if not kwargs.get("js_only", False): if not kwargs.get("js_only", False):
await self.execute_hook('before_goto', page) await self.execute_hook('before_goto', page)
await page.goto(url, wait_until="domcontentloaded", timeout=60000) response = await page.goto(url, wait_until="domcontentloaded", timeout=60000)
await self.execute_hook('after_goto', page) await self.execute_hook('after_goto', page)
# Get status code and headers
status_code = response.status
response_headers = response.headers
await page.wait_for_selector('body') await page.wait_for_selector('body')
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
@@ -202,8 +223,15 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest()) cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
with open(cache_file_path, "w", encoding="utf-8") as f: with open(cache_file_path, "w", encoding="utf-8") as f:
f.write(html) f.write(html)
# store response headers and status code in cache
with open(cache_file_path + ".meta", "w", encoding="utf-8") as f:
json.dump({
"response_headers": response_headers,
"status_code": status_code
}, f)
return html response = AsyncCrawlResponse(html=html, response_headers=response_headers, status_code=status_code)
return response
except Error as e: except Error as e:
raise Error(f"Failed to crawl {url}: {str(e)}") raise Error(f"Failed to crawl {url}: {str(e)}")
finally: finally:
@@ -218,7 +246,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
# except Exception as e: # except Exception as e:
# raise Exception(f"Failed to crawl {url}: {str(e)}") # raise Exception(f"Failed to crawl {url}: {str(e)}")
async def crawl_many(self, urls: List[str], **kwargs) -> List[str]: async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]:
semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count()) semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count())
semaphore = asyncio.Semaphore(semaphore_count) semaphore = asyncio.Semaphore(semaphore_count)

View File

@@ -8,7 +8,7 @@ from .models import CrawlResult
from .async_database import async_db_manager from .async_database import async_db_manager
from .chunking_strategy import * from .chunking_strategy import *
from .extraction_strategy import * from .extraction_strategy import *
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse
from .content_scrapping_strategy import WebScrappingStrategy from .content_scrapping_strategy import WebScrappingStrategy
from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
from .utils import ( from .utils import (
@@ -101,7 +101,8 @@ class AsyncWebCrawler:
t1 = time.time() t1 = time.time()
if user_agent: if user_agent:
self.crawler_strategy.update_user_agent(user_agent) self.crawler_strategy.update_user_agent(user_agent)
html = await self.crawler_strategy.crawl(url, **kwargs) async_response : AsyncCrawlResponse = await self.crawler_strategy.crawl(url, **kwargs)
html = sanitize_input_encode(async_response.html)
t2 = time.time() t2 = time.time()
if verbose: if verbose:
print( print(
@@ -121,8 +122,11 @@ class AsyncWebCrawler:
screenshot_data, screenshot_data,
verbose, verbose,
bool(cached), bool(cached),
async_response=async_response,
**kwargs, **kwargs,
) )
crawl_result.status_code = async_response.status_code
crawl_result.responser_headers = async_response.response_headers
crawl_result.success = bool(html) crawl_result.success = bool(html)
crawl_result.session_id = kwargs.get("session_id", None) crawl_result.session_id = kwargs.get("session_id", None)
return crawl_result return crawl_result

View File

@@ -17,4 +17,6 @@ class CrawlResult(BaseModel):
extracted_content: Optional[str] = None extracted_content: Optional[str] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
error_message: Optional[str] = None error_message: Optional[str] = None
session_id: Optional[str] = None session_id: Optional[str] = None
responser_headers: Optional[dict] = None
status_code: Optional[int] = None