New async database manager and migration support

- Introduced AsyncDatabaseManager for async DB management.
  - Added migration feature to transition to file-based storage.
  - Enhanced web crawler with improved caching logic.
  - Updated requirements and setup for async processing.
This commit is contained in:
UncleCode
2024-11-16 14:54:41 +08:00
parent ae7ebc0bd8
commit d0014c6793
8 changed files with 685 additions and 119 deletions

View File

@@ -0,0 +1,285 @@
import os
from pathlib import Path
import aiosqlite
import asyncio
from typing import Optional, Tuple, Dict
from contextlib import asynccontextmanager
import logging
import json # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash
import xxhash
import aiofiles
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DB_PATH = os.path.join(Path.home(), ".crawl4ai")
os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
class AsyncDatabaseManager:
def __init__(self, pool_size: int = 10, max_retries: int = 3):
self.db_path = DB_PATH
self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH))
self.pool_size = pool_size
self.max_retries = max_retries
self.connection_pool: Dict[int, aiosqlite.Connection] = {}
self.pool_lock = asyncio.Lock()
self.connection_semaphore = asyncio.Semaphore(pool_size)
async def initialize(self):
"""Initialize the database and connection pool"""
await self.ainit_db()
async def cleanup(self):
"""Cleanup connections when shutting down"""
async with self.pool_lock:
for conn in self.connection_pool.values():
await conn.close()
self.connection_pool.clear()
@asynccontextmanager
async def get_connection(self):
"""Connection pool manager"""
async with self.connection_semaphore:
task_id = id(asyncio.current_task())
try:
async with self.pool_lock:
if task_id not in self.connection_pool:
conn = await aiosqlite.connect(
self.db_path,
timeout=30.0
)
await conn.execute('PRAGMA journal_mode = WAL')
await conn.execute('PRAGMA busy_timeout = 5000')
self.connection_pool[task_id] = conn
yield self.connection_pool[task_id]
except Exception as e:
logger.error(f"Connection error: {e}")
raise
finally:
async with self.pool_lock:
if task_id in self.connection_pool:
await self.connection_pool[task_id].close()
del self.connection_pool[task_id]
async def execute_with_retry(self, operation, *args):
"""Execute database operations with retry logic"""
for attempt in range(self.max_retries):
try:
async with self.get_connection() as db:
result = await operation(db, *args)
await db.commit()
return result
except Exception as e:
if attempt == self.max_retries - 1:
logger.error(f"Operation failed after {self.max_retries} attempts: {e}")
raise
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
async def ainit_db(self):
"""Initialize database schema"""
async def _init(db):
await db.execute('''
CREATE TABLE IF NOT EXISTS crawled_data (
url TEXT PRIMARY KEY,
html TEXT,
cleaned_html TEXT,
markdown TEXT,
extracted_content TEXT,
success BOOLEAN,
media TEXT DEFAULT "{}",
links TEXT DEFAULT "{}",
metadata TEXT DEFAULT "{}",
screenshot TEXT DEFAULT "",
response_headers TEXT DEFAULT "{}",
downloaded_files TEXT DEFAULT "{}" -- New column added
)
''')
await self.execute_with_retry(_init)
await self.update_db_schema()
async def update_db_schema(self):
"""Update database schema if needed"""
async def _check_columns(db):
cursor = await db.execute("PRAGMA table_info(crawled_data)")
columns = await cursor.fetchall()
return [column[1] for column in columns]
column_names = await self.execute_with_retry(_check_columns)
# List of new columns to add
new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files']
for column in new_columns:
if column not in column_names:
await self.aalter_db_add_column(column)
async def aalter_db_add_column(self, new_column: str):
"""Add new column to the database"""
async def _alter(db):
if new_column == 'response_headers':
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"')
else:
await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
logger.info(f"Added column '{new_column}' to the database.")
await self.execute_with_retry(_alter)
async def aget_cached_url(self, url: str) -> Optional[Tuple[str, str, str, str, str, bool, str, str, str, str]]:
"""Retrieve cached URL data"""
async def _get(db):
async with db.execute(
'''
SELECT url, html, cleaned_html, markdown,
extracted_content, success, media, links,
metadata, screenshot, response_headers,
downloaded_files
FROM crawled_data WHERE url = ?
''',
(url,)
) as cursor:
row = await cursor.fetchone()
if row:
# Load content from files using stored hashes
html = await self._load_content(row[1], 'html') if row[1] else ""
cleaned = await self._load_content(row[2], 'cleaned') if row[2] else ""
markdown = await self._load_content(row[3], 'markdown') if row[3] else ""
extracted = await self._load_content(row[4], 'extracted') if row[4] else ""
screenshot = await self._load_content(row[9], 'screenshots') if row[9] else ""
return (
row[0], # url
html or "", # Return empty string if file not found
cleaned or "",
markdown or "",
extracted or "",
row[5], # success
json.loads(row[6] or '{}'), # media
json.loads(row[7] or '{}'), # links
json.loads(row[8] or '{}'), # metadata
screenshot or "",
json.loads(row[10] or '{}'), # response_headers
json.loads(row[11] or '[]') # downloaded_files
)
return None
try:
return await self.execute_with_retry(_get)
except Exception as e:
logger.error(f"Error retrieving cached URL: {e}")
return None
async def acache_url(self, url: str, html: str, cleaned_html: str,
markdown: str, extracted_content: str, success: bool,
media: str = "{}", links: str = "{}",
metadata: str = "{}", screenshot: str = "",
response_headers: str = "{}", downloaded_files: str = "[]"):
"""Cache URL data with content stored in filesystem"""
# Store content files and get hashes
html_hash = await self._store_content(html, 'html')
cleaned_hash = await self._store_content(cleaned_html, 'cleaned')
markdown_hash = await self._store_content(markdown, 'markdown')
extracted_hash = await self._store_content(extracted_content, 'extracted')
screenshot_hash = await self._store_content(screenshot, 'screenshots')
async def _cache(db):
await db.execute('''
INSERT INTO crawled_data (
url, html, cleaned_html, markdown,
extracted_content, success, media, links, metadata,
screenshot, response_headers, downloaded_files
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
html = excluded.html,
cleaned_html = excluded.cleaned_html,
markdown = excluded.markdown,
extracted_content = excluded.extracted_content,
success = excluded.success,
media = excluded.media,
links = excluded.links,
metadata = excluded.metadata,
screenshot = excluded.screenshot,
response_headers = excluded.response_headers,
downloaded_files = excluded.downloaded_files
''', (url, html_hash, cleaned_hash, markdown_hash, extracted_hash,
success, media, links, metadata, screenshot_hash,
response_headers, downloaded_files))
try:
await self.execute_with_retry(_cache)
except Exception as e:
logger.error(f"Error caching URL: {e}")
async def aget_total_count(self) -> int:
"""Get total number of cached URLs"""
async def _count(db):
async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor:
result = await cursor.fetchone()
return result[0] if result else 0
try:
return await self.execute_with_retry(_count)
except Exception as e:
logger.error(f"Error getting total count: {e}")
return 0
async def aclear_db(self):
"""Clear all data from the database"""
async def _clear(db):
await db.execute('DELETE FROM crawled_data')
try:
await self.execute_with_retry(_clear)
except Exception as e:
logger.error(f"Error clearing database: {e}")
async def aflush_db(self):
"""Drop the entire table"""
async def _flush(db):
await db.execute('DROP TABLE IF EXISTS crawled_data')
try:
await self.execute_with_retry(_flush)
except Exception as e:
logger.error(f"Error flushing database: {e}")
async def _store_content(self, content: str, content_type: str) -> str:
"""Store content in filesystem and return hash"""
if not content:
return ""
content_hash = generate_content_hash(content)
file_path = os.path.join(self.content_paths[content_type], content_hash)
# Only write if file doesn't exist
if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
await f.write(content)
return content_hash
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
"""Load content from filesystem by hash"""
if not content_hash:
return None
file_path = os.path.join(self.content_paths[content_type], content_hash)
try:
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
return await f.read()
except:
logger.error(f"Failed to load content: {file_path}")
return None
# Create a singleton instance
async_db_manager = AsyncDatabaseManager()

