chore: Update ONNX model loading process
This commit is contained in:
@@ -2,6 +2,7 @@ from functools import lru_cache
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import subprocess, os
|
import subprocess, os
|
||||||
import shutil
|
import shutil
|
||||||
|
import tarfile
|
||||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||||
import argparse
|
import argparse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
@@ -82,12 +83,19 @@ def load_bge_small_en_v1_5():
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_onnx_all_MiniLM_l6_v2():
|
def load_onnx_all_MiniLM_l6_v2():
|
||||||
from crawl4ai.onnx_embedding import DefaultEmbeddingModel
|
from crawl4ai.onnx_embedding import DefaultEmbeddingModel
|
||||||
model_path = "models/onnx/model.onnx"
|
|
||||||
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/model.onnx"
|
|
||||||
download_path = os.path.join(__location__, model_path)
|
|
||||||
|
|
||||||
|
model_path = "models/onnx.tar.gz"
|
||||||
|
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/onnx.tar.gz"
|
||||||
|
__location__ = os.path.realpath(
|
||||||
|
os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
download_path = os.path.join(__location__, model_path)
|
||||||
|
onnx_dir = os.path.join(__location__, "models/onnx")
|
||||||
|
|
||||||
|
# Create the models directory if it does not exist
|
||||||
|
os.makedirs(os.path.dirname(download_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Download the tar.gz file if it does not exist
|
||||||
if not os.path.exists(download_path):
|
if not os.path.exists(download_path):
|
||||||
# Define a download function with a simple progress display
|
|
||||||
def download_with_progress(url, filename):
|
def download_with_progress(url, filename):
|
||||||
def reporthook(block_num, block_size, total_size):
|
def reporthook(block_num, block_size, total_size):
|
||||||
downloaded = block_num * block_size
|
downloaded = block_num * block_size
|
||||||
@@ -95,12 +103,22 @@ def load_onnx_all_MiniLM_l6_v2():
|
|||||||
if downloaded < total_size:
|
if downloaded < total_size:
|
||||||
print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='')
|
print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='')
|
||||||
else:
|
else:
|
||||||
print("\rDownload complete! ")
|
print("\rDownload complete!")
|
||||||
|
|
||||||
urllib.request.urlretrieve(url, filename, reporthook)
|
urllib.request.urlretrieve(url, filename, reporthook)
|
||||||
|
|
||||||
download_with_progress(model_url, download_path)
|
download_with_progress(model_url, download_path)
|
||||||
|
|
||||||
|
# Extract the tar.gz file if the onnx directory does not exist
|
||||||
|
if not os.path.exists(onnx_dir):
|
||||||
|
with tarfile.open(download_path, "r:gz") as tar:
|
||||||
|
tar.extractall(path=os.path.join(__location__, "models"))
|
||||||
|
|
||||||
|
# remove the tar.gz file
|
||||||
|
os.remove(download_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model = DefaultEmbeddingModel()
|
model = DefaultEmbeddingModel()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user