fix: prevent session closing after each request to maintain connection pool. Fixes: https://github.com/unclecode/crawl4ai/issues/867

This commit is contained in:
Aravind Karnam
2025-03-25 13:46:55 +05:30
parent 2f0e217751
commit e3111d0a32

View File

@@ -1702,15 +1702,6 @@ class AsyncHTTPCrawlerStrategy(AsyncCrawlerStrategy):
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close() await self.close()
@contextlib.asynccontextmanager
async def _session_context(self):
try:
if not self._session:
await self.start()
yield self._session
finally:
await self.close()
def set_hook(self, hook_type: str, hook_func: Callable) -> None: def set_hook(self, hook_type: str, hook_func: Callable) -> None:
if hook_type in self.hooks: if hook_type in self.hooks:
self.hooks[hook_type] = partial(self._execute_hook, hook_type, hook_func) self.hooks[hook_type] = partial(self._execute_hook, hook_type, hook_func)
@@ -1787,75 +1778,77 @@ class AsyncHTTPCrawlerStrategy(AsyncCrawlerStrategy):
url: str, url: str,
config: CrawlerRunConfig config: CrawlerRunConfig
) -> AsyncCrawlResponse: ) -> AsyncCrawlResponse:
async with self._session_context() as session: if not self._session or self._session.closed:
timeout = ClientTimeout( await self.start()
total=config.page_timeout or self.DEFAULT_TIMEOUT,
connect=10, timeout = ClientTimeout(
sock_read=30 total=config.page_timeout or self.DEFAULT_TIMEOUT,
) connect=10,
sock_read=30
headers = dict(self._BASE_HEADERS) )
if self.browser_config.headers:
headers.update(self.browser_config.headers) headers = dict(self._BASE_HEADERS)
if self.browser_config.headers:
headers.update(self.browser_config.headers)
request_kwargs = { request_kwargs = {
'timeout': timeout, 'timeout': timeout,
'allow_redirects': self.browser_config.follow_redirects, 'allow_redirects': self.browser_config.follow_redirects,
'ssl': self.browser_config.verify_ssl, 'ssl': self.browser_config.verify_ssl,
'headers': headers 'headers': headers
} }
if self.browser_config.method == "POST": if self.browser_config.method == "POST":
if self.browser_config.data: if self.browser_config.data:
request_kwargs['data'] = self.browser_config.data request_kwargs['data'] = self.browser_config.data
if self.browser_config.json: if self.browser_config.json:
request_kwargs['json'] = self.browser_config.json request_kwargs['json'] = self.browser_config.json
await self.hooks['before_request'](url, request_kwargs) await self.hooks['before_request'](url, request_kwargs)
try: try:
async with session.request(self.browser_config.method, url, **request_kwargs) as response: async with self._session.request(self.browser_config.method, url, **request_kwargs) as response:
content = memoryview(await response.read()) content = memoryview(await response.read())
if not (200 <= response.status < 300): if not (200 <= response.status < 300):
raise HTTPStatusError( raise HTTPStatusError(
response.status, response.status,
f"Unexpected status code for {url}" f"Unexpected status code for {url}"
)
encoding = response.charset
if not encoding:
encoding = cchardet.detect(content.tobytes())['encoding'] or 'utf-8'
result = AsyncCrawlResponse(
html=content.tobytes().decode(encoding, errors='replace'),
response_headers=dict(response.headers),
status_code=response.status,
redirected_url=str(response.url)
) )
await self.hooks['after_request'](result) encoding = response.charset
return result if not encoding:
encoding = cchardet.detect(content.tobytes())['encoding'] or 'utf-8'
result = AsyncCrawlResponse(
html=content.tobytes().decode(encoding, errors='replace'),
response_headers=dict(response.headers),
status_code=response.status,
redirected_url=str(response.url)
)
await self.hooks['after_request'](result)
return result
except aiohttp.ServerTimeoutError as e: except aiohttp.ServerTimeoutError as e:
await self.hooks['on_error'](e) await self.hooks['on_error'](e)
raise ConnectionTimeoutError(f"Request timed out: {str(e)}") raise ConnectionTimeoutError(f"Request timed out: {str(e)}")
except aiohttp.ClientConnectorError as e:
await self.hooks['on_error'](e)
raise ConnectionError(f"Connection failed: {str(e)}")
except aiohttp.ClientError as e:
await self.hooks['on_error'](e)
raise HTTPCrawlerError(f"HTTP client error: {str(e)}")
except asyncio.exceptions.TimeoutError as e: except aiohttp.ClientConnectorError as e:
await self.hooks['on_error'](e) await self.hooks['on_error'](e)
raise ConnectionTimeoutError(f"Request timed out: {str(e)}") raise ConnectionError(f"Connection failed: {str(e)}")
except Exception as e: except aiohttp.ClientError as e:
await self.hooks['on_error'](e) await self.hooks['on_error'](e)
raise HTTPCrawlerError(f"HTTP request failed: {str(e)}") raise HTTPCrawlerError(f"HTTP client error: {str(e)}")
except asyncio.exceptions.TimeoutError as e:
await self.hooks['on_error'](e)
raise ConnectionTimeoutError(f"Request timed out: {str(e)}")
except Exception as e:
await self.hooks['on_error'](e)
raise HTTPCrawlerError(f"HTTP request failed: {str(e)}")
async def crawl( async def crawl(
self, self,