generalized query in ContentRelevanceFilter to be a str or list
This commit is contained in:
@@ -509,18 +509,22 @@ class DomainFilter(URLFilter):
|
|||||||
class ContentRelevanceFilter(URLFilter):
|
class ContentRelevanceFilter(URLFilter):
|
||||||
"""BM25-based relevance filter using head section content"""
|
"""BM25-based relevance filter using head section content"""
|
||||||
|
|
||||||
__slots__ = ("query_terms", "threshold", "k1", "b", "avgdl")
|
__slots__ = ("query_terms", "threshold", "k1", "b", "avgdl", "query")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: Union[str, List[str]],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
k1: float = 1.2,
|
k1: float = 1.2,
|
||||||
b: float = 0.75,
|
b: float = 0.75,
|
||||||
avgdl: int = 1000,
|
avgdl: int = 1000,
|
||||||
):
|
):
|
||||||
super().__init__(name="BM25RelevanceFilter")
|
super().__init__(name="BM25RelevanceFilter")
|
||||||
self.query_terms = self._tokenize(query)
|
if isinstance(query, list):
|
||||||
|
self.query = " ".join(query)
|
||||||
|
else:
|
||||||
|
self.query = query
|
||||||
|
self.query_terms = self._tokenize(self.query)
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.k1 = k1 # TF saturation parameter
|
self.k1 = k1 # TF saturation parameter
|
||||||
self.b = b # Length normalization parameter
|
self.b = b # Length normalization parameter
|
||||||
|
|||||||
Reference in New Issue
Block a user