Refactor AsyncCrawlerStrategy to return AsyncCrawlResponse
This commit refactors the AsyncCrawlerStrategy class in the async_crawler_strategy.py file to modify the return types of the crawl and crawl_many methods. Instead of returning strings, these methods now return instances of the AsyncCrawlResponse class from the pydantic module. The AsyncCrawlResponse class contains the crawled HTML, response headers, and status code. This change improves the clarity and consistency of the code.
This commit is contained in:
@@ -12,6 +12,7 @@ import json, uuid
|
|||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from playwright.async_api import ProxySettings
|
from playwright.async_api import ProxySettings
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
def calculate_semaphore_count():
|
def calculate_semaphore_count():
|
||||||
cpu_count = os.cpu_count()
|
cpu_count = os.cpu_count()
|
||||||
@@ -20,13 +21,18 @@ def calculate_semaphore_count():
|
|||||||
memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance
|
memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance
|
||||||
return min(base_count, memory_based_cap)
|
return min(base_count, memory_based_cap)
|
||||||
|
|
||||||
|
class AsyncCrawlResponse(BaseModel):
|
||||||
|
html: str
|
||||||
|
response_headers: Dict[str, str]
|
||||||
|
status_code: int
|
||||||
|
|
||||||
class AsyncCrawlerStrategy(ABC):
|
class AsyncCrawlerStrategy(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def crawl(self, url: str, **kwargs) -> str:
|
async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def crawl_many(self, urls: List[str], **kwargs) -> List[str]:
|
async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -140,7 +146,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
for sid in expired_sessions:
|
for sid in expired_sessions:
|
||||||
asyncio.create_task(self.kill_session(sid))
|
asyncio.create_task(self.kill_session(sid))
|
||||||
|
|
||||||
async def crawl(self, url: str, **kwargs) -> str:
|
async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse:
|
||||||
|
response_headers = {}
|
||||||
|
status_code = None
|
||||||
|
|
||||||
self._cleanup_expired_sessions()
|
self._cleanup_expired_sessions()
|
||||||
session_id = kwargs.get("session_id")
|
session_id = kwargs.get("session_id")
|
||||||
if session_id:
|
if session_id:
|
||||||
@@ -168,13 +177,25 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
if self.use_cached_html:
|
if self.use_cached_html:
|
||||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
|
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
|
||||||
if os.path.exists(cache_file_path):
|
if os.path.exists(cache_file_path):
|
||||||
|
html = ""
|
||||||
with open(cache_file_path, "r") as f:
|
with open(cache_file_path, "r") as f:
|
||||||
return f.read()
|
html = f.read()
|
||||||
|
# retrieve response headers and status code from cache
|
||||||
|
with open(cache_file_path + ".meta", "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
response_headers = meta.get("response_headers", {})
|
||||||
|
status_code = meta.get("status_code")
|
||||||
|
response = AsyncCrawlResponse(html=html, response_headers=response_headers, status_code=status_code)
|
||||||
|
return response
|
||||||
|
|
||||||
if not kwargs.get("js_only", False):
|
if not kwargs.get("js_only", False):
|
||||||
await self.execute_hook('before_goto', page)
|
await self.execute_hook('before_goto', page)
|
||||||
await page.goto(url, wait_until="domcontentloaded", timeout=60000)
|
response = await page.goto(url, wait_until="domcontentloaded", timeout=60000)
|
||||||
await self.execute_hook('after_goto', page)
|
await self.execute_hook('after_goto', page)
|
||||||
|
|
||||||
|
# Get status code and headers
|
||||||
|
status_code = response.status
|
||||||
|
response_headers = response.headers
|
||||||
|
|
||||||
await page.wait_for_selector('body')
|
await page.wait_for_selector('body')
|
||||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||||
@@ -202,8 +223,15 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
|
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", hashlib.md5(url.encode()).hexdigest())
|
||||||
with open(cache_file_path, "w", encoding="utf-8") as f:
|
with open(cache_file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(html)
|
f.write(html)
|
||||||
|
# store response headers and status code in cache
|
||||||
|
with open(cache_file_path + ".meta", "w", encoding="utf-8") as f:
|
||||||
|
json.dump({
|
||||||
|
"response_headers": response_headers,
|
||||||
|
"status_code": status_code
|
||||||
|
}, f)
|
||||||
|
|
||||||
return html
|
response = AsyncCrawlResponse(html=html, response_headers=response_headers, status_code=status_code)
|
||||||
|
return response
|
||||||
except Error as e:
|
except Error as e:
|
||||||
raise Error(f"Failed to crawl {url}: {str(e)}")
|
raise Error(f"Failed to crawl {url}: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
@@ -218,7 +246,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy):
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# raise Exception(f"Failed to crawl {url}: {str(e)}")
|
# raise Exception(f"Failed to crawl {url}: {str(e)}")
|
||||||
|
|
||||||
async def crawl_many(self, urls: List[str], **kwargs) -> List[str]:
|
async def crawl_many(self, urls: List[str], **kwargs) -> List[AsyncCrawlResponse]:
|
||||||
semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count())
|
semaphore_count = kwargs.get('semaphore_count', calculate_semaphore_count())
|
||||||
semaphore = asyncio.Semaphore(semaphore_count)
|
semaphore = asyncio.Semaphore(semaphore_count)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from .models import CrawlResult
|
|||||||
from .async_database import async_db_manager
|
from .async_database import async_db_manager
|
||||||
from .chunking_strategy import *
|
from .chunking_strategy import *
|
||||||
from .extraction_strategy import *
|
from .extraction_strategy import *
|
||||||
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy
|
from .async_crawler_strategy import AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, AsyncCrawlResponse
|
||||||
from .content_scrapping_strategy import WebScrappingStrategy
|
from .content_scrapping_strategy import WebScrappingStrategy
|
||||||
from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
|
from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -101,7 +101,8 @@ class AsyncWebCrawler:
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
if user_agent:
|
if user_agent:
|
||||||
self.crawler_strategy.update_user_agent(user_agent)
|
self.crawler_strategy.update_user_agent(user_agent)
|
||||||
html = await self.crawler_strategy.crawl(url, **kwargs)
|
async_response : AsyncCrawlResponse = await self.crawler_strategy.crawl(url, **kwargs)
|
||||||
|
html = sanitize_input_encode(async_response.html)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
@@ -121,8 +122,11 @@ class AsyncWebCrawler:
|
|||||||
screenshot_data,
|
screenshot_data,
|
||||||
verbose,
|
verbose,
|
||||||
bool(cached),
|
bool(cached),
|
||||||
|
async_response=async_response,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
crawl_result.status_code = async_response.status_code
|
||||||
|
crawl_result.responser_headers = async_response.response_headers
|
||||||
crawl_result.success = bool(html)
|
crawl_result.success = bool(html)
|
||||||
crawl_result.session_id = kwargs.get("session_id", None)
|
crawl_result.session_id = kwargs.get("session_id", None)
|
||||||
return crawl_result
|
return crawl_result
|
||||||
|
|||||||
@@ -17,4 +17,6 @@ class CrawlResult(BaseModel):
|
|||||||
extracted_content: Optional[str] = None
|
extracted_content: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
|
responser_headers: Optional[dict] = None
|
||||||
|
status_code: Optional[int] = None
|
||||||
Reference in New Issue
Block a user