chore: Update ONNX model loading process
This commit is contained in:
@@ -2,6 +2,7 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
import subprocess, os
|
||||
import shutil
|
||||
import tarfile
|
||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||
import argparse
|
||||
import urllib.request
|
||||
@@ -82,12 +83,19 @@ def load_bge_small_en_v1_5():
|
||||
@lru_cache()
|
||||
def load_onnx_all_MiniLM_l6_v2():
|
||||
from crawl4ai.onnx_embedding import DefaultEmbeddingModel
|
||||
model_path = "models/onnx/model.onnx"
|
||||
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/model.onnx"
|
||||
download_path = os.path.join(__location__, model_path)
|
||||
|
||||
model_path = "models/onnx.tar.gz"
|
||||
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/onnx.tar.gz"
|
||||
__location__ = os.path.realpath(
|
||||
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
download_path = os.path.join(__location__, model_path)
|
||||
onnx_dir = os.path.join(__location__, "models/onnx")
|
||||
|
||||
# Create the models directory if it does not exist
|
||||
os.makedirs(os.path.dirname(download_path), exist_ok=True)
|
||||
|
||||
# Download the tar.gz file if it does not exist
|
||||
if not os.path.exists(download_path):
|
||||
# Define a download function with a simple progress display
|
||||
def download_with_progress(url, filename):
|
||||
def reporthook(block_num, block_size, total_size):
|
||||
downloaded = block_num * block_size
|
||||
@@ -95,12 +103,22 @@ def load_onnx_all_MiniLM_l6_v2():
|
||||
if downloaded < total_size:
|
||||
print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='')
|
||||
else:
|
||||
print("\rDownload complete! ")
|
||||
print("\rDownload complete!")
|
||||
|
||||
urllib.request.urlretrieve(url, filename, reporthook)
|
||||
|
||||
download_with_progress(model_url, download_path)
|
||||
|
||||
# Extract the tar.gz file if the onnx directory does not exist
|
||||
if not os.path.exists(onnx_dir):
|
||||
with tarfile.open(download_path, "r:gz") as tar:
|
||||
tar.extractall(path=os.path.join(__location__, "models"))
|
||||
|
||||
# remove the tar.gz file
|
||||
os.remove(download_path)
|
||||
|
||||
|
||||
|
||||
model = DefaultEmbeddingModel()
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user