New async database manager and migration support

- Introduced AsyncDatabaseManager for async DB management.
  - Added migration feature to transition to file-based storage.
  - Enhanced web crawler with improved caching logic.
  - Updated requirements and setup for async processing.
This commit is contained in:
UncleCode
2024-11-16 14:54:41 +08:00
parent ae7ebc0bd8
commit d0014c6793
8 changed files with 685 additions and 119 deletions

View File

@@ -6,7 +6,11 @@ from typing import Optional, Tuple, Dict
from contextlib import asynccontextmanager
import logging
import json # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash
from .models import CrawlResult
import xxhash
import aiofiles
from .config import NEED_MIGRATION
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -18,6 +22,7 @@ DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
class AsyncDatabaseManager:
def __init__(self, pool_size: int = 10, max_retries: int = 3):
self.db_path = DB_PATH
self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH))
self.pool_size = pool_size
self.max_retries = max_retries
self.connection_pool: Dict[int, aiosqlite.Connection] = {}
@@ -26,8 +31,20 @@ class AsyncDatabaseManager:
async def initialize(self):
"""Initialize the database and connection pool"""
await self.ainit_db()
try:
logger.info("Initializing database...")
await self.ainit_db()
if NEED_MIGRATION:
await self.update_db_schema()
from .migrations import run_migration # Import here to avoid circular imports
await run_migration()
logger.info("Database initialization and migration completed successfully")
else:
logger.info("Database initialization completed successfully")
except Exception as e:
logger.error(f"Database initialization error: {e}")
logger.info("Database will be initialized on first use")
async def cleanup(self):
"""Cleanup connections when shutting down"""
async with self.pool_lock:
@@ -97,7 +114,7 @@ class AsyncDatabaseManager:
''')
await self.execute_with_retry(_init)
await self.update_db_schema()
async def update_db_schema(self):
"""Update database schema if needed"""
@@ -126,34 +143,59 @@ class AsyncDatabaseManager:
await self.execute_with_retry(_alter)
async def aget_cached_url(self, url: str) -> Optional[Tuple[str, str, str, str, str, bool, str, str, str, str]]:
"""Retrieve cached URL data"""
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
"""Retrieve cached URL data as CrawlResult"""
async def _get(db):
async with db.execute(
'''
SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files
FROM crawled_data WHERE url = ?
''',
(url,)
'SELECT * FROM crawled_data WHERE url = ?', (url,)
) as cursor:
row = await cursor.fetchone()
if row:
# Deserialize JSON fields
return (
row[0], # url
row[1], # html
row[2], # cleaned_html
row[3], # markdown
row[4], # extracted_content
row[5], # success
json.loads(row[6] or '{}'), # media
json.loads(row[7] or '{}'), # links
json.loads(row[8] or '{}'), # metadata
row[9], # screenshot
json.loads(row[10] or '{}'), # response_headers
json.loads(row[11] or '[]') # downloaded_files
)
return None
if not row:
return None
# Get column names
columns = [description[0] for description in cursor.description]
# Create dict from row data
row_dict = dict(zip(columns, row))
# Load content from files using stored hashes
content_fields = {
'html': row_dict['html'],
'cleaned_html': row_dict['cleaned_html'],
'markdown': row_dict['markdown'],
'extracted_content': row_dict['extracted_content'],
'screenshot': row_dict['screenshot']
}
for field, hash_value in content_fields.items():
if hash_value:
content = await self._load_content(
hash_value,
field.split('_')[0] # Get content type from field name
)
row_dict[field] = content or ""
else:
row_dict[field] = ""
# Parse JSON fields
json_fields = ['media', 'links', 'metadata', 'response_headers']
for field in json_fields:
try:
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {}
except json.JSONDecodeError:
row_dict[field] = {}
# Parse downloaded_files
try:
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else []
except json.JSONDecodeError:
row_dict['downloaded_files'] = []
# Remove any fields not in CrawlResult model
valid_fields = CrawlResult.__annotations__.keys()
filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields}
return CrawlResult(**filtered_dict)
try:
return await self.execute_with_retry(_get)
@@ -161,26 +203,27 @@ class AsyncDatabaseManager:
logger.error(f"Error retrieving cached URL: {e}")
return None
async def acache_url(
self,
url: str,
html: str,
cleaned_html: str,
markdown: str,
extracted_content: str,
success: bool,
media: str = "{}",
links: str = "{}",
metadata: str = "{}",
screenshot: str = "",
response_headers: str = "{}",
downloaded_files: str = "[]"
):
"""Cache URL data with retry logic"""
async def acache_url(self, result: CrawlResult):
"""Cache CrawlResult data"""
# Store content files and get hashes
content_map = {
'html': (result.html, 'html'),
'cleaned_html': (result.cleaned_html or "", 'cleaned'),
'markdown': (result.markdown or "", 'markdown'),
'extracted_content': (result.extracted_content or "", 'extracted'),
'screenshot': (result.screenshot or "", 'screenshots')
}
content_hashes = {}
for field, (content, content_type) in content_map.items():
content_hashes[field] = await self._store_content(content, content_type)
async def _cache(db):
await db.execute('''
INSERT INTO crawled_data (
url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files
url, html, cleaned_html, markdown,
extracted_content, success, media, links, metadata,
screenshot, response_headers, downloaded_files
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
@@ -189,13 +232,26 @@ class AsyncDatabaseManager:
markdown = excluded.markdown,
extracted_content = excluded.extracted_content,
success = excluded.success,
media = excluded.media,
links = excluded.links,
metadata = excluded.metadata,
media = excluded.media,
links = excluded.links,
metadata = excluded.metadata,
screenshot = excluded.screenshot,
response_headers = excluded.response_headers, -- Update response_headers
response_headers = excluded.response_headers,
downloaded_files = excluded.downloaded_files
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files))
''', (
result.url,
content_hashes['html'],
content_hashes['cleaned_html'],
content_hashes['markdown'],
content_hashes['extracted_content'],
result.success,
json.dumps(result.media),
json.dumps(result.links),
json.dumps(result.metadata or {}),
content_hashes['screenshot'],
json.dumps(result.response_headers or {}),
json.dumps(result.downloaded_files or [])
))
try:
await self.execute_with_retry(_cache)
@@ -234,6 +290,35 @@ class AsyncDatabaseManager:
await self.execute_with_retry(_flush)
except Exception as e:
logger.error(f"Error flushing database: {e}")
async def _store_content(self, content: str, content_type: str) -> str:
"""Store content in filesystem and return hash"""
if not content:
return ""
content_hash = generate_content_hash(content)
file_path = os.path.join(self.content_paths[content_type], content_hash)
# Only write if file doesn't exist
if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
await f.write(content)
return content_hash
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
"""Load content from filesystem by hash"""
if not content_hash:
return None
file_path = os.path.join(self.content_paths[content_type], content_hash)
try:
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
return await f.read()
except:
logger.error(f"Failed to load content: {file_path}")
return None
# Create a singleton instance
async_db_manager = AsyncDatabaseManager()