Load CosineStrategy in the function

This commit is contained in:
unclecode
2024-05-17 15:13:06 +08:00
parent a5f9d07dbf
commit a317dc5e1d

View File

@@ -62,14 +62,14 @@ class WebCrawler:
extract_blocks_flag: bool = True, extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD, word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False, use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = CosineStrategy(), extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs, **kwargs,
) -> CrawlResult: ) -> CrawlResult:
return self.run( return self.run(
url_model.url, url_model.url,
word_count_threshold, word_count_threshold,
extraction_strategy, extraction_strategy or CosineStrategy(),
chunking_strategy, chunking_strategy,
bypass_cache=url_model.forced, bypass_cache=url_model.forced,
**kwargs, **kwargs,
@@ -81,13 +81,14 @@ class WebCrawler:
self, self,
url: str, url: str,
word_count_threshold=MIN_WORD_THRESHOLD, word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = CosineStrategy(), extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
bypass_cache: bool = False, bypass_cache: bool = False,
css_selector: str = None, css_selector: str = None,
verbose=True, verbose=True,
**kwargs, **kwargs,
) -> CrawlResult: ) -> CrawlResult:
extraction_strategy = extraction_strategy or CosineStrategy()
# Check if extraction strategy is an instance of ExtractionStrategy if not raise an error # Check if extraction strategy is an instance of ExtractionStrategy if not raise an error
if not isinstance(extraction_strategy, ExtractionStrategy): if not isinstance(extraction_strategy, ExtractionStrategy):
raise ValueError("Unsupported extraction strategy") raise ValueError("Unsupported extraction strategy")
@@ -183,11 +184,11 @@ class WebCrawler:
extract_blocks_flag: bool = True, extract_blocks_flag: bool = True,
word_count_threshold=MIN_WORD_THRESHOLD, word_count_threshold=MIN_WORD_THRESHOLD,
use_cached_html: bool = False, use_cached_html: bool = False,
extraction_strategy: ExtractionStrategy = CosineStrategy(), extraction_strategy: ExtractionStrategy = None,
chunking_strategy: ChunkingStrategy = RegexChunking(), chunking_strategy: ChunkingStrategy = RegexChunking(),
**kwargs, **kwargs,
) -> List[CrawlResult]: ) -> List[CrawlResult]:
extraction_strategy = extraction_strategy or CosineStrategy()
def fetch_page_wrapper(url_model, *args, **kwargs): def fetch_page_wrapper(url_model, *args, **kwargs):
return self.fetch_page(url_model, *args, **kwargs) return self.fetch_page(url_model, *args, **kwargs)