View File

@@ -6,7 +6,11 @@ from typing import Optional, Tuple, Dict
from contextlib import asynccontextmanager
import logging
import json # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash
from .models import CrawlResult
import xxhash
import aiofiles
from .config import NEED_MIGRATION
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -18,6 +22,7 @@ DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
class AsyncDatabaseManager:
def __init__(self, pool_size: int = 10, max_retries: int = 3):
self.db_path = DB_PATH
self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH))
self.pool_size = pool_size
self.max_retries = max_retries
self.connection_pool: Dict[int, aiosqlite.Connection] = {}
@@ -26,8 +31,20 @@ class AsyncDatabaseManager:
async def initialize(self):
"""Initialize the database and connection pool"""
await self.ainit_db()
try:
logger.info("Initializing database...")
await self.ainit_db()
if NEED_MIGRATION:
await self.update_db_schema()
from .migrations import run_migration # Import here to avoid circular imports
await run_migration()
logger.info("Database initialization and migration completed successfully")
else:
logger.info("Database initialization completed successfully")
except Exception as e:
logger.error(f"Database initialization error: {e}")
logger.info("Database will be initialized on first use")
async def cleanup(self):
"""Cleanup connections when shutting down"""
async with self.pool_lock:
@@ -97,7 +114,7 @@ class AsyncDatabaseManager:
''')
await self.execute_with_retry(_init)
await self.update_db_schema()
async def update_db_schema(self):
"""Update database schema if needed"""
@@ -126,34 +143,59 @@ class AsyncDatabaseManager:
await self.execute_with_retry(_alter)
async def aget_cached_url(self, url: str) -> Optional[Tuple[str, str, str, str, str, bool, str, str, str, str]]:
"""Retrieve cached URL data"""
async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
"""Retrieve cached URL data as CrawlResult"""
async def _get(db):
async with db.execute(
'''
SELECT url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files
FROM crawled_data WHERE url = ?
''',
(url,)
'SELECT * FROM crawled_data WHERE url = ?', (url,)
) as cursor:
row = await cursor.fetchone()
if row:
# Deserialize JSON fields
return (
row[0], # url
row[1], # html
row[2], # cleaned_html
row[3], # markdown
row[4], # extracted_content
row[5], # success
json.loads(row[6] or '{}'), # media
json.loads(row[7] or '{}'), # links
json.loads(row[8] or '{}'), # metadata
row[9], # screenshot
json.loads(row[10] or '{}'), # response_headers
json.loads(row[11] or '[]') # downloaded_files
)
return None
if not row:
return None
# Get column names
columns = [description[0] for description in cursor.description]
# Create dict from row data
row_dict = dict(zip(columns, row))
# Load content from files using stored hashes
content_fields = {
'html': row_dict['html'],
'cleaned_html': row_dict['cleaned_html'],
'markdown': row_dict['markdown'],
'extracted_content': row_dict['extracted_content'],
'screenshot': row_dict['screenshot']
}
for field, hash_value in content_fields.items():
if hash_value:
content = await self._load_content(
hash_value,
field.split('_')[0] # Get content type from field name
)
row_dict[field] = content or ""
else:
row_dict[field] = ""
# Parse JSON fields
json_fields = ['media', 'links', 'metadata', 'response_headers']
for field in json_fields:
try:
row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {}
except json.JSONDecodeError:
row_dict[field] = {}
# Parse downloaded_files
try:
row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else []
except json.JSONDecodeError:
row_dict['downloaded_files'] = []
# Remove any fields not in CrawlResult model
valid_fields = CrawlResult.__annotations__.keys()
filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields}
return CrawlResult(**filtered_dict)
try:
return await self.execute_with_retry(_get)
@@ -161,26 +203,27 @@ class AsyncDatabaseManager:
logger.error(f"Error retrieving cached URL: {e}")
return None
async def acache_url(
self,
url: str,
html: str,
cleaned_html: str,
markdown: str,
extracted_content: str,
success: bool,
media: str = "{}",
links: str = "{}",
metadata: str = "{}",
screenshot: str = "",
response_headers: str = "{}",
downloaded_files: str = "[]"
):
"""Cache URL data with retry logic"""
async def acache_url(self, result: CrawlResult):
"""Cache CrawlResult data"""
# Store content files and get hashes
content_map = {
'html': (result.html, 'html'),
'cleaned_html': (result.cleaned_html or "", 'cleaned'),
'markdown': (result.markdown or "", 'markdown'),
'extracted_content': (result.extracted_content or "", 'extracted'),
'screenshot': (result.screenshot or "", 'screenshots')
}
content_hashes = {}
for field, (content, content_type) in content_map.items():
content_hashes[field] = await self._store_content(content, content_type)
async def _cache(db):
await db.execute('''
INSERT INTO crawled_data (
url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files
url, html, cleaned_html, markdown,
extracted_content, success, media, links, metadata,
screenshot, response_headers, downloaded_files
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
@@ -189,13 +232,26 @@ class AsyncDatabaseManager:
markdown = excluded.markdown,
extracted_content = excluded.extracted_content,
success = excluded.success,
media = excluded.media,
links = excluded.links,
metadata = excluded.metadata,
media = excluded.media,
links = excluded.links,
metadata = excluded.metadata,
screenshot = excluded.screenshot,
response_headers = excluded.response_headers, -- Update response_headers
response_headers = excluded.response_headers,
downloaded_files = excluded.downloaded_files
''', (url, html, cleaned_html, markdown, extracted_content, success, media, links, metadata, screenshot, response_headers, downloaded_files))
''', (
result.url,
content_hashes['html'],
content_hashes['cleaned_html'],
content_hashes['markdown'],
content_hashes['extracted_content'],
result.success,
json.dumps(result.media),
json.dumps(result.links),
json.dumps(result.metadata or {}),
content_hashes['screenshot'],
json.dumps(result.response_headers or {}),
json.dumps(result.downloaded_files or [])
))
try:
await self.execute_with_retry(_cache)
@@ -234,6 +290,35 @@ class AsyncDatabaseManager:
await self.execute_with_retry(_flush)
except Exception as e:
logger.error(f"Error flushing database: {e}")
async def _store_content(self, content: str, content_type: str) -> str:
"""Store content in filesystem and return hash"""
if not content:
return ""
content_hash = generate_content_hash(content)
file_path = os.path.join(self.content_paths[content_type], content_hash)
# Only write if file doesn't exist
if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
await f.write(content)
return content_hash
async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
"""Load content from filesystem by hash"""
if not content_hash:
return None
file_path = os.path.join(self.content_paths[content_type], content_hash)
try:
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
return await f.read()
except:
logger.error(f"Failed to load content: {file_path}")
return None
# Create a singleton instance
async_db_manager = AsyncDatabaseManager()

