chore: Update extraction strategy to support GPU, MPS, and CPU, add batch procesing for CPU devices
This commit is contained in:
@@ -187,8 +187,12 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
elif model_name == "BAAI/bge-small-en-v1.5":
|
||||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||||
|
|
||||||
|
self.model.eval() # Ensure the model is in evaluation mode
|
||||||
|
self.buffer_embeddings = None
|
||||||
|
|
||||||
self.nlp, self.device = load_text_multilabel_classifier()
|
self.nlp, self.device = load_text_multilabel_classifier()
|
||||||
|
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
|
||||||
|
self.default_batch_size = calculate_batch_size(self.device)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||||
@@ -219,7 +223,7 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
return filtered_docs
|
return filtered_docs
|
||||||
|
|
||||||
def get_embeddings(self, sentences: List[str], bypass_buffer=True):
|
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=True):
|
||||||
"""
|
"""
|
||||||
Get BERT embeddings for a list of sentences.
|
Get BERT embeddings for a list of sentences.
|
||||||
|
|
||||||
@@ -231,15 +235,25 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
# Tokenize sentences and convert to tensor
|
# Tokenize sentences and convert to tensor
|
||||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
if batch_size is None:
|
||||||
# Compute token embeddings
|
batch_size = self.default_batch_size
|
||||||
with torch.no_grad():
|
|
||||||
model_output = self.model(**encoded_input)
|
all_embeddings = []
|
||||||
|
for i in range(0, len(sentences), batch_size):
|
||||||
|
batch_sentences = sentences[i:i + batch_size]
|
||||||
|
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
|
||||||
|
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
|
||||||
|
|
||||||
# Get embeddings from the last hidden state (mean pooling)
|
# Ensure no gradients are calculated
|
||||||
embeddings = model_output.last_hidden_state.mean(1)
|
with torch.no_grad():
|
||||||
self.buffer_embeddings = embeddings.numpy()
|
model_output = self.model(**encoded_input)
|
||||||
return embeddings.numpy()
|
|
||||||
|
# Get embeddings from the last hidden state (mean pooling)
|
||||||
|
embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
self.buffer_embeddings = np.vstack(all_embeddings)
|
||||||
|
return self.buffer_embeddings
|
||||||
|
|
||||||
def hierarchical_clustering(self, sentences: List[str]):
|
def hierarchical_clustering(self, sentences: List[str]):
|
||||||
"""
|
"""
|
||||||
@@ -319,7 +333,7 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] 🚀 Assign tags using {self.device}")
|
print(f"[LOG] 🚀 Assign tags using {self.device}")
|
||||||
|
|
||||||
if self.device == "gpu":
|
if self.device.type in ["gpu", "cuda", "mps"]:
|
||||||
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
labels = self.nlp([cluster['content'] for cluster in cluster_list])
|
||||||
|
|
||||||
for cluster, label in zip(cluster_list, labels):
|
for cluster, label in zip(cluster_list, labels):
|
||||||
|
|||||||
@@ -5,6 +5,48 @@ import shutil
|
|||||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_available_memory(device):
|
||||||
|
import torch
|
||||||
|
if device.type == 'cuda':
|
||||||
|
return torch.cuda.get_device_properties(device).total_memory
|
||||||
|
elif device.type == 'mps':
|
||||||
|
return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def calculate_batch_size(device):
|
||||||
|
available_memory = get_available_memory(device)
|
||||||
|
|
||||||
|
if device.type == 'cpu':
|
||||||
|
return 16
|
||||||
|
elif device.type in ['cuda', 'mps']:
|
||||||
|
# Adjust these thresholds based on your model size and available memory
|
||||||
|
if available_memory > 32 * 1024 ** 3: # > 16GB
|
||||||
|
return 256
|
||||||
|
elif available_memory > 16 * 1024 ** 3: # > 16GB
|
||||||
|
return 128
|
||||||
|
elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB
|
||||||
|
return 64
|
||||||
|
else:
|
||||||
|
return 32
|
||||||
|
else:
|
||||||
|
return 16 # Default batch size
|
||||||
|
|
||||||
|
def set_model_device(model):
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda')
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = torch.device('mps')
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
return model, device
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def get_home_folder():
|
def get_home_folder():
|
||||||
home_folder = os.path.join(Path.home(), ".crawl4ai")
|
home_folder = os.path.join(Path.home(), ".crawl4ai")
|
||||||
os.makedirs(home_folder, exist_ok=True)
|
os.makedirs(home_folder, exist_ok=True)
|
||||||
@@ -17,6 +59,8 @@ def load_bert_base_uncased():
|
|||||||
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
|
||||||
|
model.eval()
|
||||||
|
model, device = set_model_device(model)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -25,17 +69,20 @@ def load_bge_small_en_v1_5():
|
|||||||
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||||
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
model, device = set_model_device(model)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_text_classifier():
|
def load_text_classifier():
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||||
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news")
|
||||||
|
model.eval()
|
||||||
|
model, device = set_model_device(model)
|
||||||
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -51,18 +98,16 @@ def load_text_multilabel_classifier():
|
|||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
else:
|
else:
|
||||||
return load_spacy_model()
|
return load_spacy_model(), torch.device("cpu")
|
||||||
# device = torch.device("cpu")
|
|
||||||
|
|
||||||
|
|
||||||
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
MODEL = "cardiffnlp/tweet-topic-21-multi"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None)
|
||||||
|
model.eval()
|
||||||
|
model, device = set_model_device(model)
|
||||||
class_mapping = model.config.id2label
|
class_mapping = model.config.id2label
|
||||||
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
def _classifier(texts, threshold=0.5, max_length=64):
|
def _classifier(texts, threshold=0.5, max_length=64):
|
||||||
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
||||||
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
|
tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device
|
||||||
@@ -81,7 +126,7 @@ def load_text_multilabel_classifier():
|
|||||||
|
|
||||||
return batch_labels
|
return batch_labels
|
||||||
|
|
||||||
return _classifier, "gpu"
|
return _classifier, device
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_nltk_punkt():
|
def load_nltk_punkt():
|
||||||
@@ -142,7 +187,7 @@ def load_spacy_model():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
return spacy.load(model_folder), "cpu"
|
return spacy.load(model_folder)
|
||||||
|
|
||||||
def download_all_models(remove_existing=False):
|
def download_all_models(remove_existing=False):
|
||||||
"""Download all models required for Crawl4AI."""
|
"""Download all models required for Crawl4AI."""
|
||||||
|
|||||||
Reference in New Issue
Block a user