diff --git a/crawl4ai/chunking_strategy.py b/crawl4ai/chunking_strategy.py index a98884e6..d6f0e5d5 100644 --- a/crawl4ai/chunking_strategy.py +++ b/crawl4ai/chunking_strategy.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod import re -import spacy -import nltk -from nltk.corpus import stopwords -from nltk.tokenize import word_tokenize, TextTilingTokenizer +# spacy = lazy_import.lazy_module('spacy') +# nl = lazy_import.lazy_module('nltk') +# from nltk.corpus import stopwords +# from nltk.tokenize import word_tokenize, TextTilingTokenizer from collections import Counter import string @@ -34,8 +34,10 @@ class RegexChunking(ChunkingStrategy): return paragraphs # NLP-based sentence chunking using spaCy + class NlpSentenceChunking(ChunkingStrategy): def __init__(self, model='en_core_web_sm'): + import spacy self.nlp = spacy.load(model) def chunk(self, text: str) -> list: @@ -44,8 +46,10 @@ class NlpSentenceChunking(ChunkingStrategy): # Topic-based segmentation using TextTiling class TopicSegmentationChunking(ChunkingStrategy): + def __init__(self, num_keywords=3): - self.tokenizer = TextTilingTokenizer() + import nltk as nl + self.tokenizer = nl.toknize.TextTilingTokenizer() self.num_keywords = num_keywords def chunk(self, text: str) -> list: @@ -55,8 +59,9 @@ class TopicSegmentationChunking(ChunkingStrategy): def extract_keywords(self, text: str) -> list: # Tokenize and remove stopwords and punctuation - tokens = word_tokenize(text) - tokens = [token.lower() for token in tokens if token not in stopwords.words('english') and token not in string.punctuation] + import nltk as nl + tokens = nl.toknize.word_tokenize(text) + tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation] # Calculate frequency distribution freq_dist = Counter(tokens) diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 6c265b47..46069919 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -1,20 +1,12 @@ from abc import ABC, abstractmethod from typing import Any, List, Dict, Optional, Union -from scipy.cluster.hierarchy import linkage, fcluster -from scipy.spatial.distance import pdist -from transformers import BertTokenizer, BertModel, pipeline -from transformers import AutoTokenizer, AutoModel from concurrent.futures import ThreadPoolExecutor, as_completed -import nltk -from nltk.tokenize import TextTilingTokenizer import json, time -import torch -import spacy # from optimum.intel import IPEXModel - from .prompts import PROMPT_EXTRACT_BLOCKS from .config import * from .utils import * +from functools import partial class ExtractionStrategy(ABC): """ @@ -50,23 +42,32 @@ class ExtractionStrategy(ABC): parsed_json.extend(future.result()) return parsed_json +class NoExtractionStrategy(ExtractionStrategy): + def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: + return [{"index": 0, "content": html}] + + def run(self, url: str, sections: List[str], *q, **kwargs) -> List[Dict[str, Any]]: + return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)] + class LLMExtractionStrategy(ExtractionStrategy): def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None): - """ - Initialize the strategy with clustering parameters. + """ + Initialize the strategy with clustering parameters. - :param word_count_threshold: Minimum number of words per cluster. - :param max_dist: The maximum cophenetic distance on the dendrogram to form clusters. - :param linkage_method: The linkage method for hierarchical clustering. - """ - super().__init__() - self.provider = provider - self.api_token = api_token - - - def extract(self, url: str, html: str, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None) -> List[Dict[str, Any]]: - api_token = PROVIDER_MODELS.get(provider, None) if not api_token else api_token + :param word_count_threshold: Minimum number of words per cluster. + :param max_dist: The maximum cophenetic distance on the dendrogram to form clusters. + :param linkage_method: The linkage method for hierarchical clustering. + """ + super().__init__() + self.provider = provider + self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY") + if not self.api_token: + raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.") + + + def extract(self, url: str, html: str) -> List[Dict[str, Any]]: + print("Extracting blocks ...") variable_values = { "URL": url, "HTML": escape_json_string(sanitize_html(html)), @@ -78,7 +79,7 @@ class LLMExtractionStrategy(ExtractionStrategy): "{" + variable + "}", variable_values[variable] ) - response = perform_completion_with_backoff(provider, prompt_with_variables, api_token) + response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token) try: blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks'] @@ -96,28 +97,54 @@ class LLMExtractionStrategy(ExtractionStrategy): "tags": ["error"], "content": unparsed }) + + print("Extracted", len(blocks), "blocks.") return blocks - def run(self, url: str, sections: List[str], provider: str, api_token: Optional[str]) -> List[Dict[str, Any]]: + def _merge(self, documents): + chunks = [] + sections = [] + total_token_so_far = 0 + + for document in documents: + if total_token_so_far < CHUNK_TOKEN_THRESHOLD: + chunk = document.split(' ') + total_token_so_far += len(chunk) * 1.3 + chunks.append(document) + else: + sections.append('\n\n'.join(chunks)) + chunks = [document] + total_token_so_far = len(document.split(' ')) * 1.3 + + if chunks: + sections.append('\n\n'.join(chunks)) + + return sections + + def run(self, url: str, sections: List[str]) -> List[Dict[str, Any]]: """ Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy. """ + merged_sections = self._merge(sections) parsed_json = [] - if provider.startswith("groq/"): + if self.provider.startswith("groq/"): # Sequential processing with a delay - for section in sections: - parsed_json.extend(self.extract(url, section, provider, api_token)) + for section in merged_sections: + parsed_json.extend(self.extract(url, section)) time.sleep(0.5) # 500 ms delay between each processing else: # Parallel processing using ThreadPoolExecutor - with ThreadPoolExecutor() as executor: - futures = [executor.submit(self.extract, url, section, provider, api_token) for section in sections] + with ThreadPoolExecutor(max_workers=4) as executor: + extract_func = partial(self.extract, url) + futures = [executor.submit(extract_func, section) for section in merged_sections] + for future in as_completed(futures): parsed_json.extend(future.result()) + return parsed_json -class HierarchicalClusteringStrategy(ExtractionStrategy): +class CosinegStrategy(ExtractionStrategy): def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'): """ Initialize the strategy with clustering parameters. @@ -128,6 +155,10 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): :param top_k: Number of top categories to extract. """ super().__init__() + from transformers import BertTokenizer, BertModel, pipeline + from transformers import AutoTokenizer, AutoModel + import spacy + self.word_count_threshold = word_count_threshold self.max_dist = max_dist self.linkage_method = linkage_method @@ -156,6 +187,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): :param sentences: List of text chunks (sentences). :return: NumPy array of embeddings. """ + import torch # Tokenize sentences and convert to tensor encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings @@ -174,9 +206,11 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): :return: NumPy array of cluster labels. """ # Get embeddings + from scipy.cluster.hierarchy import linkage, fcluster + from scipy.spatial.distance import pdist self.timer = time.time() embeddings = self.get_embeddings(sentences) - print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds") + # print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds") # Compute pairwise cosine distances distance_matrix = pdist(embeddings, 'cosine') # Perform agglomerative clustering respecting order @@ -219,7 +253,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): # Perform clustering labels = self.hierarchical_clustering(text_chunks) - print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds") + # print(f"[LOG] 🚀 Clustering done in {time.time() - t:.2f} seconds") # Organize texts by their cluster labels, retaining order t = time.time() @@ -240,7 +274,7 @@ class HierarchicalClusteringStrategy(ExtractionStrategy): top_categories = sorted(doc.cats.items(), key=lambda x: x[1], reverse=True)[:tok_k] cluster['tags'] = [cat for cat, _ in top_categories] - print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") + # print(f"[LOG] 🚀 Categorization done in {time.time() - t:.2f} seconds") return cluster_list @@ -265,9 +299,10 @@ class TopicExtractionStrategy(ExtractionStrategy): :param num_keywords: Number of keywords to represent each topic segment. """ + import nltk super().__init__() self.num_keywords = num_keywords - self.tokenizer = TextTilingTokenizer() + self.tokenizer = nltk.TextTilingTokenizer() def extract_keywords(self, text: str) -> List[str]: """ @@ -276,6 +311,7 @@ class TopicExtractionStrategy(ExtractionStrategy): :param text: The text segment from which to extract keywords. :return: A list of keyword strings. """ + import nltk # Tokenize the text and compute word frequency words = nltk.word_tokenize(text) freq_dist = nltk.FreqDist(words) diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index fe9aad0a..7cdaf538 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -3,15 +3,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString import html2text import json +import html import re import os -import litellm -from litellm import completion, batch_completion +from html2text import HTML2Text from .prompts import PROMPT_EXTRACT_BLOCKS from .config import * -import re -import html -from html2text import HTML2Text def beautify_html(escaped_html): @@ -303,17 +300,16 @@ def extract_xml_data(tags, string): return data -import time -import litellm - # Function to perform the completion with exponential backoff def perform_completion_with_backoff(provider, prompt_with_variables, api_token): + from litellm import completion + from litellm.exceptions import RateLimitError max_attempts = 3 base_delay = 2 # Base delay in seconds, you can adjust this based on your needs for attempt in range(max_attempts): try: - response = completion( + response =completion( model=provider, messages=[ {"role": "user", "content": prompt_with_variables} @@ -322,7 +318,7 @@ def perform_completion_with_backoff(provider, prompt_with_variables, api_token): api_key=api_token ) return response # Return the successful response - except litellm.exceptions.RateLimitError as e: + except RateLimitError as e: print("Rate limit error:", str(e)) # Check if we have exhausted our max attempts @@ -378,7 +374,7 @@ def extract_blocks(url, html, provider = DEFAULT_PROVIDER, api_token = None): def extract_blocks_batch(batch_data, provider = "groq/llama3-70b-8192", api_token = None): api_token = os.getenv('GROQ_API_KEY', None) if not api_token else api_token - + from litellm import batch_completion messages = [] for url, html in batch_data: diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 34c3afc6..0952adf5 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -9,49 +9,93 @@ from .extraction_strategy import * from .crawler_strategy import * from typing import List from concurrent.futures import ThreadPoolExecutor -from .config import * +from .config import * + class WebCrawler: - def __init__(self, db_path: str, crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy()): + def __init__( + self, + db_path: str = None, + crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy(), + ): self.db_path = db_path - init_db(self.db_path) self.crawler_strategy = crawler_strategy - + # Create the .crawl4ai folder in the user's home directory if it doesn't exist self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai") - os.makedirs(self.crawl4ai_folder, exist_ok=True) + os.makedirs(self.crawl4ai_folder, exist_ok=True) os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True) - - def fetch_page(self, - url_model: UrlModel, - provider: str = DEFAULT_PROVIDER, - api_token: str = None, - extract_blocks_flag: bool = True, - word_count_threshold = MIN_WORD_THRESHOLD, - use_cached_html: bool = False, - extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(), - chunking_strategy: ChunkingStrategy = RegexChunking(), - **kwargs - ) -> CrawlResult: + # If db_path is not provided, use the default path + if not db_path: + self.db_path = f"{self.crawl4ai_folder}/crawl4ai.db" + + init_db(self.db_path) + + self.ready = False + + def warmup(self): + print("[LOG] 🌤️ Warming up the WebCrawler") + single_url = UrlModel(url='https://crawl4ai.uccode.io/', forced=False) + result = self.run( + single_url, + word_count_threshold=5, + extraction_strategy= CosinegStrategy(), + verbose = False + ) + self.ready = True + print("[LOG] 🌞 WebCrawler is ready to crawl") + + def fetch_page( + self, + url_model: UrlModel, + provider: str = DEFAULT_PROVIDER, + api_token: str = None, + extract_blocks_flag: bool = True, + word_count_threshold=MIN_WORD_THRESHOLD, + use_cached_html: bool = False, + extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(), + chunking_strategy: ChunkingStrategy = RegexChunking(), + **kwargs, + ) -> CrawlResult: + return self.run( + url_model, + word_count_threshold, + extraction_strategy, + chunking_strategy, + **kwargs, + ) + pass + + + def run( + self, + url_model: UrlModel, + word_count_threshold=MIN_WORD_THRESHOLD, + extraction_strategy: ExtractionStrategy = NoExtractionStrategy(), + chunking_strategy: ChunkingStrategy = RegexChunking(), + verbose=True, + **kwargs, + ) -> CrawlResult: # make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD if word_count_threshold < MIN_WORD_THRESHOLD: word_count_threshold = MIN_WORD_THRESHOLD - + # Check cache first cached = get_cached_url(self.db_path, str(url_model.url)) if cached and not url_model.forced: - return CrawlResult(**{ - "url": cached[0], - "html": cached[1], - "cleaned_html": cached[2], - "markdown": cached[3], - "parsed_json": cached[4], - "success": cached[5], - "error_message": "" - }) - + return CrawlResult( + **{ + "url": cached[0], + "html": cached[1], + "cleaned_html": cached[2], + "markdown": cached[3], + "parsed_json": cached[4], + "success": cached[5], + "error_message": "", + } + ) # Initialize WebDriver for crawling t = time.time() @@ -62,65 +106,89 @@ class WebCrawler: except Exception as e: html = "" success = False - error_message = str(e) - + error_message = str(e) + # Extract content from HTML result = get_content_of_website(html, word_count_threshold) - cleaned_html = result.get('cleaned_html', html) - markdown = result.get('markdown', "") - + cleaned_html = result.get("cleaned_html", html) + markdown = result.get("markdown", "") + # Print a profession LOG style message, show time taken and say crawling is done - print(f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds") - + if verbose: + print( + f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds" + ) parsed_json = [] - if extract_blocks_flag: + if verbose: print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}") - t = time.time() - # Split markdown into sections - sections = chunking_strategy.chunk(markdown) - # sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD) + t = time.time() + # Split markdown into sections + sections = chunking_strategy.chunk(markdown) + # sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD) - parsed_json = extraction_strategy.run(str(url_model.url), sections, provider, api_token) - parsed_json = json.dumps(parsed_json) - - - print(f"[LOG] 🚀 Extraction done for {url_model.url}, time taken: {time.time() - t} seconds.") - else: - parsed_json = "{}" - print(f"[LOG] 🚀 Skipping extraction for {url_model.url}") + parsed_json = extraction_strategy.run( + str(url_model.url), sections, + ) + parsed_json = json.dumps(parsed_json) + + if verbose: + print( + f"[LOG] 🚀 Extraction done for {url_model.url}, time taken: {time.time() - t} seconds." + ) # Cache the result cleaned_html = beautify_html(cleaned_html) - cache_url(self.db_path, str(url_model.url), html, cleaned_html, markdown, parsed_json, success) - - return CrawlResult( - url=str(url_model.url), - html=html, - cleaned_html=cleaned_html, - markdown=markdown, - parsed_json=parsed_json, - success=success, - error_message=error_message + cache_url( + self.db_path, + str(url_model.url), + html, + cleaned_html, + markdown, + parsed_json, + success, ) - def fetch_pages(self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, api_token: str = None, - extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD, - use_cached_html: bool = False, extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(), - chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs) -> List[CrawlResult]: - + return CrawlResult( + url=str(url_model.url), + html=html, + cleaned_html=cleaned_html, + markdown=markdown, + parsed_json=parsed_json, + success=success, + error_message=error_message, + ) + + def fetch_pages( + self, + url_models: List[UrlModel], + provider: str = DEFAULT_PROVIDER, + api_token: str = None, + extract_blocks_flag: bool = True, + word_count_threshold=MIN_WORD_THRESHOLD, + use_cached_html: bool = False, + extraction_strategy: ExtractionStrategy = LLMExtractionStrategy(), + chunking_strategy: ChunkingStrategy = RegexChunking(), + **kwargs, + ) -> List[CrawlResult]: + def fetch_page_wrapper(url_model, *args, **kwargs): return self.fetch_page(url_model, *args, **kwargs) with ThreadPoolExecutor() as executor: - results = list(executor.map(fetch_page_wrapper, url_models, - [provider] * len(url_models), - [api_token] * len(url_models), - [extract_blocks_flag] * len(url_models), - [word_count_threshold] * len(url_models), - [use_cached_html] * len(url_models), - [extraction_strategy] * len(url_models), - [chunking_strategy] * len(url_models), - *[kwargs] * len(url_models))) + results = list( + executor.map( + fetch_page_wrapper, + url_models, + [provider] * len(url_models), + [api_token] * len(url_models), + [extract_blocks_flag] * len(url_models), + [word_count_threshold] * len(url_models), + [use_cached_html] * len(url_models), + [extraction_strategy] * len(url_models), + [chunking_strategy] * len(url_models), + *[kwargs] * len(url_models), + ) + ) - return results \ No newline at end of file + return results diff --git a/requirements.txt b/requirements.txt index 39a4226f..7995d648 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ html2text litellm python-dotenv nltk +lazy_import # spacy \ No newline at end of file