From 3846648c1207819e3e4ac40a78b5d3b77f5e9933 Mon Sep 17 00:00:00 2001 From: unclecode Date: Sat, 18 May 2024 15:42:19 +0800 Subject: [PATCH] chore: Update extraction strategy to support GPU, MPS, and CPU, add batch procesing for CPU devices --- crawl4ai/extraction_strategy.py | 36 +++++++++++++------ crawl4ai/model_loader.py | 61 ++++++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index a5d4b447..d5cd3747 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -187,8 +187,12 @@ class CosineStrategy(ExtractionStrategy): elif model_name == "BAAI/bge-small-en-v1.5": self.tokenizer, self.model = load_bge_small_en_v1_5() - + self.model.eval() # Ensure the model is in evaluation mode + self.buffer_embeddings = None + self.nlp, self.device = load_text_multilabel_classifier() + # self.default_batch_size = 16 if self.device.type == 'cpu' else 64 + self.default_batch_size = calculate_batch_size(self.device) if self.verbose: print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") @@ -219,7 +223,7 @@ class CosineStrategy(ExtractionStrategy): return filtered_docs - def get_embeddings(self, sentences: List[str], bypass_buffer=True): + def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=True): """ Get BERT embeddings for a list of sentences. @@ -231,15 +235,25 @@ class CosineStrategy(ExtractionStrategy): import torch # Tokenize sentences and convert to tensor - encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') - # Compute token embeddings - with torch.no_grad(): - model_output = self.model(**encoded_input) + if batch_size is None: + batch_size = self.default_batch_size + + all_embeddings = [] + for i in range(0, len(sentences), batch_size): + batch_sentences = sentences[i:i + batch_size] + encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt') + encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()} - # Get embeddings from the last hidden state (mean pooling) - embeddings = model_output.last_hidden_state.mean(1) - self.buffer_embeddings = embeddings.numpy() - return embeddings.numpy() + # Ensure no gradients are calculated + with torch.no_grad(): + model_output = self.model(**encoded_input) + + # Get embeddings from the last hidden state (mean pooling) + embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy() + all_embeddings.append(embeddings) + + self.buffer_embeddings = np.vstack(all_embeddings) + return self.buffer_embeddings def hierarchical_clustering(self, sentences: List[str]): """ @@ -319,7 +333,7 @@ class CosineStrategy(ExtractionStrategy): if self.verbose: print(f"[LOG] 🚀 Assign tags using {self.device}") - if self.device == "gpu": + if self.device.type in ["gpu", "cuda", "mps"]: labels = self.nlp([cluster['content'] for cluster in cluster_list]) for cluster, label in zip(cluster_list, labels): diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py index 1d7181d7..8955fa4b 100644 --- a/crawl4ai/model_loader.py +++ b/crawl4ai/model_loader.py @@ -5,6 +5,48 @@ import shutil from crawl4ai.config import MODEL_REPO_BRANCH import argparse +@lru_cache() +def get_available_memory(device): + import torch + if device.type == 'cuda': + return torch.cuda.get_device_properties(device).total_memory + elif device.type == 'mps': + return 48 * 1024 ** 3 # Assuming 8GB for MPS, as a conservative estimate + else: + return 0 + +@lru_cache() +def calculate_batch_size(device): + available_memory = get_available_memory(device) + + if device.type == 'cpu': + return 16 + elif device.type in ['cuda', 'mps']: + # Adjust these thresholds based on your model size and available memory + if available_memory > 32 * 1024 ** 3: # > 16GB + return 256 + elif available_memory > 16 * 1024 ** 3: # > 16GB + return 128 + elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB + return 64 + else: + return 32 + else: + return 16 # Default batch size + +def set_model_device(model): + import torch + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + + model.to(device) + return model, device + +@lru_cache() def get_home_folder(): home_folder = os.path.join(Path.home(), ".crawl4ai") os.makedirs(home_folder, exist_ok=True) @@ -17,6 +59,8 @@ def load_bert_base_uncased(): from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + model.eval() + model, device = set_model_device(model) return tokenizer, model @lru_cache() @@ -25,17 +69,20 @@ def load_bge_small_en_v1_5(): tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) model.eval() + model, device = set_model_device(model) return tokenizer, model @lru_cache() def load_text_classifier(): from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import pipeline + import torch tokenizer = AutoTokenizer.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") model = AutoModelForSequenceClassification.from_pretrained("dstefa/roberta-base_topic_classification_nyt_news") + model.eval() + model, device = set_model_device(model) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) - return pipe @lru_cache() @@ -51,18 +98,16 @@ def load_text_multilabel_classifier(): elif torch.backends.mps.is_available(): device = torch.device("mps") else: - return load_spacy_model() - # device = torch.device("cpu") + return load_spacy_model(), torch.device("cpu") MODEL = "cardiffnlp/tweet-topic-21-multi" tokenizer = AutoTokenizer.from_pretrained(MODEL, resume_download=None) model = AutoModelForSequenceClassification.from_pretrained(MODEL, resume_download=None) + model.eval() + model, device = set_model_device(model) class_mapping = model.config.id2label - - model.to(device) - def _classifier(texts, threshold=0.5, max_length=64): tokens = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length) tokens = {key: val.to(device) for key, val in tokens.items()} # Move tokens to the selected device @@ -81,7 +126,7 @@ def load_text_multilabel_classifier(): return batch_labels - return _classifier, "gpu" + return _classifier, device @lru_cache() def load_nltk_punkt(): @@ -142,7 +187,7 @@ def load_spacy_model(): except Exception as e: print(f"An error occurred: {e}") - return spacy.load(model_folder), "cpu" + return spacy.load(model_folder) def download_all_models(remove_existing=False): """Download all models required for Crawl4AI."""