Add model loader, update requirements.txt
This commit is contained in:
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
import subprocess, os
|
import subprocess, os
|
||||||
import shutil
|
import shutil
|
||||||
from .config import MODEL_REPO_BRANCH
|
from .config import MODEL_REPO_BRANCH
|
||||||
|
import argparse
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_bert_base_uncased():
|
def load_bert_base_uncased():
|
||||||
@@ -83,4 +84,40 @@ def load_spacy_model():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
return spacy.load(model_folder)
|
return spacy.load(model_folder)
|
||||||
|
|
||||||
|
|
||||||
|
def download_all_models(remove_existing=False):
|
||||||
|
"""Download all models required for Crawl4AI."""
|
||||||
|
if remove_existing:
|
||||||
|
print("[LOG] Removing existing models...")
|
||||||
|
home_folder = get_home_folder()
|
||||||
|
model_folders = [
|
||||||
|
os.path.join(home_folder, "models/reuters"),
|
||||||
|
os.path.join(home_folder, "models"),
|
||||||
|
]
|
||||||
|
for folder in model_folders:
|
||||||
|
if Path(folder).exists():
|
||||||
|
shutil.rmtree(folder)
|
||||||
|
print("[LOG] Existing models removed.")
|
||||||
|
|
||||||
|
# Load each model to trigger download
|
||||||
|
print("[LOG] Downloading BERT Base Uncased...")
|
||||||
|
load_bert_base_uncased()
|
||||||
|
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||||
|
load_bge_small_en_v1_5()
|
||||||
|
print("[LOG] Downloading spaCy EN Core Web SM...")
|
||||||
|
load_spacy_en_core_web_sm()
|
||||||
|
print("[LOG] Downloading custom spaCy model...")
|
||||||
|
load_spacy_model()
|
||||||
|
print("[LOG] ✅ All models downloaded successfully.")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Crawl4AI Model Downloader")
|
||||||
|
parser.add_argument('--remove-existing', action='store_true', help="Remove existing models before downloading")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
download_all_models(remove_existing=args.remove_existing)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,18 +1,17 @@
|
|||||||
fastapi
|
aiohttp==3.9.5
|
||||||
uvicorn
|
aiosqlite==0.20.0
|
||||||
selenium
|
bs4==0.0.2
|
||||||
pydantic
|
fastapi==0.111.0
|
||||||
aiohttp
|
html2text==2024.2.26
|
||||||
aiosqlite
|
httpx==0.27.0
|
||||||
chromedriver_autoinstaller
|
lazy_import==0.2.2
|
||||||
httpx
|
litellm==1.37.11
|
||||||
requests
|
nltk==3.8.1
|
||||||
bs4
|
pydantic==2.7.1
|
||||||
html2text
|
python-dotenv==1.0.1
|
||||||
litellm
|
requests==2.31.0
|
||||||
python-dotenv
|
rich==13.7.1
|
||||||
nltk
|
scikit-learn==1.4.2
|
||||||
lazy_import
|
selenium==4.20.0
|
||||||
rich
|
spacy==3.7.4
|
||||||
spacy
|
uvicorn==0.29.0
|
||||||
scikit-learn
|
|
||||||
|
|||||||
5
setup.py
5
setup.py
@@ -16,6 +16,11 @@ setup(
|
|||||||
license="MIT",
|
license="MIT",
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'crawl4ai-download-models=crawl4ai.model_loader:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
|
|||||||
Reference in New Issue
Block a user