diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py index 832b4240..5ff3695e 100644 --- a/crawl4ai/model_loader.py +++ b/crawl4ai/model_loader.py @@ -2,6 +2,7 @@ from functools import lru_cache from pathlib import Path import subprocess, os import shutil +import tarfile from crawl4ai.config import MODEL_REPO_BRANCH import argparse import urllib.request @@ -82,12 +83,19 @@ def load_bge_small_en_v1_5(): @lru_cache() def load_onnx_all_MiniLM_l6_v2(): from crawl4ai.onnx_embedding import DefaultEmbeddingModel - model_path = "models/onnx/model.onnx" - model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/model.onnx" - download_path = os.path.join(__location__, model_path) + model_path = "models/onnx.tar.gz" + model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/onnx.tar.gz" + __location__ = os.path.realpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) + download_path = os.path.join(__location__, model_path) + onnx_dir = os.path.join(__location__, "models/onnx") + + # Create the models directory if it does not exist + os.makedirs(os.path.dirname(download_path), exist_ok=True) + + # Download the tar.gz file if it does not exist if not os.path.exists(download_path): - # Define a download function with a simple progress display def download_with_progress(url, filename): def reporthook(block_num, block_size, total_size): downloaded = block_num * block_size @@ -95,12 +103,22 @@ def load_onnx_all_MiniLM_l6_v2(): if downloaded < total_size: print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='') else: - print("\rDownload complete! ") + print("\rDownload complete!") urllib.request.urlretrieve(url, filename, reporthook) download_with_progress(model_url, download_path) + # Extract the tar.gz file if the onnx directory does not exist + if not os.path.exists(onnx_dir): + with tarfile.open(download_path, "r:gz") as tar: + tar.extractall(path=os.path.join(__location__, "models")) + + # remove the tar.gz file + os.remove(download_path) + + + model = DefaultEmbeddingModel() return model