feat(database): implement version management and migration checks during initialization
This commit is contained in:
@@ -11,6 +11,7 @@ from .models import CrawlResult
|
||||
import xxhash
|
||||
import aiofiles
|
||||
from .config import NEED_MIGRATION
|
||||
from .version_manager import VersionManager
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,22 +29,49 @@ class AsyncDatabaseManager:
|
||||
self.connection_pool: Dict[int, aiosqlite.Connection] = {}
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.connection_semaphore = asyncio.Semaphore(pool_size)
|
||||
self._initialized = False
|
||||
self.version_manager = VersionManager()
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the database and connection pool"""
|
||||
try:
|
||||
logger.info("Initializing database...")
|
||||
# Ensure the database file exists
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
# Check if version update is needed
|
||||
needs_update = self.version_manager.needs_update()
|
||||
|
||||
# Always ensure base table exists
|
||||
await self.ainit_db()
|
||||
if NEED_MIGRATION:
|
||||
|
||||
# Verify the table exists
|
||||
async def verify_table(db):
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='crawled_data'"
|
||||
) as cursor:
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
raise Exception("crawled_data table was not created")
|
||||
|
||||
await self.execute_with_retry(verify_table)
|
||||
|
||||
# If version changed or fresh install, run updates
|
||||
if needs_update:
|
||||
logger.info("New version detected, running updates...")
|
||||
await self.update_db_schema()
|
||||
from .migrations import run_migration # Import here to avoid circular imports
|
||||
await run_migration()
|
||||
logger.info("Database initialization and migration completed successfully")
|
||||
self.version_manager.update_version() # Update stored version after successful migration
|
||||
logger.info("Version update completed successfully")
|
||||
else:
|
||||
logger.info("Database initialization completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {e}")
|
||||
logger.info("Database will be initialized on first use")
|
||||
raise
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup connections when shutting down"""
|
||||
@@ -55,6 +83,12 @@ class AsyncDatabaseManager:
|
||||
@asynccontextmanager
|
||||
async def get_connection(self):
|
||||
"""Connection pool manager"""
|
||||
if not self._initialized:
|
||||
async with self.pool_lock: # Prevent multiple simultaneous initializations
|
||||
if not self._initialized: # Double-check after acquiring lock
|
||||
await self.initialize()
|
||||
self._initialized = True
|
||||
|
||||
async with self.connection_semaphore:
|
||||
task_id = id(asyncio.current_task())
|
||||
try:
|
||||
@@ -79,6 +113,7 @@ class AsyncDatabaseManager:
|
||||
await self.connection_pool[task_id].close()
|
||||
del self.connection_pool[task_id]
|
||||
|
||||
|
||||
async def execute_with_retry(self, operation, *args):
|
||||
"""Execute database operations with retry logic"""
|
||||
for attempt in range(self.max_retries):
|
||||
|
||||
30
crawl4ai/version_manager.py
Normal file
30
crawl4ai/version_manager.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# version_manager.py
|
||||
import os
|
||||
from pathlib import Path
|
||||
from packaging import version
|
||||
from . import __version__
|
||||
|
||||
class VersionManager:
|
||||
def __init__(self):
|
||||
self.home_dir = Path.home() / ".crawl4ai"
|
||||
self.version_file = self.home_dir / "version.txt"
|
||||
|
||||
def get_installed_version(self):
|
||||
"""Get the version recorded in home directory"""
|
||||
if not self.version_file.exists():
|
||||
return None
|
||||
try:
|
||||
return version.parse(self.version_file.read_text().strip())
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_version(self):
|
||||
"""Update the version file to current library version"""
|
||||
self.version_file.write_text(__version__)
|
||||
|
||||
def needs_update(self):
|
||||
"""Check if database needs update based on version"""
|
||||
installed = self.get_installed_version()
|
||||
current = version.parse(__version__)
|
||||
return installed is None or installed < current
|
||||
|
||||
Reference in New Issue
Block a user