- Add ONNX embedding model for CPU devices, Update the similarithy threshold, improve the embedding speed.
This commit is contained in:
@@ -4,6 +4,8 @@ import subprocess, os
|
||||
import shutil
|
||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||
import argparse
|
||||
import urllib.request
|
||||
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||
|
||||
@lru_cache()
|
||||
def get_available_memory(device):
|
||||
@@ -23,18 +25,20 @@ def calculate_batch_size(device):
|
||||
return 16
|
||||
elif device.type in ['cuda', 'mps']:
|
||||
# Adjust these thresholds based on your model size and available memory
|
||||
if available_memory > 32 * 1024 ** 3: # > 16GB
|
||||
if available_memory >= 31 * 1024 ** 3: # > 32GB
|
||||
return 256
|
||||
elif available_memory > 16 * 1024 ** 3: # > 16GB
|
||||
elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB
|
||||
return 128
|
||||
elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB
|
||||
elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB
|
||||
return 64
|
||||
else:
|
||||
return 32
|
||||
else:
|
||||
return 16 # Default batch size
|
||||
|
||||
def set_model_device(model):
|
||||
|
||||
@lru_cache()
|
||||
def get_device():
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
@@ -42,7 +46,10 @@ def set_model_device(model):
|
||||
device = torch.device('mps')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
return device
|
||||
|
||||
def set_model_device(model):
|
||||
device = get_device()
|
||||
model.to(device)
|
||||
return model, device
|
||||
|
||||
@@ -72,6 +79,31 @@ def load_bge_small_en_v1_5():
|
||||
model, device = set_model_device(model)
|
||||
return tokenizer, model
|
||||
|
||||
@lru_cache()
|
||||
def load_onnx_all_MiniLM_l6_v2():
|
||||
from crawl4ai.onnx_embedding import DefaultEmbeddingModel
|
||||
model_path = "models/onnx/model.onnx"
|
||||
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/model.onnx"
|
||||
download_path = os.path.join(__location__, model_path)
|
||||
|
||||
if not os.path.exists(download_path):
|
||||
# Define a download function with a simple progress display
|
||||
def download_with_progress(url, filename):
|
||||
def reporthook(block_num, block_size, total_size):
|
||||
downloaded = block_num * block_size
|
||||
percentage = 100 * downloaded / total_size
|
||||
if downloaded < total_size:
|
||||
print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='')
|
||||
else:
|
||||
print("\rDownload complete! ")
|
||||
|
||||
urllib.request.urlretrieve(url, filename, reporthook)
|
||||
|
||||
download_with_progress(model_url, download_path)
|
||||
|
||||
model = DefaultEmbeddingModel()
|
||||
return model
|
||||
|
||||
@lru_cache()
|
||||
def load_text_classifier():
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
@@ -204,10 +236,12 @@ def download_all_models(remove_existing=False):
|
||||
print("[LOG] Existing models removed.")
|
||||
|
||||
# Load each model to trigger download
|
||||
print("[LOG] Downloading BERT Base Uncased...")
|
||||
load_bert_base_uncased()
|
||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
||||
load_bge_small_en_v1_5()
|
||||
# print("[LOG] Downloading BERT Base Uncased...")
|
||||
# load_bert_base_uncased()
|
||||
# print("[LOG] Downloading BGE Small EN v1.5...")
|
||||
# load_bge_small_en_v1_5()
|
||||
print("[LOG] Downloading ONNX model...")
|
||||
load_onnx_all_MiniLM_l6_v2()
|
||||
print("[LOG] Downloading text classifier...")
|
||||
_, device = load_text_multilabel_classifier()
|
||||
print(f"[LOG] Text classifier loaded on {device}")
|
||||
|
||||
Reference in New Issue
Block a user