diff --git a/.gitignore b/.gitignore index 3d8aeb81..a794706f 100644 --- a/.gitignore +++ b/.gitignore @@ -166,5 +166,6 @@ Crawl4AI.egg-info/* crawler_data.db .vscode/ test_pad.py +test_pad*.py .data/ Crawl4AI.egg-info/ \ No newline at end of file diff --git a/README.md b/README.md index 8a49e1e4..90922299 100644 --- a/README.md +++ b/README.md @@ -56,40 +56,28 @@ pip install -e . 2. Import the necessary modules in your Python script: ```python from crawl4ai.web_crawler import WebCrawler -from crawl4ai.models import UrlModel +from crawl4ai.chunking_strategy import * +from crawl4ai.extraction_strategy import * import os -crawler = WebCrawler(db_path='crawler_data.db') +crawler = WebCrawler() +crawler.warmup() # IMPORTANT: Warmup the engine before running the first crawl # Single page crawl -single_url = UrlModel(url='https://kidocode.com', forced=False) -result = crawl4ai.fetch_page( - single_url, - provider= "openai/gpt-3.5-turbo", - api_token = os.getenv('OPENAI_API_KEY'), - # Set `extract_blocks_flag` to True to enable the LLM to generate semantically clustered chunks - # and return them as JSON. Depending on the model and data size, this may take up to 1 minute. - # Without this setting, it will take between 5 to 20 seconds. - extract_blocks_flag=False - word_count_threshold=5 # Minimum word count for a HTML tag to be considered as a worthy block +result = crawler.run( + url='https://www.nbcnews.com/business', + word_count_threshold=5, # Minimum word count for a HTML tag to be considered as a worthy block + chunking_strategy= RegexChunking( patterns = ["\n\n"]), # Default is RegexChunking + extraction_strategy= CosineStrategy(word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3) # Default is CosineStrategy + # extraction_strategy= LLMExtractionStrategy(provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY')), + bypass_cache=False, + extract_blocks =True, # Whether to extract semantical blocks of text from the HTML + css_selector = "", # Eg: "div.article-body" + verbose=True, + include_raw_html=True, # Whether to include the raw HTML content in the response ) + print(result.model_dump()) - -# Multiple page crawl -urls = [ - UrlModel(url='http://example.com', forced=False), - UrlModel(url='http://example.org', forced=False) -] -results = crawl4ai.fetch_pages( - urls, - provider= "openai/gpt-3.5-turbo", - api_token = os.getenv('OPENAI_API_KEY'), - extract_blocks_flag=True, - word_count_threshold=5 -) - -for res in results: - print(res.model_dump()) ``` Running for the first time will download the chrome driver for selenium. Also creates a SQLite database file `crawler_data.db` in the current directory. This file will store the crawled data for future reference. @@ -150,23 +138,22 @@ Set `extract_blocks_flag` to True to enable the LLM to generate semantically clu import requests import os -url = "http://localhost:8000/crawl" # Replace with the appropriate server URL data = { "urls": [ - "https://example.com" + "https://www.nbcnews.com/business" ], "provider_model": "groq/llama3-70b-8192", - "api_token": "your_api_token", "include_raw_html": true, - "forced": false, - # Set `extract_blocks_flag` to True to enable the LLM to generate semantically clustered chunks - # and return them as JSON. Depending on the model and data size, this may take up to 1 minute. - # Without this setting, it will take between 5 to 20 seconds. - "extract_blocks_flag": False, - "word_count_threshold": 5 + "bypass_cache": false, + "extract_blocks": true, + "word_count_threshold": 10, + "extraction_strategy": "CosineStrategy", + "chunking_strategy": "RegexChunking", + "css_selector": "", + "verbose": true } -response = requests.post(url, json=data) +response = requests.post("http://crawl4ai.uccode.io/crawl", json=data) # OR http://localhost:8000 if your run locally if response.status_code == 200: result = response.json()["results"][0] @@ -180,9 +167,9 @@ else: print("Error:", response.status_code, response.text) ``` -This code sends a POST request to the Crawl4AI server running on localhost, specifying the target URL (`https://example.com`) and the desired options (`grq_api_token`, `include_raw_html`, and `forced`). The server processes the request and returns the crawled data in JSON format. +This code sends a POST request to the Crawl4AI server running on localhost, specifying the target URL (`http://crawl4ai.uccode.io/crawl`) and the desired options. The server processes the request and returns the crawled data in JSON format. -The response from the server includes the parsed JSON, cleaned HTML, and markdown representations of the crawled webpage. You can access and use this data in your Python application as needed. +The response from the server includes the semantical clusters, cleaned HTML, and markdown representations of the crawled webpage. You can access and use this data in your Python application as needed. Make sure to replace `"http://localhost:8000/crawl"` with the appropriate server URL if your Crawl4AI server is running on a different host or port. @@ -194,15 +181,17 @@ That's it! You can now integrate Crawl4AI into your Python projects and leverage ## 📖 Parameters -| Parameter | Description | Required | Default Value | -|----------------------|-------------------------------------------------------------------------------------------------|----------|---------------| -| `urls` | A list of URLs to crawl and extract data from. | Yes | - | -| `provider_model` | The provider and model to use for extracting relevant information (e.g., "groq/llama3-70b-8192"). | Yes | - | -| `api_token` | Your API token for the specified provider. | Yes | - | -| `include_raw_html` | Whether to include the raw HTML content in the response. | No | `false` | -| `forced` | Whether to force a fresh crawl even if the URL has been previously crawled. | No | `false` | -| `extract_blocks_flag`| Whether to extract semantical blocks of text from the HTML. | No | `false` | -| `word_count_threshold` | The minimum number of words a block must contain to be considered meaningful (minimum value is 5). | No | `5` | +| Parameter | Description | Required | Default Value | +|-----------------------|-------------------------------------------------------------------------------------------------------|----------|---------------------| +| `urls` | A list of URLs to crawl and extract data from. | Yes | - | +| `include_raw_html` | Whether to include the raw HTML content in the response. | No | `false` | +| `bypass_cache` | Whether to force a fresh crawl even if the URL has been previously crawled. | No | `false` | +| `extract_blocks` | Whether to extract semantical blocks of text from the HTML. | No | `true` | +| `word_count_threshold`| The minimum number of words a block must contain to be considered meaningful (minimum value is 5). | No | `5` | +| `extraction_strategy` | The strategy to use for extracting content from the HTML (e.g., "CosineStrategy"). | No | `CosineStrategy` | +| `chunking_strategy` | The strategy to use for chunking the text before processing (e.g., "RegexChunking"). | No | `RegexChunking` | +| `css_selector` | The CSS selector to target specific parts of the HTML for extraction. | No | `None` | +| `verbose` | Whether to enable verbose logging. | No | `true` | ## 🛠️ Configuration Crawl4AI allows you to configure various parameters and settings in the `crawler/config.py` file. Here's an example of how you can adjust the parameters: @@ -213,15 +202,17 @@ from dotenv import load_dotenv load_dotenv() # Load environment variables from .env file -# Default provider +# Default provider, ONLY used when the extraction strategy is LLMExtractionStrategy DEFAULT_PROVIDER = "openai/gpt-4-turbo" -# Provider-model dictionary +# Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy PROVIDER_MODELS = { + "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token "groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"), "groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"), "openai/gpt-3.5-turbo": os.getenv("OPENAI_API_KEY"), "openai/gpt-4-turbo": os.getenv("OPENAI_API_KEY"), + "openai/gpt-4o": os.getenv("OPENAI_API_KEY"), "anthropic/claude-3-haiku-20240307": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"), @@ -229,12 +220,14 @@ PROVIDER_MODELS = { # Chunk token threshold CHUNK_TOKEN_THRESHOLD = 1000 - # Threshold for the minimum number of words in an HTML tag to be considered MIN_WORD_THRESHOLD = 5 ``` + In the `crawler/config.py` file, you can: +REMEBER: You only need to set the API keys for the providers in case you choose LLMExtractStrategy as the extraction strategy. If you choose CosineStrategy, you don't need to set the API keys. + - Set the default provider using the `DEFAULT_PROVIDER` variable. - Add or modify the provider-model dictionary (`PROVIDER_MODELS`) to include your desired providers and their corresponding API keys. Crawl4AI supports various providers such as Groq, OpenAI, Anthropic, and more. You can add any provider supported by LiteLLM, as well as Ollama. - Adjust the `CHUNK_TOKEN_THRESHOLD` value to control the splitting of web content into chunks for parallel processing. A higher value means fewer chunks and faster processing, but it may cause issues with weaker LLMs during extraction. diff --git a/crawl4ai/config.py b/crawl4ai/config.py index b29325f1..5132079c 100644 --- a/crawl4ai/config.py +++ b/crawl4ai/config.py @@ -3,15 +3,17 @@ from dotenv import load_dotenv load_dotenv() # Load environment variables from .env file -# Default provider +# Default provider, ONLY used when the extraction strategy is LLMExtractionStrategy DEFAULT_PROVIDER = "openai/gpt-4-turbo" -# Provider-model dictionary +# Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy PROVIDER_MODELS = { + "ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token "groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"), "groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"), "openai/gpt-3.5-turbo": os.getenv("OPENAI_API_KEY"), "openai/gpt-4-turbo": os.getenv("OPENAI_API_KEY"), + "openai/gpt-4o": os.getenv("OPENAI_API_KEY"), "anthropic/claude-3-haiku-20240307": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"), diff --git a/crawl4ai/crawler_strategy.py b/crawl4ai/crawler_strategy.py index 2540c077..8d183e38 100644 --- a/crawl4ai/crawler_strategy.py +++ b/crawl4ai/crawler_strategy.py @@ -5,18 +5,20 @@ from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.chrome.options import Options +from selenium.common.exceptions import InvalidArgumentException import chromedriver_autoinstaller from typing import List import requests - +import os +from pathlib import Path class CrawlerStrategy(ABC): @abstractmethod - def crawl(self, url: str) -> str: + def crawl(self, url: str, **kwargs) -> str: pass class CloudCrawlerStrategy(CrawlerStrategy): - def crawl(self, url: str) -> str: + def crawl(self, url: str, use_cached_html = False, css_selector = None) -> str: data = { "urls": [url], "provider_model": "", @@ -40,19 +42,34 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy): self.options.add_argument("--disable-dev-shm-usage") self.options.add_argument("--headless") - chromedriver_autoinstaller.install() + # chromedriver_autoinstaller.install() self.service = Service(chromedriver_autoinstaller.install()) self.driver = webdriver.Chrome(service=self.service, options=self.options) - def crawl(self, url: str, use_cached_html = False) -> str: + def crawl(self, url: str, use_cached_html = False, css_selector = None) -> str: if use_cached_html: - return get_content_of_website(url) - self.driver.get(url) - WebDriverWait(self.driver, 10).until( - EC.presence_of_all_elements_located((By.TAG_NAME, "html")) - ) - html = self.driver.page_source - return html + cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_")) + if os.path.exists(cache_file_path): + with open(cache_file_path, "r") as f: + return f.read() + + try: + self.driver.get(url) + WebDriverWait(self.driver, 10).until( + EC.presence_of_all_elements_located((By.TAG_NAME, "html")) + ) + html = self.driver.page_source + + # Store in cache + cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_")) + with open(cache_file_path, "w") as f: + f.write(html) + + return html + except InvalidArgumentException: + raise InvalidArgumentException(f"Invalid URL {url}") + except Exception as e: + raise Exception(f"Failed to crawl {url}: {str(e)}") def quit(self): self.driver.quit() \ No newline at end of file diff --git a/crawl4ai/database.py b/crawl4ai/database.py index 89048d05..b2169c84 100644 --- a/crawl4ai/database.py +++ b/crawl4ai/database.py @@ -1,7 +1,15 @@ +import os +from pathlib import Path import sqlite3 from typing import Optional +from typing import Optional, Tuple +DB_PATH = os.path.join(Path.home(), ".crawl4ai") +os.makedirs(DB_PATH, exist_ok=True) +DB_PATH = os.path.join(DB_PATH, "crawl4ai.db") + def init_db(db_path: str): + global DB_PATH conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute(''' @@ -16,46 +24,65 @@ def init_db(db_path: str): ''') conn.commit() conn.close() + DB_PATH = db_path -def get_cached_url(db_path: str, url: str) -> Optional[tuple]: - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute('SELECT url, html, cleaned_html, markdown, parsed_json, success FROM crawled_data WHERE url = ?', (url,)) - result = cursor.fetchone() - conn.close() - return result +def check_db_path(): + if not DB_PATH: + raise ValueError("Database path is not set or is empty.") -def cache_url(db_path: str, url: str, html: str, cleaned_html: str, markdown: str, parsed_json: str, success: bool): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO crawled_data (url, html, cleaned_html, markdown, parsed_json, success) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(url) DO UPDATE SET - html = excluded.html, - cleaned_html = excluded.cleaned_html, - markdown = excluded.markdown, - parsed_json = excluded.parsed_json, - success = excluded.success - ''', (str(url), html, cleaned_html, markdown, parsed_json, success)) - conn.commit() - conn.close() - -def get_total_count(db_path: str) -> int: +def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool]]: + check_db_path() try: - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute('SELECT url, html, cleaned_html, markdown, parsed_json, success FROM crawled_data WHERE url = ?', (url,)) + result = cursor.fetchone() + conn.close() + return result + except Exception as e: + print(f"Error retrieving cached URL: {e}") + return None + +def cache_url(url: str, html: str, cleaned_html: str, markdown: str, parsed_json: str, success: bool): + check_db_path() + try: + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO crawled_data (url, html, cleaned_html, markdown, parsed_json, success) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(url) DO UPDATE SET + html = excluded.html, + cleaned_html = excluded.cleaned_html, + markdown = excluded.markdown, + parsed_json = excluded.parsed_json, + success = excluded.success + ''', (url, html, cleaned_html, markdown, parsed_json, success)) + conn.commit() + conn.close() + except Exception as e: + print(f"Error caching URL: {e}") + +def get_total_count() -> int: + check_db_path() + try: + conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM crawled_data') result = cursor.fetchone() conn.close() return result[0] except Exception as e: + print(f"Error getting total count: {e}") return 0 - -# Crete function to cler the database -def clear_db(db_path: str): - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - cursor.execute('DELETE FROM crawled_data') - conn.commit() - conn.close() \ No newline at end of file + +def clear_db(): + check_db_path() + try: + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute('DELETE FROM crawled_data') + conn.commit() + conn.close() + except Exception as e: + print(f"Error clearing database: {e}") \ No newline at end of file diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 46069919..91e44e3f 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -7,6 +7,8 @@ from .prompts import PROMPT_EXTRACT_BLOCKS from .config import * from .utils import * from functools import partial +from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model + class ExtractionStrategy(ABC): """ @@ -15,6 +17,7 @@ class ExtractionStrategy(ABC): def __init__(self): self.DEL = "<|DEL|>" + self.name = self.__class__.__name__ @abstractmethod def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]: @@ -67,7 +70,7 @@ class LLMExtractionStrategy(ExtractionStrategy): def extract(self, url: str, html: str) -> List[Dict[str, Any]]: - print("Extracting blocks ...") + print("[LOG] Extracting blocks from URL:", url) variable_values = { "URL": url, "HTML": escape_json_string(sanitize_html(html)), @@ -98,7 +101,7 @@ class LLMExtractionStrategy(ExtractionStrategy): "content": unparsed }) - print("Extracted", len(blocks), "blocks.") + print("[LOG] Extracted", len(blocks), "blocks from URL:", url) return blocks def _merge(self, documents): @@ -125,6 +128,7 @@ class LLMExtractionStrategy(ExtractionStrategy): """ Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy. """ + merged_sections = self._merge(sections) parsed_json = [] if self.provider.startswith("groq/"): @@ -144,7 +148,7 @@ class LLMExtractionStrategy(ExtractionStrategy): return parsed_json -class CosinegStrategy(ExtractionStrategy): +class CosineStrategy(ExtractionStrategy): def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'): """ Initialize the strategy with clustering parameters. @@ -164,20 +168,13 @@ class CosinegStrategy(ExtractionStrategy): self.linkage_method = linkage_method self.top_k = top_k self.timer = time.time() - - if model_name == "bert-base-uncased": - self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) - self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) - elif model_name == "sshleifer/distilbart-cnn-12-6": - # self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") - # self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static") - pass - elif model_name == "BAAI/bge-small-en-v1.5": - self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) - self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) - self.model.eval() - self.nlp = spacy.load("models/reuters") + if model_name == "bert-base-uncased": + self.tokenizer, self.model = load_bert_base_uncased() + elif model_name == "BAAI/bge-small-en-v1.5": + self.tokenizer, self.model = load_bge_small_en_v1_5() + + self.nlp = load_spacy_model() print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds") def get_embeddings(self, sentences: List[str]): diff --git a/crawl4ai/model_loader.py b/crawl4ai/model_loader.py new file mode 100644 index 00000000..b068f5f8 --- /dev/null +++ b/crawl4ai/model_loader.py @@ -0,0 +1,20 @@ +from functools import lru_cache +from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel +import spacy + +@lru_cache() +def load_bert_base_uncased(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None) + model = BertModel.from_pretrained('bert-base-uncased', resume_download=None) + return tokenizer, model + +@lru_cache() +def load_bge_small_en_v1_5(): + tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) + model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None) + model.eval() + return tokenizer, model + +@lru_cache() +def load_spacy_model(): + return spacy.load("models/reuters") diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 7cdaf538..37729656 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -10,6 +10,8 @@ from html2text import HTML2Text from .prompts import PROMPT_EXTRACT_BLOCKS from .config import * +class InvalidCSSSelectorError(Exception): + pass def beautify_html(escaped_html): """ @@ -140,13 +142,25 @@ class CustomHTML2Text(HTML2Text): super().handle_tag(tag, attrs, start) -def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD): +def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD, css_selector = None): try: + if not html: + return None # Parse HTML content with BeautifulSoup soup = BeautifulSoup(html, 'html.parser') # Get the content within the tag body = soup.body + + # If css_selector is provided, extract content based on the selector + if css_selector: + selected_elements = body.select(css_selector) + if not selected_elements: + raise InvalidCSSSelectorError(f"Invalid CSS selector , No elements found for CSS selector: {css_selector}") + div_tag = soup.new_tag('div') + for el in selected_elements: + div_tag.append(el) + body = div_tag # Remove script, style, and other tags that don't carry useful content from body for tag in body.find_all(['script', 'style', 'link', 'meta', 'noscript']): @@ -255,7 +269,7 @@ def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD): # Remove comments - for comment in soup.find_all(text=lambda text: isinstance(text, Comment)): + for comment in soup.find_all(string=lambda text: isinstance(text, Comment)): comment.extract() # Remove consecutive empty newlines and replace multiple spaces with a single space @@ -281,7 +295,7 @@ def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD): except Exception as e: print('Error processing HTML content:', str(e)) - return None + raise InvalidCSSSelectorError(f"Invalid CSS selector: {css_selector}") from e def extract_xml_tags(string): tags = re.findall(r'<(\w+)>', string) diff --git a/crawl4ai/web_crawler.py b/crawl4ai/web_crawler.py index 0952adf5..361c06dd 100644 --- a/crawl4ai/web_crawler.py +++ b/crawl4ai/web_crawler.py @@ -2,7 +2,7 @@ import os, time from pathlib import Path from .models import UrlModel, CrawlResult -from .database import init_db, get_cached_url, cache_url +from .database import init_db, get_cached_url, cache_url, DB_PATH from .utils import * from .chunking_strategy import * from .extraction_strategy import * @@ -10,6 +10,7 @@ from .crawler_strategy import * from typing import List from concurrent.futures import ThreadPoolExecutor from .config import * +# from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model class WebCrawler: @@ -36,11 +37,11 @@ class WebCrawler: def warmup(self): print("[LOG] 🌤️ Warming up the WebCrawler") - single_url = UrlModel(url='https://crawl4ai.uccode.io/', forced=False) result = self.run( - single_url, + url='https://crawl4ai.uccode.io/', word_count_threshold=5, - extraction_strategy= CosinegStrategy(), + extraction_strategy= CosineStrategy(), + bypass_cache=False, verbose = False ) self.ready = True @@ -60,10 +61,11 @@ class WebCrawler: **kwargs, ) -> CrawlResult: return self.run( - url_model, + url_model.url, word_count_threshold, extraction_strategy, chunking_strategy, + bypass_cache=url_model.forced, **kwargs, ) pass @@ -71,77 +73,85 @@ class WebCrawler: def run( self, - url_model: UrlModel, + url: str, word_count_threshold=MIN_WORD_THRESHOLD, extraction_strategy: ExtractionStrategy = NoExtractionStrategy(), chunking_strategy: ChunkingStrategy = RegexChunking(), + bypass_cache: bool = False, + css_selector: str = None, verbose=True, **kwargs, ) -> CrawlResult: + # Check if extraction strategy is an instance of ExtractionStrategy if not raise an error + if not isinstance(extraction_strategy, ExtractionStrategy): + raise ValueError("Unsupported extraction strategy") + if not isinstance(chunking_strategy, ChunkingStrategy): + raise ValueError("Unsupported chunking strategy") + # make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD if word_count_threshold < MIN_WORD_THRESHOLD: word_count_threshold = MIN_WORD_THRESHOLD # Check cache first - cached = get_cached_url(self.db_path, str(url_model.url)) - if cached and not url_model.forced: - return CrawlResult( - **{ - "url": cached[0], - "html": cached[1], - "cleaned_html": cached[2], - "markdown": cached[3], - "parsed_json": cached[4], - "success": cached[5], - "error_message": "", - } - ) + if not bypass_cache: + cached = get_cached_url(url) + if cached: + return CrawlResult( + **{ + "url": cached[0], + "html": cached[1], + "cleaned_html": cached[2], + "markdown": cached[3], + "parsed_json": cached[4], + "success": cached[5], + "error_message": "", + } + ) # Initialize WebDriver for crawling t = time.time() - try: - html = self.crawler_strategy.crawl(str(url_model.url)) - success = True - error_message = "" - except Exception as e: - html = "" - success = False - error_message = str(e) - + html = self.crawler_strategy.crawl(url) + success = True + error_message = "" # Extract content from HTML - result = get_content_of_website(html, word_count_threshold) + try: + result = get_content_of_website(html, word_count_threshold, css_selector=css_selector) + if result is None: + raise ValueError(f"Failed to extract content from the website: {url}") + except InvalidCSSSelectorError as e: + raise ValueError(str(e)) + cleaned_html = result.get("cleaned_html", html) markdown = result.get("markdown", "") # Print a profession LOG style message, show time taken and say crawling is done if verbose: print( - f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds" + f"[LOG] 🚀 Crawling done for {url}, success: {success}, time taken: {time.time() - t} seconds" ) parsed_json = [] if verbose: - print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}") + print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}") t = time.time() # Split markdown into sections sections = chunking_strategy.chunk(markdown) # sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD) parsed_json = extraction_strategy.run( - str(url_model.url), sections, + url, sections, ) parsed_json = json.dumps(parsed_json) if verbose: print( - f"[LOG] 🚀 Extraction done for {url_model.url}, time taken: {time.time() - t} seconds." + f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t} seconds." ) # Cache the result cleaned_html = beautify_html(cleaned_html) cache_url( - self.db_path, - str(url_model.url), + url, html, cleaned_html, markdown, @@ -150,7 +160,7 @@ class WebCrawler: ) return CrawlResult( - url=str(url_model.url), + url=url, html=html, cleaned_html=cleaned_html, markdown=markdown, diff --git a/docs/chunking_strategies.json b/docs/chunking_strategies.json new file mode 100644 index 00000000..ec855cc7 --- /dev/null +++ b/docs/chunking_strategies.json @@ -0,0 +1,12 @@ +{ + "RegexChunking": "### RegexChunking\n\n`RegexChunking` is a text chunking strategy that splits a given text into smaller parts using regular expressions.\nThis is useful for preparing large texts for processing by language models, ensuring they are divided into manageable segments.\n\n#### Constructor Parameters:\n- `patterns` (list, optional): A list of regular expression patterns used to split the text. Default is to split by double newlines (`['\\n\\n']`).\n\n#### Example usage:\n```python\nchunker = RegexChunking(patterns=[r'\\n\\n', r'\\. '])\nchunks = chunker.chunk(\"This is a sample text. It will be split into chunks.\")\n```", + + "NlpSentenceChunking": "### NlpSentenceChunking\n\n`NlpSentenceChunking` uses a natural language processing model to chunk a given text into sentences. This approach leverages SpaCy to accurately split text based on sentence boundaries.\n\n#### Constructor Parameters:\n- `model` (str, optional): The SpaCy model to use for sentence detection. Default is `'en_core_web_sm'`.\n\n#### Example usage:\n```python\nchunker = NlpSentenceChunking(model='en_core_web_sm')\nchunks = chunker.chunk(\"This is a sample text. It will be split into sentences.\")\n```", + + "TopicSegmentationChunking": "### TopicSegmentationChunking\n\n`TopicSegmentationChunking` uses the TextTiling algorithm to segment a given text into topic-based chunks. This method identifies thematic boundaries in the text.\n\n#### Constructor Parameters:\n- `num_keywords` (int, optional): The number of keywords to extract for each topic segment. Default is `3`.\n\n#### Example usage:\n```python\nchunker = TopicSegmentationChunking(num_keywords=3)\nchunks = chunker.chunk(\"This is a sample text. It will be split into topic-based segments.\")\n```", + + "FixedLengthWordChunking": "### FixedLengthWordChunking\n\n`FixedLengthWordChunking` splits a given text into chunks of fixed length, based on the number of words.\n\n#### Constructor Parameters:\n- `chunk_size` (int, optional): The number of words in each chunk. Default is `100`.\n\n#### Example usage:\n```python\nchunker = FixedLengthWordChunking(chunk_size=100)\nchunks = chunker.chunk(\"This is a sample text. It will be split into fixed-length word chunks.\")\n```", + + "SlidingWindowChunking": "### SlidingWindowChunking\n\n`SlidingWindowChunking` uses a sliding window approach to chunk a given text. Each chunk has a fixed length, and the window slides by a specified step size.\n\n#### Constructor Parameters:\n- `window_size` (int, optional): The number of words in each chunk. Default is `100`.\n- `step` (int, optional): The number of words to slide the window. Default is `50`.\n\n#### Example usage:\n```python\nchunker = SlidingWindowChunking(window_size=100, step=50)\nchunks = chunker.chunk(\"This is a sample text. It will be split using a sliding window approach.\")\n```" + } + \ No newline at end of file diff --git a/docs/extraction_strategies.json b/docs/extraction_strategies.json new file mode 100644 index 00000000..207ab981 --- /dev/null +++ b/docs/extraction_strategies.json @@ -0,0 +1,10 @@ +{ + "NoExtractionStrategy": "### NoExtractionStrategy\n\n`NoExtractionStrategy` is a basic extraction strategy that returns the entire HTML content without any modification. It is useful for cases where no specific extraction is required. Only clean html, and amrkdown.\n\n#### Constructor Parameters:\nNone.\n\n#### Example usage:\n```python\nextractor = NoExtractionStrategy()\nextracted_content = extractor.extract(url, html)\n```", + + "LLMExtractionStrategy": "### LLMExtractionStrategy\n\n`LLMExtractionStrategy` uses a Language Model (LLM) to extract meaningful blocks or chunks from the given HTML content. This strategy leverages an external provider for language model completions.\n\n#### Constructor Parameters:\n- `provider` (str, optional): The provider to use for the language model completions. Default is `DEFAULT_PROVIDER` (following provider/model eg. openai/gpt-4o).\n- `api_token` (str, optional): The API token for the provider. If not provided, it will try to load from the environment variable `OPENAI_API_KEY`.\n\n#### Example usage:\n```python\nextractor = LLMExtractionStrategy(provider='openai', api_token='your_api_token')\nextracted_content = extractor.extract(url, html)\n```", + + "CosineStrategy": "### CosineStrategy\n\n`CosineStrategy` uses hierarchical clustering based on cosine similarity to extract clusters of text from the given HTML content. This strategy is suitable for identifying related content sections.\n\n#### Constructor Parameters:\n- `word_count_threshold` (int, optional): Minimum number of words per cluster. Default is `20`.\n- `max_dist` (float, optional): The maximum cophenetic distance on the dendrogram to form clusters. Default is `0.2`.\n- `linkage_method` (str, optional): The linkage method for hierarchical clustering. Default is `'ward'`.\n- `top_k` (int, optional): Number of top categories to extract. Default is `3`.\n- `model_name` (str, optional): The model name for embedding generation. Default is `'BAAI/bge-small-en-v1.5'`.\n\n#### Example usage:\n```python\nextractor = CosineStrategy(word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name='BAAI/bge-small-en-v1.5')\nextracted_content = extractor.extract(url, html)\n```", + + "TopicExtractionStrategy": "### TopicExtractionStrategy\n\n`TopicExtractionStrategy` uses the TextTiling algorithm to segment the HTML content into topics and extracts keywords for each segment. This strategy is useful for identifying and summarizing thematic content.\n\n#### Constructor Parameters:\n- `num_keywords` (int, optional): Number of keywords to represent each topic segment. Default is `3`.\n\n#### Example usage:\n```python\nextractor = TopicExtractionStrategy(num_keywords=3)\nextracted_content = extractor.extract(url, html)\n```" + } + \ No newline at end of file diff --git a/docs/quickstart.py b/docs/quickstart.py new file mode 100644 index 00000000..cbdfbe0d --- /dev/null +++ b/docs/quickstart.py @@ -0,0 +1,33 @@ +import os +from crawl4ai.web_crawler import WebCrawler +from crawl4ai.chunking_strategy import * +from crawl4ai.extraction_strategy import * + + +def main(): + crawler = WebCrawler() + crawler.warmup() + + # Single page crawl + result = crawler.run( + url="https://www.nbcnews.com/business", + word_count_threshold=5, # Minimum word count for a HTML tag to be considered as a worthy block + chunking_strategy=RegexChunking(patterns=["\n\n"]), # Default is RegexChunking + extraction_strategy=CosineStrategy( + word_count_threshold=20, max_dist=0.2, linkage_method="ward", top_k=3 + ), # Default is CosineStrategy + # extraction_strategy= LLMExtractionStrategy(provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY')), + bypass_cache=True, + extract_blocks=True, # Whether to extract semantical blocks of text from the HTML + css_selector="", # Eg: "div.article-body" or all H2 tags liek "h2" + verbose=True, + include_raw_html=True, # Whether to include the raw HTML content in the response + ) + + + print("[LOG] 📦 Crawl result:") + print(result.model_dump()) + + +if __name__ == "__main__": + main() diff --git a/examples/quickstart.py b/examples/quickstart.py deleted file mode 100644 index 82091e96..00000000 --- a/examples/quickstart.py +++ /dev/null @@ -1,32 +0,0 @@ -from crawl4ai.web_crawler import WebCrawler -from crawl4ai.models import UrlModel -from crawl4ai.utils import get_content_of_website -import os - -def main(): - # Initialize the WebCrawler with just the database path - crawler = WebCrawler(db_path='crawler_data.db') - - # Fetch a single page - single_url = UrlModel(url='https://www.nbcnews.com/business', forced=True) - result = crawler.fetch_page( - single_url, - provider= "openai/gpt-3.5-turbo", - api_token = os.getenv('OPENAI_API_KEY'), - use_cached_html = True, - extract_blocks_flag=True, - word_count_threshold=10 - ) - print(result.model_dump()) - - # Fetch multiple pages - # urls = [ - # UrlModel(url='http://example.com', forced=False), - # UrlModel(url='http://example.org', forced=False) - # ] - # results = crawler.fetch_pages(urls, provider= "openai/gpt-4-turbo", api_token = os.getenv('OPENAI_API_KEY')) - # for res in results: - # print(res.model_copy()) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/main.py b/main.py index 71b43c36..3cc141b7 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,19 @@ -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse -from fastapi.staticfiles import StaticFiles -from fastapi.responses import JSONResponse -from pydantic import BaseModel, HttpUrl -from typing import List, Optional -from crawl4ai.web_crawler import WebCrawler -from crawl4ai.models import UrlModel -import asyncio -from concurrent.futures import ThreadPoolExecutor, as_completed -import chromedriver_autoinstaller -from functools import lru_cache -from crawl4ai.database import get_total_count, clear_db import os -import uuid -# Import the CORS middleware +import importlib +import asyncio +from functools import lru_cache + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, HttpUrl +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Optional -# Task management -tasks = {} +from crawl4ai.web_crawler import WebCrawler +from crawl4ai.database import get_total_count, clear_db # Configuration __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) @@ -41,22 +36,25 @@ app.add_middleware( # Mount the pages directory as a static directory app.mount("/pages", StaticFiles(directory=__location__ + "/pages"), name="pages") +# chromedriver_autoinstaller.install() # Ensure chromedriver is installed +@lru_cache() +def get_crawler(): + # Initialize and return a WebCrawler instance + return WebCrawler() -chromedriver_autoinstaller.install() # Ensure chromedriver is installed - -class UrlsInput(BaseModel): +class CrawlRequest(BaseModel): urls: List[HttpUrl] provider_model: str api_token: str include_raw_html: Optional[bool] = False - forced: bool = False + bypass_cache: bool = False extract_blocks: bool = True word_count_threshold: Optional[int] = 5 + extraction_strategy: Optional[str] = "CosineStrategy" + chunking_strategy: Optional[str] = "RegexChunking" + css_selector: Optional[str] = None + verbose: Optional[bool] = True -@lru_cache() -def get_crawler(): - # Initialize and return a WebCrawler instance - return WebCrawler(db_path='crawler_data.db') @app.get("/", response_class=HTMLResponse) async def read_index(): @@ -66,20 +64,30 @@ async def read_index(): @app.get("/total-count") async def get_total_url_count(): - count = get_total_count(db_path='crawler_data.db') + count = get_total_count() return JSONResponse(content={"count": count}) # Add endpoit to clear db @app.get("/clear-db") async def clear_database(): - clear_db(db_path='crawler_data.db') + clear_db() return JSONResponse(content={"message": "Database cleared."}) +def import_strategy(module_name: str, class_name: str): + try: + module = importlib.import_module(module_name) + strategy_class = getattr(module, class_name) + return strategy_class() + except ImportError: + raise HTTPException(status_code=400, detail=f"Module {module_name} not found.") + except AttributeError: + raise HTTPException(status_code=400, detail=f"Class {class_name} not found in {module_name}.") + @app.post("/crawl") -async def crawl_urls(urls_input: UrlsInput, request: Request): +async def crawl_urls(crawl_request: CrawlRequest, request: Request): global current_requests # Raise error if api_token is not provided - if not urls_input.api_token: + if not crawl_request.api_token: raise HTTPException(status_code=401, detail="API token is required.") async with lock: if current_requests >= MAX_CONCURRENT_REQUESTS: @@ -87,87 +95,50 @@ async def crawl_urls(urls_input: UrlsInput, request: Request): current_requests += 1 try: - # Prepare URL models for crawling - url_models = [UrlModel(url=url, forced=urls_input.forced) for url in urls_input.urls] + extraction_strategy = import_strategy("crawl4ai.extraction_strategy", crawl_request.extraction_strategy) + chunking_strategy = import_strategy("crawl4ai.chunking_strategy", crawl_request.chunking_strategy) # Use ThreadPoolExecutor to run the synchronous WebCrawler in async manner with ThreadPoolExecutor() as executor: loop = asyncio.get_event_loop() futures = [ - loop.run_in_executor(executor, get_crawler().fetch_page, url_model, urls_input.provider_model, urls_input.api_token, urls_input.extract_blocks, urls_input.word_count_threshold) - for url_model in url_models + loop.run_in_executor( + executor, + get_crawler().run, + str(url), + crawl_request.word_count_threshold, + extraction_strategy, + chunking_strategy, + crawl_request.bypass_cache, + crawl_request.css_selector, + crawl_request.verbose + ) + for url in crawl_request.urls ] results = await asyncio.gather(*futures) # if include_raw_html is False, remove the raw HTML content from the results - if not urls_input.include_raw_html: + if not crawl_request.include_raw_html: for result in results: result.html = None - + return {"results": [result.dict() for result in results]} finally: async with lock: current_requests -= 1 + +@app.get("/strategies/extraction", response_class=JSONResponse) +async def get_extraction_strategies(): + # Load docs/extraction_strategies.json" and return as JSON response + with open(f"{__location__}/docs/extraction_strategies.json", "r") as file: + return JSONResponse(content=file.read()) -@app.post("/crawl_async") -async def crawl_urls(urls_input: UrlsInput, request: Request): - global current_requests - if not urls_input.api_token: - raise HTTPException(status_code=401, detail="API token is required.") - - async with lock: - if current_requests >= MAX_CONCURRENT_REQUESTS: - raise HTTPException(status_code=429, detail="Too many requests - please try again later.") - current_requests += 1 - - task_id = str(uuid.uuid4()) - tasks[task_id] = {"status": "pending", "results": None} - - try: - url_models = [UrlModel(url=url, forced=urls_input.forced) for url in urls_input.urls] - - loop = asyncio.get_running_loop() - loop.create_task( - process_crawl_task(url_models, urls_input.provider_model, urls_input.api_token, task_id, urls_input.extract_blocks) - ) - return {"task_id": task_id} - finally: - async with lock: - current_requests -= 1 - -async def process_crawl_task(url_models, provider, api_token, task_id, extract_blocks_flag): - try: - with ThreadPoolExecutor() as executor: - loop = asyncio.get_running_loop() - futures = [ - loop.run_in_executor(executor, get_crawler().fetch_page, url_model, provider, api_token, extract_blocks_flag) - for url_model in url_models - ] - results = await asyncio.gather(*futures) - - tasks[task_id] = {"status": "done", "results": results} - except Exception as e: - tasks[task_id] = {"status": "failed", "error": str(e)} - -@app.get("/task/{task_id}") -async def get_task_status(task_id: str): - task = tasks.get(task_id) - if not task: - raise HTTPException(status_code=404, detail="Task not found") - - if task['status'] == 'done': - return { - "status": task['status'], - "results": [result.dict() for result in task['results']] - } - elif task['status'] == 'failed': - return { - "status": task['status'], - "error": task['error'] - } - else: - return {"status": task['status']} +@app.get("/strategies/chunking", response_class=JSONResponse) +async def get_chunking_strategies(): + with open(f"{__location__}/docs/chunking_strategies.json", "r") as file: + return JSONResponse(content=file.read()) + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/pages/index.html b/pages/index.html index 63d32dd9..e354ae3e 100644 --- a/pages/index.html +++ b/pages/index.html @@ -9,12 +9,15 @@ - + + + + + - -
-
-

🔥🕷️ Crawl4AI: Open-source LLM Friendly Web Crawler & Scrapper

+ +
+
+

🔥🕷️ Crawl4AI: Web Data for your Thoughts

+
+
+ 📊 Total Website Processed + 2
- -
-
- 📊 Total Website Procceced - 0 -
-
- -
+

Try It Now

-
-
- - -
- -
- - +
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + +
+
+ + +
+ +
-
- - - -
-
- - - -
-
- - -
- -
-
- -
+
+
- - + -
-
-
+
+
+ > + >
-
+ +
- - - + +
-
+
-                                    
-                                    
-                                
+ + + + + + + + + +
+
+

Installation 💻

+

There are two ways to use Crawl4AI: as a library in your Python projects or as a standalone local server.

+ +

You can also try Crawl4AI in a Google Colab Open In Colab

+ +

Using Crawl4AI as a Library 📚

+

To install Crawl4AI as a library, follow these steps:

+ +
    +
  1. + Install the package from GitHub: +
    pip install git+https://github.com/unclecode/crawl4ai.git
    +
  2. +
  3. + Alternatively, you can clone the repository and install the package locally: +
    virtualenv venv
    +source venv/bin/activate
    +git clone https://github.com/unclecode/crawl4ai.git
    +cd crawl4ai
    +pip install -e .
    +        
    +
  4. +
  5. + Import the necessary modules in your Python script: +
    from crawl4ai.web_crawler import WebCrawler
    +from crawl4ai.chunking_strategy import *
    +from crawl4ai.extraction_strategy import *
    +import os
     
    -        
    +crawler = WebCrawler() + +# Single page crawl +single_url = UrlModel(url='https://www.nbcnews.com/business', forced=False) +result = crawl4ai.fetch_page( + url='https://www.nbcnews.com/business', + word_count_threshold=5, # Minimum word count for a HTML tag to be considered as a worthy block + chunking_strategy= RegexChunking( patterns = ["\\n\\n"]), # Default is RegexChunking + extraction_strategy= CosineStrategy(word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3) # Default is CosineStrategy + # extraction_strategy= LLMExtractionStrategy(provider= "openai/gpt-4o", api_token = os.getenv('OPENAI_API_KEY')), + bypass_cache=False, + extract_blocks =True, # Whether to extract semantical blocks of text from the HTML + css_selector = "", # Eg: "div.article-body" + verbose=True, + include_raw_html=True, # Whether to include the raw HTML content in the response +) +print(result.model_dump()) +
    +
  6. +
+

For more information about how to run Crawl4AI as a local server, please refer to the GitHub repository.

+ +

📖 Parameters

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ParameterDescriptionRequiredDefault Value
urls + A list of URLs to crawl and extract data from. + Yes-
include_raw_html + Whether to include the raw HTML content in the response. + Nofalse
bypass_cache + Whether to force a fresh crawl even if the URL has been previously crawled. + Nofalse
extract_blocks + Whether to extract semantical blocks of text from the HTML. + Notrue
word_count_threshold + The minimum number of words a block must contain to be considered meaningful (minimum + value is 5). + No5
extraction_strategy + The strategy to use for extracting content from the HTML (e.g., "CosineStrategy"). + NoCosineStrategy
chunking_strategy + The strategy to use for chunking the text before processing (e.g., "RegexChunking"). + NoRegexChunking
css_selector + The CSS selector to target specific parts of the HTML for extraction. + NoNone
verboseWhether to enable verbose logging.Notrue
+
+
+ +
+
+

Extraction Strategies

+
+
+
+ +
+
+

Chunking Strategies

+
+
+
+ +

🤔 Why building this?

@@ -192,7 +504,7 @@

-
+

⚙️ Installation

@@ -202,7 +514,7 @@

-
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_web_crawler.py b/tests/test_web_crawler.py new file mode 100644 index 00000000..99360f42 --- /dev/null +++ b/tests/test_web_crawler.py @@ -0,0 +1,111 @@ +import unittest, os +from crawl4ai.web_crawler import WebCrawler +from crawl4ai.chunking_strategy import RegexChunking, FixedLengthWordChunking, SlidingWindowChunking +from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, TopicExtractionStrategy, NoExtractionStrategy + +class TestWebCrawler(unittest.TestCase): + + def setUp(self): + self.crawler = WebCrawler() + + def test_warmup(self): + self.crawler.warmup() + self.assertTrue(self.crawler.ready, "WebCrawler failed to warm up") + + def test_run_default_strategies(self): + result = self.crawler.run( + url='https://www.nbcnews.com/business', + word_count_threshold=5, + chunking_strategy=RegexChunking(), + extraction_strategy=CosineStrategy(), bypass_cache=True + ) + self.assertTrue(result.success, "Failed to crawl and extract using default strategies") + + def test_run_different_strategies(self): + url = 'https://www.nbcnews.com/business' + + # Test with FixedLengthWordChunking and LLMExtractionStrategy + result = self.crawler.run( + url=url, + word_count_threshold=5, + chunking_strategy=FixedLengthWordChunking(chunk_size=100), + extraction_strategy=LLMExtractionStrategy(provider="openai/gpt-3.5-turbo", api_token=os.getenv('OPENAI_API_KEY')), bypass_cache=True + ) + self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and LLMExtractionStrategy") + + # Test with SlidingWindowChunking and TopicExtractionStrategy + result = self.crawler.run( + url=url, + word_count_threshold=5, + chunking_strategy=SlidingWindowChunking(window_size=100, step=50), + extraction_strategy=TopicExtractionStrategy(num_keywords=5), bypass_cache=True + ) + self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and TopicExtractionStrategy") + + def test_invalid_url(self): + with self.assertRaises(Exception) as context: + self.crawler.run(url='invalid_url', bypass_cache=True) + self.assertIn("Invalid URL", str(context.exception)) + + def test_unsupported_extraction_strategy(self): + with self.assertRaises(Exception) as context: + self.crawler.run(url='https://www.nbcnews.com/business', extraction_strategy="UnsupportedStrategy", bypass_cache=True) + self.assertIn("Unsupported extraction strategy", str(context.exception)) + + def test_invalid_css_selector(self): + with self.assertRaises(ValueError) as context: + self.crawler.run(url='https://www.nbcnews.com/business', css_selector="invalid_selector", bypass_cache=True) + self.assertIn("Invalid CSS selector", str(context.exception)) + + + def test_crawl_with_cache_and_bypass_cache(self): + url = 'https://www.nbcnews.com/business' + + # First crawl with cache enabled + result = self.crawler.run(url=url, bypass_cache=False) + self.assertTrue(result.success, "Failed to crawl and cache the result") + + # Second crawl with bypass_cache=True + result = self.crawler.run(url=url, bypass_cache=True) + self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data") + + def test_fetch_multiple_pages(self): + urls = [ + 'https://www.nbcnews.com/business', + 'https://www.bbc.com/news' + ] + results = [] + for url in urls: + result = self.crawler.run( + url=url, + word_count_threshold=5, + chunking_strategy=RegexChunking(), + extraction_strategy=CosineStrategy(), + bypass_cache=True + ) + results.append(result) + + self.assertEqual(len(results), 2, "Failed to crawl and extract multiple pages") + for result in results: + self.assertTrue(result.success, "Failed to crawl and extract a page in the list") + + def test_run_fixed_length_word_chunking_and_no_extraction(self): + result = self.crawler.run( + url='https://www.nbcnews.com/business', + word_count_threshold=5, + chunking_strategy=FixedLengthWordChunking(chunk_size=100), + extraction_strategy=NoExtractionStrategy(), bypass_cache=True + ) + self.assertTrue(result.success, "Failed to crawl and extract with FixedLengthWordChunking and NoExtractionStrategy") + + def test_run_sliding_window_and_no_extraction(self): + result = self.crawler.run( + url='https://www.nbcnews.com/business', + word_count_threshold=5, + chunking_strategy=SlidingWindowChunking(window_size=100, step=50), + extraction_strategy=NoExtractionStrategy(), bypass_cache=True + ) + self.assertTrue(result.success, "Failed to crawl and extract with SlidingWindowChunking and NoExtractionStrategy") + +if __name__ == '__main__': + unittest.main()