- Add ONNX embedding model for CPU devices, Update the similarithy threshold, improve the embedding speed.
This commit is contained in:
@@ -157,7 +157,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
|
|||||||
return extracted_content
|
return extracted_content
|
||||||
|
|
||||||
class CosineStrategy(ExtractionStrategy):
|
class CosineStrategy(ExtractionStrategy):
|
||||||
def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5', **kwargs):
|
def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'sentence-transformers/all-MiniLM-L6-v2', sim_threshold = 0.3, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize the strategy with clustering parameters.
|
Initialize the strategy with clustering parameters.
|
||||||
|
|
||||||
@@ -174,56 +174,96 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
self.max_dist = max_dist
|
self.max_dist = max_dist
|
||||||
self.linkage_method = linkage_method
|
self.linkage_method = linkage_method
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
self.sim_threshold = sim_threshold
|
||||||
self.timer = time.time()
|
self.timer = time.time()
|
||||||
self.verbose = kwargs.get("verbose", False)
|
self.verbose = kwargs.get("verbose", False)
|
||||||
|
|
||||||
self.buffer_embeddings = np.array([])
|
self.buffer_embeddings = np.array([])
|
||||||
|
self.get_embedding_method = "direct"
|
||||||
|
|
||||||
|
self.device = get_device()
|
||||||
|
self.default_batch_size = calculate_batch_size(self.device)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] Loading Extraction Model {model_name}")
|
print(f"[LOG] Loading Extraction Model for {self.device.type} device.")
|
||||||
|
|
||||||
if model_name == "bert-base-uncased":
|
if self.device.type == "cpu":
|
||||||
self.tokenizer, self.model = load_bert_base_uncased()
|
self.model = load_onnx_all_MiniLM_l6_v2()
|
||||||
elif model_name == "BAAI/bge-small-en-v1.5":
|
self.tokenizer = self.model.tokenizer
|
||||||
|
self.get_embedding_method = "direct"
|
||||||
|
else:
|
||||||
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||||
|
self.model.eval()
|
||||||
self.model.eval() # Ensure the model is in evaluation mode
|
self.get_embedding_method = "batch"
|
||||||
self.buffer_embeddings = None
|
|
||||||
|
|
||||||
|
self.buffer_embeddings = np.array([])
|
||||||
|
|
||||||
|
# if model_name == "bert-base-uncased":
|
||||||
|
# self.tokenizer, self.model = load_bert_base_uncased()
|
||||||
|
# self.model.eval() # Ensure the model is in evaluation mode
|
||||||
|
# self.get_embedding_method = "batch"
|
||||||
|
# elif model_name == "BAAI/bge-small-en-v1.5":
|
||||||
|
# self.tokenizer, self.model = load_bge_small_en_v1_5()
|
||||||
|
# self.model.eval() # Ensure the model is in evaluation mode
|
||||||
|
# self.get_embedding_method = "batch"
|
||||||
|
# elif model_name == "sentence-transformers/all-MiniLM-L6-v2":
|
||||||
|
# self.model = load_onnx_all_MiniLM_l6_v2()
|
||||||
|
# self.tokenizer = self.model.tokenizer
|
||||||
|
# self.get_embedding_method = "direct"
|
||||||
|
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[LOG] Loading Multilabel Classifier for {self.device.type} device.")
|
||||||
|
|
||||||
self.nlp, self.device = load_text_multilabel_classifier()
|
self.nlp, self.device = load_text_multilabel_classifier()
|
||||||
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
|
# self.default_batch_size = 16 if self.device.type == 'cpu' else 64
|
||||||
self.default_batch_size = calculate_batch_size(self.device)
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
|
||||||
|
|
||||||
def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, threshold: float = 0.5) -> List[str]:
|
def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, at_least_k: int = 20) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Filter documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
|
Filter and sort documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
|
||||||
|
|
||||||
:param documents: List of text chunks (documents).
|
:param documents: List of text chunks (documents).
|
||||||
:param semantic_filter: A string containing the keywords for filtering.
|
:param semantic_filter: A string containing the keywords for filtering.
|
||||||
:param threshold: Cosine similarity threshold for filtering documents.
|
:param threshold: Cosine similarity threshold for filtering documents.
|
||||||
:return: Filtered list of documents.
|
:param at_least_k: Minimum number of documents to return.
|
||||||
|
:return: List of filtered documents, ensuring at least `at_least_k` documents.
|
||||||
"""
|
"""
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
if not semantic_filter:
|
if not semantic_filter:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
if len(documents) < at_least_k:
|
||||||
|
at_least_k = len(documents) // 2
|
||||||
|
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
# Compute embedding for the keyword filter
|
# Compute embedding for the keyword filter
|
||||||
query_embedding = self.get_embeddings([semantic_filter])[0]
|
query_embedding = self.get_embeddings([semantic_filter])[0]
|
||||||
|
|
||||||
# Compute embeddings for the docu ments
|
# Compute embeddings for the documents
|
||||||
document_embeddings = self.get_embeddings(documents)
|
document_embeddings = self.get_embeddings(documents)
|
||||||
|
|
||||||
# Calculate cosine similarity between the query embedding and document embeddings
|
# Calculate cosine similarity between the query embedding and document embeddings
|
||||||
similarities = cosine_similarity([query_embedding], document_embeddings).flatten()
|
similarities = cosine_similarity([query_embedding], document_embeddings).flatten()
|
||||||
|
|
||||||
# Filter documents based on the similarity threshold
|
# Filter documents based on the similarity threshold
|
||||||
filtered_docs = [doc for doc, sim in zip(documents, similarities) if sim >= threshold]
|
filtered_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim >= self.sim_threshold]
|
||||||
|
|
||||||
return filtered_docs
|
# If the number of filtered documents is less than at_least_k, sort remaining documents by similarity
|
||||||
|
if len(filtered_docs) < at_least_k:
|
||||||
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=True):
|
remaining_docs = [(doc, sim) for doc, sim in zip(documents, similarities) if sim < self.sim_threshold]
|
||||||
|
remaining_docs.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
filtered_docs.extend(remaining_docs[:at_least_k - len(filtered_docs)])
|
||||||
|
|
||||||
|
# Extract the document texts from the tuples
|
||||||
|
filtered_docs = [doc for doc, _ in filtered_docs]
|
||||||
|
|
||||||
|
return filtered_docs[:at_least_k]
|
||||||
|
|
||||||
|
def get_embeddings(self, sentences: List[str], batch_size=None, bypass_buffer=False):
|
||||||
"""
|
"""
|
||||||
Get BERT embeddings for a list of sentences.
|
Get BERT embeddings for a list of sentences.
|
||||||
|
|
||||||
@@ -233,29 +273,32 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
# if self.buffer_embeddings.any() and not bypass_buffer:
|
# if self.buffer_embeddings.any() and not bypass_buffer:
|
||||||
# return self.buffer_embeddings
|
# return self.buffer_embeddings
|
||||||
|
|
||||||
import torch
|
if self.device.type in ["gpu", "cuda", "mps"]:
|
||||||
# Tokenize sentences and convert to tensor
|
import torch
|
||||||
if batch_size is None:
|
# Tokenize sentences and convert to tensor
|
||||||
batch_size = self.default_batch_size
|
if batch_size is None:
|
||||||
|
batch_size = self.default_batch_size
|
||||||
all_embeddings = []
|
|
||||||
for i in range(0, len(sentences), batch_size):
|
all_embeddings = []
|
||||||
batch_sentences = sentences[i:i + batch_size]
|
for i in range(0, len(sentences), batch_size):
|
||||||
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
|
batch_sentences = sentences[i:i + batch_size]
|
||||||
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
|
encoded_input = self.tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt')
|
||||||
|
encoded_input = {key: tensor.to(self.device) for key, tensor in encoded_input.items()}
|
||||||
|
|
||||||
|
# Ensure no gradients are calculated
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = self.model(**encoded_input)
|
||||||
|
|
||||||
|
# Get embeddings from the last hidden state (mean pooling)
|
||||||
|
embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy()
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
# Ensure no gradients are calculated
|
self.buffer_embeddings = np.vstack(all_embeddings)
|
||||||
with torch.no_grad():
|
elif self.device.type == "cpu":
|
||||||
model_output = self.model(**encoded_input)
|
self.buffer_embeddings = self.model(sentences)
|
||||||
|
|
||||||
# Get embeddings from the last hidden state (mean pooling)
|
|
||||||
embeddings = model_output.last_hidden_state.mean(dim=1).cpu().numpy()
|
|
||||||
all_embeddings.append(embeddings)
|
|
||||||
|
|
||||||
self.buffer_embeddings = np.vstack(all_embeddings)
|
|
||||||
return self.buffer_embeddings
|
return self.buffer_embeddings
|
||||||
|
|
||||||
def hierarchical_clustering(self, sentences: List[str]):
|
def hierarchical_clustering(self, sentences: List[str], embeddings = None):
|
||||||
"""
|
"""
|
||||||
Perform hierarchical clustering on sentences and return cluster labels.
|
Perform hierarchical clustering on sentences and return cluster labels.
|
||||||
|
|
||||||
@@ -266,7 +309,7 @@ class CosineStrategy(ExtractionStrategy):
|
|||||||
from scipy.cluster.hierarchy import linkage, fcluster
|
from scipy.cluster.hierarchy import linkage, fcluster
|
||||||
from scipy.spatial.distance import pdist
|
from scipy.spatial.distance import pdist
|
||||||
self.timer = time.time()
|
self.timer = time.time()
|
||||||
embeddings = self.get_embeddings(sentences, bypass_buffer=False)
|
embeddings = self.get_embeddings(sentences, bypass_buffer=True)
|
||||||
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
|
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
|
||||||
# Compute pairwise cosine distances
|
# Compute pairwise cosine distances
|
||||||
distance_matrix = pdist(embeddings, 'cosine')
|
distance_matrix = pdist(embeddings, 'cosine')
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import subprocess, os
|
|||||||
import shutil
|
import shutil
|
||||||
from crawl4ai.config import MODEL_REPO_BRANCH
|
from crawl4ai.config import MODEL_REPO_BRANCH
|
||||||
import argparse
|
import argparse
|
||||||
|
import urllib.request
|
||||||
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_available_memory(device):
|
def get_available_memory(device):
|
||||||
@@ -23,18 +25,20 @@ def calculate_batch_size(device):
|
|||||||
return 16
|
return 16
|
||||||
elif device.type in ['cuda', 'mps']:
|
elif device.type in ['cuda', 'mps']:
|
||||||
# Adjust these thresholds based on your model size and available memory
|
# Adjust these thresholds based on your model size and available memory
|
||||||
if available_memory > 32 * 1024 ** 3: # > 16GB
|
if available_memory >= 31 * 1024 ** 3: # > 32GB
|
||||||
return 256
|
return 256
|
||||||
elif available_memory > 16 * 1024 ** 3: # > 16GB
|
elif available_memory >= 15 * 1024 ** 3: # > 16GB to 32GB
|
||||||
return 128
|
return 128
|
||||||
elif available_memory > 8 * 1024 ** 3: # 8GB to 16GB
|
elif available_memory >= 8 * 1024 ** 3: # 8GB to 16GB
|
||||||
return 64
|
return 64
|
||||||
else:
|
else:
|
||||||
return 32
|
return 32
|
||||||
else:
|
else:
|
||||||
return 16 # Default batch size
|
return 16 # Default batch size
|
||||||
|
|
||||||
def set_model_device(model):
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_device():
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
@@ -42,7 +46,10 @@ def set_model_device(model):
|
|||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
else:
|
else:
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
|
return device
|
||||||
|
|
||||||
|
def set_model_device(model):
|
||||||
|
device = get_device()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
return model, device
|
return model, device
|
||||||
|
|
||||||
@@ -72,6 +79,31 @@ def load_bge_small_en_v1_5():
|
|||||||
model, device = set_model_device(model)
|
model, device = set_model_device(model)
|
||||||
return tokenizer, model
|
return tokenizer, model
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def load_onnx_all_MiniLM_l6_v2():
|
||||||
|
from crawl4ai.onnx_embedding import DefaultEmbeddingModel
|
||||||
|
model_path = "models/onnx/model.onnx"
|
||||||
|
model_url = "https://unclecode-files.s3.us-west-2.amazonaws.com/model.onnx"
|
||||||
|
download_path = os.path.join(__location__, model_path)
|
||||||
|
|
||||||
|
if not os.path.exists(download_path):
|
||||||
|
# Define a download function with a simple progress display
|
||||||
|
def download_with_progress(url, filename):
|
||||||
|
def reporthook(block_num, block_size, total_size):
|
||||||
|
downloaded = block_num * block_size
|
||||||
|
percentage = 100 * downloaded / total_size
|
||||||
|
if downloaded < total_size:
|
||||||
|
print(f"\rDownloading: {percentage:.2f}% ({downloaded / (1024 * 1024):.2f} MB of {total_size / (1024 * 1024):.2f} MB)", end='')
|
||||||
|
else:
|
||||||
|
print("\rDownload complete! ")
|
||||||
|
|
||||||
|
urllib.request.urlretrieve(url, filename, reporthook)
|
||||||
|
|
||||||
|
download_with_progress(model_url, download_path)
|
||||||
|
|
||||||
|
model = DefaultEmbeddingModel()
|
||||||
|
return model
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def load_text_classifier():
|
def load_text_classifier():
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
@@ -204,10 +236,12 @@ def download_all_models(remove_existing=False):
|
|||||||
print("[LOG] Existing models removed.")
|
print("[LOG] Existing models removed.")
|
||||||
|
|
||||||
# Load each model to trigger download
|
# Load each model to trigger download
|
||||||
print("[LOG] Downloading BERT Base Uncased...")
|
# print("[LOG] Downloading BERT Base Uncased...")
|
||||||
load_bert_base_uncased()
|
# load_bert_base_uncased()
|
||||||
print("[LOG] Downloading BGE Small EN v1.5...")
|
# print("[LOG] Downloading BGE Small EN v1.5...")
|
||||||
load_bge_small_en_v1_5()
|
# load_bge_small_en_v1_5()
|
||||||
|
print("[LOG] Downloading ONNX model...")
|
||||||
|
load_onnx_all_MiniLM_l6_v2()
|
||||||
print("[LOG] Downloading text classifier...")
|
print("[LOG] Downloading text classifier...")
|
||||||
_, device = load_text_multilabel_classifier()
|
_, device = load_text_multilabel_classifier()
|
||||||
print(f"[LOG] Text classifier loaded on {device}")
|
print(f"[LOG] Text classifier loaded on {device}")
|
||||||
|
|||||||
25
crawl4ai/models/onnx/config.json
Normal file
25
crawl4ai/models/onnx/config.json
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
"architectures": [
|
||||||
|
"BertModel"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"classifier_dropout": null,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 384,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 1536,
|
||||||
|
"layer_norm_eps": 1e-12,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"model_type": "bert",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 6,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"position_embedding_type": "absolute",
|
||||||
|
"transformers_version": "4.27.4",
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 30522
|
||||||
|
}
|
||||||
BIN
crawl4ai/models/onnx/model.onnx
Normal file
BIN
crawl4ai/models/onnx/model.onnx
Normal file
Binary file not shown.
7
crawl4ai/models/onnx/special_tokens_map.json
Normal file
7
crawl4ai/models/onnx/special_tokens_map.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"cls_token": "[CLS]",
|
||||||
|
"mask_token": "[MASK]",
|
||||||
|
"pad_token": "[PAD]",
|
||||||
|
"sep_token": "[SEP]",
|
||||||
|
"unk_token": "[UNK]"
|
||||||
|
}
|
||||||
30686
crawl4ai/models/onnx/tokenizer.json
Normal file
30686
crawl4ai/models/onnx/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
15
crawl4ai/models/onnx/tokenizer_config.json
Normal file
15
crawl4ai/models/onnx/tokenizer_config.json
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"cls_token": "[CLS]",
|
||||||
|
"do_basic_tokenize": true,
|
||||||
|
"do_lower_case": true,
|
||||||
|
"mask_token": "[MASK]",
|
||||||
|
"model_max_length": 512,
|
||||||
|
"never_split": null,
|
||||||
|
"pad_token": "[PAD]",
|
||||||
|
"sep_token": "[SEP]",
|
||||||
|
"special_tokens_map_file": "/Users/hammad/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/7dbbc90392e2f80f3d3c277d6e90027e55de9125/special_tokens_map.json",
|
||||||
|
"strip_accents": null,
|
||||||
|
"tokenize_chinese_chars": true,
|
||||||
|
"tokenizer_class": "BertTokenizer",
|
||||||
|
"unk_token": "[UNK]"
|
||||||
|
}
|
||||||
30522
crawl4ai/models/onnx/vocab.txt
Normal file
30522
crawl4ai/models/onnx/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
50
crawl4ai/onnx_embedding.py
Normal file
50
crawl4ai/onnx_embedding.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# A dependency-light way to run the onnx model
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
import os
|
||||||
|
|
||||||
|
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
||||||
|
MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def normalize(v):
|
||||||
|
norm = np.linalg.norm(v, axis=1)
|
||||||
|
norm[norm == 0] = 1e-12
|
||||||
|
return v / norm[:, np.newaxis]
|
||||||
|
|
||||||
|
# Sampel implementation of the default sentence-transformers model using ONNX
|
||||||
|
class DefaultEmbeddingModel():
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
import onnxruntime as ort
|
||||||
|
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
|
||||||
|
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
|
||||||
|
self.tokenizer = Tokenizer.from_file(os.path.join(__location__, "models/onnx/tokenizer.json"))
|
||||||
|
self.tokenizer.enable_truncation(max_length=256)
|
||||||
|
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
|
||||||
|
self.model = ort.InferenceSession(os.path.join(__location__,"models/onnx/model.onnx"))
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, documents: List[str], batch_size: int = 32):
|
||||||
|
all_embeddings = []
|
||||||
|
for i in range(0, len(documents), batch_size):
|
||||||
|
batch = documents[i:i + batch_size]
|
||||||
|
encoded = [self.tokenizer.encode(d) for d in batch]
|
||||||
|
input_ids = np.array([e.ids for e in encoded])
|
||||||
|
attention_mask = np.array([e.attention_mask for e in encoded])
|
||||||
|
onnx_input = {
|
||||||
|
"input_ids": np.array(input_ids, dtype=np.int64),
|
||||||
|
"attention_mask": np.array(attention_mask, dtype=np.int64),
|
||||||
|
"token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64),
|
||||||
|
}
|
||||||
|
model_output = self.model.run(None, onnx_input)
|
||||||
|
last_hidden_state = model_output[0]
|
||||||
|
# Perform mean pooling with attention weighting
|
||||||
|
input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape)
|
||||||
|
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
|
||||||
|
embeddings = normalize(embeddings).astype(np.float32)
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
return np.concatenate(all_embeddings)
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ def add_extraction_strategy(crawler):
|
|||||||
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
cprint("CosineStrategy uses cosine similarity to extract semantically similar blocks of text. Let's see it in action!")
|
||||||
result = crawler.run(
|
result = crawler.run(
|
||||||
url="https://www.nbcnews.com/business",
|
url="https://www.nbcnews.com/business",
|
||||||
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, verbose=True)
|
extraction_strategy=CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold = 0.3, verbose=True)
|
||||||
)
|
)
|
||||||
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
cprint("[LOG] 📦 [bold yellow]CosineStrategy result:[/bold yellow]")
|
||||||
print_result(result)
|
print_result(result)
|
||||||
|
|||||||
@@ -16,3 +16,5 @@ uvicorn==0.29.0
|
|||||||
transformers==4.40.2
|
transformers==4.40.2
|
||||||
chromedriver-autoinstaller==0.6.4
|
chromedriver-autoinstaller==0.6.4
|
||||||
torch==2.3.0
|
torch==2.3.0
|
||||||
|
onnxruntime==1.14.1
|
||||||
|
tokenizers==0.13.2
|
||||||
Reference in New Issue
Block a user