View File

@@ -47,17 +47,17 @@ class AsyncWebCrawler:
async def awarmup(self):
# Print a message for crawl4ai and its version
print(f"[LOG] 🚀 Crawl4AI {crawl4ai_version}")
if self.verbose:
print(f"[LOG] 🚀 Crawl4AI {crawl4ai_version}")
print("[LOG] 🌤️ Warming up the AsyncWebCrawler")
# await async_db_manager.ainit_db()
await async_db_manager.initialize()
await self.arun(
url="https://google.com/",
word_count_threshold=5,
bypass_cache=False,
verbose=False,
)
# # await async_db_manager.initialize()
# await self.arun(
# url="https://google.com/",
# word_count_threshold=5,
# bypass_cache=False,
# verbose=False,
# )
self.ready = True
if self.verbose:
print("[LOG] 🌞 AsyncWebCrawler is ready to crawl")
@@ -73,6 +73,9 @@ class AsyncWebCrawler:
screenshot: bool = False,
user_agent: str = None,
verbose=True,
disable_cache: bool = False,
no_cache_read: bool = False,
no_cache_write: bool = False,
**kwargs,
) -> CrawlResult:
"""
@@ -89,6 +92,11 @@ class AsyncWebCrawler:
CrawlResult: The result of the crawling and processing.
"""
try:
if disable_cache:
bypass_cache = True
no_cache_read = True
no_cache_write = True
extraction_strategy = extraction_strategy or NoExtractionStrategy()
extraction_strategy.verbose = verbose
if not isinstance(extraction_strategy, ExtractionStrategy):
@@ -108,36 +116,39 @@ class AsyncWebCrawler:
is_raw_html = url.startswith("raw:")
_url = url if not is_raw_html else "Raw HTML"
if is_web_url and not bypass_cache and not self.always_by_pass_cache:
cached = await async_db_manager.aget_cached_url(url)
start_time = time.perf_counter()
cached_result = None
if is_web_url and (not bypass_cache or not no_cache_read) and not self.always_by_pass_cache:
cached_result = await async_db_manager.aget_cached_url(url)
# if not bypass_cache and not self.always_by_pass_cache:
# cached = await async_db_manager.aget_cached_url(url)
if kwargs.get("warmup", True) and not self.ready:
return None
if cached:
html = sanitize_input_encode(cached[1])
extracted_content = sanitize_input_encode(cached[4])
if cached_result:
html = sanitize_input_encode(cached_result.html)
extracted_content = sanitize_input_encode(cached_result.extracted_content or "")
if screenshot:
screenshot_data = cached[9]
screenshot_data = cached_result.screenshot
if not screenshot_data:
cached = None
cached_result = None
if verbose:
print(
f"[LOG] 1⃣ ✅ Page fetched (cache) for {_url}, success: {bool(html)}, time taken: {time.perf_counter() - start_time:.2f} seconds"
)
if not cached or not html:
t1 = time.time()
t1 = time.perf_counter()
if user_agent:
self.crawler_strategy.update_user_agent(user_agent)
async_response: AsyncCrawlResponse = await self.crawler_strategy.crawl(url, screenshot=screenshot, **kwargs)
html = sanitize_input_encode(async_response.html)
screenshot_data = async_response.screenshot
t2 = time.time()
t2 = time.perf_counter()
if verbose:
print(
f"[LOG] 🚀 Crawling done for {_url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
f"[LOG] 1⃣ ✅ Page fetched (no-cache) for {_url}, success: {bool(html)}, time taken: {t2 - t1:.2f} seconds"
)
t1 = time.perf_counter()
crawl_result = await self.aprocess_html(
url=url,
html=html,
@@ -163,30 +174,19 @@ class AsyncWebCrawler:
crawl_result.downloaded_files = async_response.downloaded_files
else:
crawl_result.status_code = 200
crawl_result.response_headers = cached[10]
# crawl_result.downloaded_files = cached[11]
crawl_result.response_headers = cached_result.response_headers if cached_result else {}
crawl_result.success = bool(html)
crawl_result.session_id = kwargs.get("session_id", None)
if verbose:
print(
f"[LOG] 🔥 🚀 Crawling done for {_url}, success: {crawl_result.success}, time taken: {time.perf_counter() - start_time:.2f} seconds"
)
if not is_raw_html:
if not bool(cached) or kwargs.get("bypass_cache", False) or self.always_by_pass_cache:
await async_db_manager.acache_url(
url = url,
html = html,
cleaned_html = crawl_result.cleaned_html,
markdown = crawl_result.markdown,
extracted_content = extracted_content,
success = True,
media = json.dumps(crawl_result.media),
links = json.dumps(crawl_result.links),
metadata = json.dumps(crawl_result.metadata),
screenshot=screenshot,
response_headers=json.dumps(crawl_result.response_headers),
downloaded_files=json.dumps(crawl_result.downloaded_files),
)
if not is_raw_html and not no_cache_write:
if not bool(cached_result) or kwargs.get("bypass_cache", False) or self.always_by_pass_cache:
await async_db_manager.acache_url(crawl_result)
return crawl_result
@@ -258,11 +258,11 @@ class AsyncWebCrawler:
verbose: bool,
**kwargs,
) -> CrawlResult:
t = time.time()
t = time.perf_counter()
# Extract content from HTML
try:
_url = url if not kwargs.get("is_raw_html", False) else "Raw HTML"
t1 = time.time()
t1 = time.perf_counter()
scrapping_strategy = WebScrapingStrategy()
# result = await scrapping_strategy.ascrap(
result = scrapping_strategy.scrap(
@@ -276,10 +276,6 @@ class AsyncWebCrawler:
),
**kwargs,
)
if verbose:
print(
f"[LOG] 🚀 Content extracted for {_url}, success: True, time taken: {time.time() - t1:.2f} seconds"
)
if result is None:
raise ValueError(f"Process HTML, Failed to extract content from the website: {url}")
@@ -295,13 +291,14 @@ class AsyncWebCrawler:
media = result.get("media", [])
links = result.get("links", [])
metadata = result.get("metadata", {})
if verbose:
print(
f"[LOG] 2⃣ ✅ Scraping done for {_url}, success: True, time taken: {time.perf_counter() - t1:.2f} seconds"
)
if extracted_content is None and extraction_strategy and chunking_strategy:
if verbose:
print(
f"[LOG] 🔥 Extracting semantic blocks for {_url}, Strategy: {self.__class__.__name__}"
)
if extracted_content is None and extraction_strategy and chunking_strategy and not isinstance(extraction_strategy, NoExtractionStrategy):
t1 = time.perf_counter()
# Check if extraction strategy is type of JsonCssExtractionStrategy
if isinstance(extraction_strategy, JsonCssExtractionStrategy) or isinstance(extraction_strategy, JsonCssExtractionStrategy):
extraction_strategy.verbose = verbose
@@ -311,11 +308,10 @@ class AsyncWebCrawler:
sections = chunking_strategy.chunk(markdown)
extracted_content = extraction_strategy.run(url, sections)
extracted_content = json.dumps(extracted_content, indent=4, default=str, ensure_ascii=False)
if verbose:
print(
f"[LOG] 🚀 Extraction done for {_url}, time taken: {time.time() - t:.2f} seconds."
)
if verbose:
print(
f"[LOG] 3⃣ ✅ Extraction done for {_url}, time taken: {time.perf_counter() - t1:.2f} seconds"
)
screenshot = None if not screenshot else screenshot

