From e3111d0a328ae2a0c78464de83cfc986f807c28b Mon Sep 17 00:00:00 2001 From: Aravind Karnam Date: Tue, 25 Mar 2025 13:46:55 +0530 Subject: [PATCH] fix: prevent session closing after each request to maintain connection pool. Fixes: https://github.com/unclecode/crawl4ai/issues/867 --- crawl4ai/async_crawler_strategy.py | 133 ++++++++++++++--------------- 1 file changed, 63 insertions(+), 70 deletions(-) diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index 37aa0962..2330b3f3 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -1702,15 +1702,6 @@ class AsyncHTTPCrawlerStrategy(AsyncCrawlerStrategy): async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() - @contextlib.asynccontextmanager - async def _session_context(self): - try: - if not self._session: - await self.start() - yield self._session - finally: - await self.close() - def set_hook(self, hook_type: str, hook_func: Callable) -> None: if hook_type in self.hooks: self.hooks[hook_type] = partial(self._execute_hook, hook_type, hook_func) @@ -1787,75 +1778,77 @@ class AsyncHTTPCrawlerStrategy(AsyncCrawlerStrategy): url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse: - async with self._session_context() as session: - timeout = ClientTimeout( - total=config.page_timeout or self.DEFAULT_TIMEOUT, - connect=10, - sock_read=30 - ) - - headers = dict(self._BASE_HEADERS) - if self.browser_config.headers: - headers.update(self.browser_config.headers) + if not self._session or self._session.closed: + await self.start() + + timeout = ClientTimeout( + total=config.page_timeout or self.DEFAULT_TIMEOUT, + connect=10, + sock_read=30 + ) + + headers = dict(self._BASE_HEADERS) + if self.browser_config.headers: + headers.update(self.browser_config.headers) - request_kwargs = { - 'timeout': timeout, - 'allow_redirects': self.browser_config.follow_redirects, - 'ssl': self.browser_config.verify_ssl, - 'headers': headers - } + request_kwargs = { + 'timeout': timeout, + 'allow_redirects': self.browser_config.follow_redirects, + 'ssl': self.browser_config.verify_ssl, + 'headers': headers + } - if self.browser_config.method == "POST": - if self.browser_config.data: - request_kwargs['data'] = self.browser_config.data - if self.browser_config.json: - request_kwargs['json'] = self.browser_config.json + if self.browser_config.method == "POST": + if self.browser_config.data: + request_kwargs['data'] = self.browser_config.data + if self.browser_config.json: + request_kwargs['json'] = self.browser_config.json - await self.hooks['before_request'](url, request_kwargs) + await self.hooks['before_request'](url, request_kwargs) - try: - async with session.request(self.browser_config.method, url, **request_kwargs) as response: - content = memoryview(await response.read()) - - if not (200 <= response.status < 300): - raise HTTPStatusError( - response.status, - f"Unexpected status code for {url}" - ) - - encoding = response.charset - if not encoding: - encoding = cchardet.detect(content.tobytes())['encoding'] or 'utf-8' - - result = AsyncCrawlResponse( - html=content.tobytes().decode(encoding, errors='replace'), - response_headers=dict(response.headers), - status_code=response.status, - redirected_url=str(response.url) + try: + async with self._session.request(self.browser_config.method, url, **request_kwargs) as response: + content = memoryview(await response.read()) + + if not (200 <= response.status < 300): + raise HTTPStatusError( + response.status, + f"Unexpected status code for {url}" ) - - await self.hooks['after_request'](result) - return result + + encoding = response.charset + if not encoding: + encoding = cchardet.detect(content.tobytes())['encoding'] or 'utf-8' + + result = AsyncCrawlResponse( + html=content.tobytes().decode(encoding, errors='replace'), + response_headers=dict(response.headers), + status_code=response.status, + redirected_url=str(response.url) + ) + + await self.hooks['after_request'](result) + return result - except aiohttp.ServerTimeoutError as e: - await self.hooks['on_error'](e) - raise ConnectionTimeoutError(f"Request timed out: {str(e)}") - - except aiohttp.ClientConnectorError as e: - await self.hooks['on_error'](e) - raise ConnectionError(f"Connection failed: {str(e)}") - - except aiohttp.ClientError as e: - await self.hooks['on_error'](e) - raise HTTPCrawlerError(f"HTTP client error: {str(e)}") + except aiohttp.ServerTimeoutError as e: + await self.hooks['on_error'](e) + raise ConnectionTimeoutError(f"Request timed out: {str(e)}") - except asyncio.exceptions.TimeoutError as e: - await self.hooks['on_error'](e) - raise ConnectionTimeoutError(f"Request timed out: {str(e)}") + except aiohttp.ClientConnectorError as e: + await self.hooks['on_error'](e) + raise ConnectionError(f"Connection failed: {str(e)}") - except Exception as e: - await self.hooks['on_error'](e) - raise HTTPCrawlerError(f"HTTP request failed: {str(e)}") + except aiohttp.ClientError as e: + await self.hooks['on_error'](e) + raise HTTPCrawlerError(f"HTTP client error: {str(e)}") + + except asyncio.exceptions.TimeoutError as e: + await self.hooks['on_error'](e) + raise ConnectionTimeoutError(f"Request timed out: {str(e)}") + + except Exception as e: + await self.hooks['on_error'](e) + raise HTTPCrawlerError(f"HTTP request failed: {str(e)}") async def crawl( self,