Add model loader, update requirements.txt

This commit is contained in:
unclecode
2024-05-16 20:08:21 +08:00
parent c8589f8da3
commit 8e28eb9efb
3 changed files with 60 additions and 19 deletions

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import subprocess, os import subprocess, os
import shutil import shutil
from .config import MODEL_REPO_BRANCH from .config import MODEL_REPO_BRANCH
import argparse
@lru_cache() @lru_cache()
def load_bert_base_uncased(): def load_bert_base_uncased():
@@ -84,3 +85,39 @@ def load_spacy_model():
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return spacy.load(model_folder) return spacy.load(model_folder)
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""
if remove_existing:
print("[LOG] Removing existing models...")
home_folder = get_home_folder()
model_folders = [
os.path.join(home_folder, "models/reuters"),
os.path.join(home_folder, "models"),
]
for folder in model_folders:
if Path(folder).exists():
shutil.rmtree(folder)
print("[LOG] Existing models removed.")
# Load each model to trigger download
print("[LOG] Downloading BERT Base Uncased...")
load_bert_base_uncased()
print("[LOG] Downloading BGE Small EN v1.5...")
load_bge_small_en_v1_5()
print("[LOG] Downloading spaCy EN Core Web SM...")
load_spacy_en_core_web_sm()
print("[LOG] Downloading custom spaCy model...")
load_spacy_model()
print("[LOG] ✅ All models downloaded successfully.")
def main():
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
args = parser.parse_args()
download_all_models(remove_existing=args.remove_existing)
if __name__ == "__main__":
main()

View File

@@ -1,18 +1,17 @@
fastapi aiohttp==3.9.5
uvicorn aiosqlite==0.20.0
selenium bs4==0.0.2
pydantic fastapi==0.111.0
aiohttp html2text==2024.2.26
aiosqlite httpx==0.27.0
chromedriver_autoinstaller lazy_import==0.2.2
httpx litellm==1.37.11
requests nltk==3.8.1
bs4 pydantic==2.7.1
html2text python-dotenv==1.0.1
litellm requests==2.31.0
python-dotenv rich==13.7.1
nltk scikit-learn==1.4.2
lazy_import selenium==4.20.0
rich spacy==3.7.4
spacy uvicorn==0.29.0
scikit-learn

View File

@@ -16,6 +16,11 @@ setup(
license="MIT", license="MIT",
packages=find_packages(), packages=find_packages(),
install_requires=requirements, install_requires=requirements,
entry_points={
'console_scripts': [
'crawl4ai-download-models=crawl4ai.model_loader:main',
],
},
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",