Improve library loading
This commit is contained in:
@@ -8,8 +8,8 @@ from .config import *
|
||||
from .utils import *
|
||||
from functools import partial
|
||||
from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model
|
||||
from transformers import pipeline
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
|
||||
import numpy as np
|
||||
class ExtractionStrategy(ABC):
|
||||
"""
|
||||
@@ -165,6 +165,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
:param top_k: Number of top categories to extract.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
from transformers import BertTokenizer, BertModel, pipeline
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import spacy
|
||||
@@ -196,6 +197,7 @@ class CosineStrategy(ExtractionStrategy):
|
||||
:param threshold: Cosine similarity threshold for filtering documents.
|
||||
:return: Filtered list of documents.
|
||||
"""
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
if not semantic_filter:
|
||||
return documents
|
||||
# Compute embedding for the keyword filter
|
||||
@@ -409,6 +411,7 @@ class ContentSummarizationStrategy(ExtractionStrategy):
|
||||
|
||||
:param model_name: The model to use for summarization.
|
||||
"""
|
||||
from transformers import pipeline
|
||||
self.summarizer = pipeline("summarization", model=model_name)
|
||||
|
||||
def extract(self, url: str, text: str, provider: str = None, api_token: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
|
||||
Reference in New Issue
Block a user