Add model loader, update requirements.txt

This commit is contained in:
unclecode
2024-05-16 20:08:21 +08:00
parent c8589f8da3
commit 8e28eb9efb
3 changed files with 60 additions and 19 deletions

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import subprocess, os
import shutil
from .config import MODEL_REPO_BRANCH
import argparse
@lru_cache()
def load_bert_base_uncased():
@@ -83,4 +84,40 @@ def load_spacy_model():
except Exception as e:
print(f"An error occurred: {e}")
return spacy.load(model_folder)
return spacy.load(model_folder)
def download_all_models(remove_existing=False):
"""Download all models required for Crawl4AI."""
if remove_existing:
print("[LOG] Removing existing models...")
home_folder = get_home_folder()
model_folders = [
os.path.join(home_folder, "models/reuters"),
os.path.join(home_folder, "models"),
]
for folder in model_folders:
if Path(folder).exists():
shutil.rmtree(folder)
print("[LOG] Existing models removed.")
# Load each model to trigger download
print("[LOG] Downloading BERT Base Uncased...")
load_bert_base_uncased()
print("[LOG] Downloading BGE Small EN v1.5...")
load_bge_small_en_v1_5()
print("[LOG] Downloading spaCy EN Core Web SM...")
load_spacy_en_core_web_sm()
print("[LOG] Downloading custom spaCy model...")
load_spacy_model()
print("[LOG] ✅ All models downloaded successfully.")
def main():
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
args = parser.parse_args()
download_all_models(remove_existing=args.remove_existing)
if __name__ == "__main__":
main()

View File

@@ -1,18 +1,17 @@
fastapi
uvicorn
selenium
pydantic
aiohttp
aiosqlite
chromedriver_autoinstaller
httpx
requests
bs4
html2text
litellm
python-dotenv
nltk
lazy_import
rich
spacy
scikit-learn
aiohttp==3.9.5
aiosqlite==0.20.0
bs4==0.0.2
fastapi==0.111.0
html2text==2024.2.26
httpx==0.27.0
lazy_import==0.2.2
litellm==1.37.11
nltk==3.8.1
pydantic==2.7.1
python-dotenv==1.0.1
requests==2.31.0
rich==13.7.1
scikit-learn==1.4.2
selenium==4.20.0
spacy==3.7.4
uvicorn==0.29.0

View File

@@ -16,6 +16,11 @@ setup(
license="MIT",
packages=find_packages(),
install_requires=requirements,
entry_points={
'console_scripts': [
'crawl4ai-download-models=crawl4ai.model_loader:main',
],
},
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",