From d19488a82159651946b06faa66fd03401ba4068a Mon Sep 17 00:00:00 2001 From: unclecode Date: Thu, 16 May 2024 21:05:24 +0800 Subject: [PATCH] chore: Update model_loader.py to create necessary folders in the home directory --- crawl4ai/model_loader.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py index e51dfeac..9c50c524 100644 --- a/crawl4ai/model_loader.py +++ b/crawl4ai/model_loader.py @@ -1,11 +1,17 @@ from functools import lru_cache -from .utils import get_home_folder from pathlib import Path import subprocess, os import shutil from .config import MODEL_REPO_BRANCH import argparse +def get_home_folder(): + home_folder = os.path.join(Path.home(), ".crawl4ai") + os.makedirs(home_folder, exist_ok=True) + os.makedirs(f"{home_folder}/cache", exist_ok=True) + os.makedirs(f"{home_folder}/models", exist_ok=True) + return home_folder + @lru_cache() def load_bert_base_uncased(): from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel @@ -89,6 +95,8 @@ def load_spacy_model(): def download_all_models(remove_existing=False): """Download all models required for Crawl4AI.""" + print("[LOG] Welcome to the Crawl4AI Model Downloader!") + print("[LOG] This script will download all the models required for Crawl4AI.") if remove_existing: print("[LOG] Removing existing models...") home_folder = get_home_folder()