View File

@@ -52,4 +52,6 @@ SOCIAL_MEDIA_DOMAINS = [
# If image is in the first half of the total images extracted from the page
IMAGE_SCORE_THRESHOLD = 2
MAX_METRICS_HISTORY = 1000
MAX_METRICS_HISTORY = 1000
NEED_MIGRATION = True

152
crawl4ai/migrations.py Normal file
View File

@@ -0,0 +1,152 @@
import os
import asyncio
import logging
from pathlib import Path
import aiosqlite
from typing import Optional
import xxhash
import aiofiles
import shutil
import time
from datetime import datetime
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatabaseMigration:
def __init__(self, db_path: str):
self.db_path = db_path
self.content_paths = self._ensure_content_dirs(os.path.dirname(db_path))
def _ensure_content_dirs(self, base_path: str) -> dict:
dirs = {
'html': 'html_content',
'cleaned': 'cleaned_html',
'markdown': 'markdown_content',
'extracted': 'extracted_content',
'screenshots': 'screenshots'
}
content_paths = {}
for key, dirname in dirs.items():
path = os.path.join(base_path, dirname)
os.makedirs(path, exist_ok=True)
content_paths[key] = path
return content_paths
def _generate_content_hash(self, content: str) -> str:
x = xxhash.xxh64()
x.update(content.encode())
content_hash = x.hexdigest()
return content_hash
# return hashlib.sha256(content.encode()).hexdigest()
async def _store_content(self, content: str, content_type: str) -> str:
if not content:
return ""
content_hash = self._generate_content_hash(content)
file_path = os.path.join(self.content_paths[content_type], content_hash)
if not os.path.exists(file_path):
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
await f.write(content)
return content_hash
async def migrate_database(self):
"""Migrate existing database to file-based storage"""
logger.info("Starting database migration...")
try:
async with aiosqlite.connect(self.db_path) as db:
# Get all rows
async with db.execute(
'''SELECT url, html, cleaned_html, markdown,
extracted_content, screenshot FROM crawled_data'''
) as cursor:
rows = await cursor.fetchall()
migrated_count = 0
for row in rows:
url, html, cleaned_html, markdown, extracted_content, screenshot = row
# Store content in files and get hashes
html_hash = await self._store_content(html, 'html')
cleaned_hash = await self._store_content(cleaned_html, 'cleaned')
markdown_hash = await self._store_content(markdown, 'markdown')
extracted_hash = await self._store_content(extracted_content, 'extracted')
screenshot_hash = await self._store_content(screenshot, 'screenshots')
# Update database with hashes
await db.execute('''
UPDATE crawled_data
SET html = ?,
cleaned_html = ?,
markdown = ?,
extracted_content = ?,
screenshot = ?
WHERE url = ?
''', (html_hash, cleaned_hash, markdown_hash,
extracted_hash, screenshot_hash, url))
migrated_count += 1
if migrated_count % 100 == 0:
logger.info(f"Migrated {migrated_count} records...")
await db.commit()
logger.info(f"Migration completed. {migrated_count} records processed.")
except Exception as e:
logger.error(f"Migration failed: {e}")
raise
async def backup_database(db_path: str) -> str:
"""Create backup of existing database"""
if not os.path.exists(db_path):
logger.info("No existing database found. Skipping backup.")
return None
# Create backup with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_path = f"{db_path}.backup_{timestamp}"
try:
# Wait for any potential write operations to finish
await asyncio.sleep(1)
# Create backup
shutil.copy2(db_path, backup_path)
logger.info(f"Database backup created at: {backup_path}")
return backup_path
except Exception as e:
logger.error(f"Backup failed: {e}")
raise
async def run_migration(db_path: Optional[str] = None):
"""Run database migration"""
if db_path is None:
db_path = os.path.join(Path.home(), ".crawl4ai", "crawl4ai.db")
if not os.path.exists(db_path):
logger.info("No existing database found. Skipping migration.")
return
# Create backup first
backup_path = await backup_database(db_path)
if not backup_path:
return
migration = DatabaseMigration(db_path)
await migration.migrate_database()
def main():
"""CLI entry point for migration"""
import argparse
parser = argparse.ArgumentParser(description='Migrate Crawl4AI database to file-based storage')
parser.add_argument('--db-path', help='Custom database path')
args = parser.parse_args()
asyncio.run(run_migration(args.db_path))
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,9 @@ from typing import Dict, Any
from urllib.parse import urljoin
import requests
from requests.exceptions import InvalidSchema
import hashlib
from typing import Optional, Tuple, Dict, Any
import xxhash
class InvalidCSSSelectorError(Exception):
pass
@@ -1109,3 +1112,27 @@ def clean_tokens(tokens: list[str]) -> list[str]:
and not token.startswith('')
and not token.startswith('')
and not token.startswith('')]
def generate_content_hash(content: str) -> str:
"""Generate a unique hash for content"""
return xxhash.xxh64(content.encode()).hexdigest()
# return hashlib.sha256(content.encode()).hexdigest()
def ensure_content_dirs(base_path: str) -> Dict[str, str]:
"""Create content directories if they don't exist"""
dirs = {
'html': 'html_content',
'cleaned': 'cleaned_html',
'markdown': 'markdown_content',
'extracted': 'extracted_content',
'screenshots': 'screenshots'
}
content_paths = {}
for key, dirname in dirs.items():
path = os.path.join(base_path, dirname)
os.makedirs(path, exist_ok=True)
content_paths[key] = path
return content_paths