- Test all methods

- Update index.hml
- Update Readme
- Resolve some bugs
This commit is contained in:
unclecode
2024-05-14 21:27:41 +08:00
parent 5fea6c064b
commit f6e59157bf
17 changed files with 1004 additions and 402 deletions

View File

@@ -3,15 +3,17 @@ from dotenv import load_dotenv
load_dotenv() # Load environment variables from .env file
# Default provider
# Default provider, ONLY used when the extraction strategy is LLMExtractionStrategy
DEFAULT_PROVIDER = "openai/gpt-4-turbo"
# Provider-model dictionary
# Provider-model dictionary, ONLY used when the extraction strategy is LLMExtractionStrategy
PROVIDER_MODELS = {
"ollama/llama3": "no-token-needed", # Any model from Ollama no need for API token
"groq/llama3-70b-8192": os.getenv("GROQ_API_KEY"),
"groq/llama3-8b-8192": os.getenv("GROQ_API_KEY"),
"openai/gpt-3.5-turbo": os.getenv("OPENAI_API_KEY"),
"openai/gpt-4-turbo": os.getenv("OPENAI_API_KEY"),
"openai/gpt-4o": os.getenv("OPENAI_API_KEY"),
"anthropic/claude-3-haiku-20240307": os.getenv("ANTHROPIC_API_KEY"),
"anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"),
"anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"),

View File

@@ -5,18 +5,20 @@ from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options
from selenium.common.exceptions import InvalidArgumentException
import chromedriver_autoinstaller
from typing import List
import requests
import os
from pathlib import Path
class CrawlerStrategy(ABC):
@abstractmethod
def crawl(self, url: str) -> str:
def crawl(self, url: str, **kwargs) -> str:
pass
class CloudCrawlerStrategy(CrawlerStrategy):
def crawl(self, url: str) -> str:
def crawl(self, url: str, use_cached_html = False, css_selector = None) -> str:
data = {
"urls": [url],
"provider_model": "",
@@ -40,19 +42,34 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
self.options.add_argument("--disable-dev-shm-usage")
self.options.add_argument("--headless")
chromedriver_autoinstaller.install()
# chromedriver_autoinstaller.install()
self.service = Service(chromedriver_autoinstaller.install())
self.driver = webdriver.Chrome(service=self.service, options=self.options)
def crawl(self, url: str, use_cached_html = False) -> str:
def crawl(self, url: str, use_cached_html = False, css_selector = None) -> str:
if use_cached_html:
return get_content_of_website(url)
self.driver.get(url)
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
)
html = self.driver.page_source
return html
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
if os.path.exists(cache_file_path):
with open(cache_file_path, "r") as f:
return f.read()
try:
self.driver.get(url)
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
)
html = self.driver.page_source
# Store in cache
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
with open(cache_file_path, "w") as f:
f.write(html)
return html
except InvalidArgumentException:
raise InvalidArgumentException(f"Invalid URL {url}")
except Exception as e:
raise Exception(f"Failed to crawl {url}: {str(e)}")
def quit(self):
self.driver.quit()

View File

@@ -1,7 +1,15 @@
import os
from pathlib import Path
import sqlite3
from typing import Optional
from typing import Optional, Tuple
DB_PATH = os.path.join(Path.home(), ".crawl4ai")
os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
def init_db(db_path: str):
global DB_PATH
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('''
@@ -16,46 +24,65 @@ def init_db(db_path: str):
''')
conn.commit()
conn.close()
DB_PATH = db_path
def get_cached_url(db_path: str, url: str) -> Optional[tuple]:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, parsed_json, success FROM crawled_data WHERE url = ?', (url,))
result = cursor.fetchone()
conn.close()
return result
def check_db_path():
if not DB_PATH:
raise ValueError("Database path is not set or is empty.")
def cache_url(db_path: str, url: str, html: str, cleaned_html: str, markdown: str, parsed_json: str, success: bool):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO crawled_data (url, html, cleaned_html, markdown, parsed_json, success)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
html = excluded.html,
cleaned_html = excluded.cleaned_html,
markdown = excluded.markdown,
parsed_json = excluded.parsed_json,
success = excluded.success
''', (str(url), html, cleaned_html, markdown, parsed_json, success))
conn.commit()
conn.close()
def get_total_count(db_path: str) -> int:
def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool]]:
check_db_path()
try:
conn = sqlite3.connect(db_path)
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, parsed_json, success FROM crawled_data WHERE url = ?', (url,))
result = cursor.fetchone()
conn.close()
return result
except Exception as e:
print(f"Error retrieving cached URL: {e}")
return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, parsed_json: str, success: bool):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO crawled_data (url, html, cleaned_html, markdown, parsed_json, success)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
html = excluded.html,
cleaned_html = excluded.cleaned_html,
markdown = excluded.markdown,
parsed_json = excluded.parsed_json,
success = excluded.success
''', (url, html, cleaned_html, markdown, parsed_json, success))
conn.commit()
conn.close()
except Exception as e:
print(f"Error caching URL: {e}")
def get_total_count() -> int:
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM crawled_data')
result = cursor.fetchone()
conn.close()
return result[0]
except Exception as e:
print(f"Error getting total count: {e}")
return 0
# Crete function to cler the database
def clear_db(db_path: str):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute('DELETE FROM crawled_data')
conn.commit()
conn.close()
def clear_db():
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('DELETE FROM crawled_data')
conn.commit()
conn.close()
except Exception as e:
print(f"Error clearing database: {e}")

View File

@@ -7,6 +7,8 @@ from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
from .utils import *
from functools import partial
from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model
class ExtractionStrategy(ABC):
"""
@@ -15,6 +17,7 @@ class ExtractionStrategy(ABC):
def __init__(self):
self.DEL = "<|DEL|>"
self.name = self.__class__.__name__
@abstractmethod
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
@@ -67,7 +70,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
def extract(self, url: str, html: str) -> List[Dict[str, Any]]:
print("Extracting blocks ...")
print("[LOG] Extracting blocks from URL:", url)
variable_values = {
"URL": url,
"HTML": escape_json_string(sanitize_html(html)),
@@ -98,7 +101,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"content": unparsed
})
print("Extracted", len(blocks), "blocks.")
print("[LOG] Extracted", len(blocks), "blocks from URL:", url)
return blocks
def _merge(self, documents):
@@ -125,6 +128,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
Process sections sequentially with a delay for rate limiting issues, specifically for LLMExtractionStrategy.
"""
merged_sections = self._merge(sections)
parsed_json = []
if self.provider.startswith("groq/"):
@@ -144,7 +148,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
return parsed_json
class CosinegStrategy(ExtractionStrategy):
class CosineStrategy(ExtractionStrategy):
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
"""
Initialize the strategy with clustering parameters.
@@ -164,20 +168,13 @@ class CosinegStrategy(ExtractionStrategy):
self.linkage_method = linkage_method
self.top_k = top_k
self.timer = time.time()
if model_name == "bert-base-uncased":
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
self.model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
elif model_name == "sshleifer/distilbart-cnn-12-6":
# self.model = IPEXModel.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
# self.tokenizer = AutoTokenizer.from_pretrained("Intel/bge-small-en-v1.5-rag-int8-static")
pass
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
self.model.eval()
self.nlp = spacy.load("models/reuters")
if model_name == "bert-base-uncased":
self.tokenizer, self.model = load_bert_base_uncased()
elif model_name == "BAAI/bge-small-en-v1.5":
self.tokenizer, self.model = load_bge_small_en_v1_5()
self.nlp = load_spacy_model()
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
def get_embeddings(self, sentences: List[str]):

20
crawl4ai/model_loader.py Normal file
View File

@@ -0,0 +1,20 @@
from functools import lru_cache
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
import spacy
@lru_cache()
def load_bert_base_uncased():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', resume_download=None)
model = BertModel.from_pretrained('bert-base-uncased', resume_download=None)
return tokenizer, model
@lru_cache()
def load_bge_small_en_v1_5():
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5', resume_download=None)
model.eval()
return tokenizer, model
@lru_cache()
def load_spacy_model():
return spacy.load("models/reuters")

View File

@@ -10,6 +10,8 @@ from html2text import HTML2Text
from .prompts import PROMPT_EXTRACT_BLOCKS
from .config import *
class InvalidCSSSelectorError(Exception):
pass
def beautify_html(escaped_html):
"""
@@ -140,13 +142,25 @@ class CustomHTML2Text(HTML2Text):
super().handle_tag(tag, attrs, start)
def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD):
def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD, css_selector = None):
try:
if not html:
return None
# Parse HTML content with BeautifulSoup
soup = BeautifulSoup(html, 'html.parser')
# Get the content within the <body> tag
body = soup.body
# If css_selector is provided, extract content based on the selector
if css_selector:
selected_elements = body.select(css_selector)
if not selected_elements:
raise InvalidCSSSelectorError(f"Invalid CSS selector , No elements found for CSS selector: {css_selector}")
div_tag = soup.new_tag('div')
for el in selected_elements:
div_tag.append(el)
body = div_tag
# Remove script, style, and other tags that don't carry useful content from body
for tag in body.find_all(['script', 'style', 'link', 'meta', 'noscript']):
@@ -255,7 +269,7 @@ def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD):
# Remove comments
for comment in soup.find_all(text=lambda text: isinstance(text, Comment)):
for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
comment.extract()
# Remove consecutive empty newlines and replace multiple spaces with a single space
@@ -281,7 +295,7 @@ def get_content_of_website(html, word_count_threshold = MIN_WORD_THRESHOLD):
except Exception as e:
print('Error processing HTML content:', str(e))
return None
raise InvalidCSSSelectorError(f"Invalid CSS selector: {css_selector}") from e
def extract_xml_tags(string):
tags = re.findall(r'<(\w+)>', string)

View File

@@ -2,7 +2,7 @@ import os, time
from pathlib import Path
from .models import UrlModel, CrawlResult
from .database import init_db, get_cached_url, cache_url
from .database import init_db, get_cached_url, cache_url, DB_PATH
from .utils import *
from .chunking_strategy import *
from .extraction_strategy import *
@@ -10,6 +10,7 @@ from .crawler_strategy import *
from typing import List
from concurrent.futures import ThreadPoolExecutor
from .config import *
# from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model
class WebCrawler:
@@ -36,11 +37,11 @@ class WebCrawler:
def warmup(self):
print("[LOG] 🌤️ Warming up the WebCrawler")
single_url = UrlModel(url='https://crawl4ai.uccode.io/', forced=False)
result = self.run(
single_url,
url='https://crawl4ai.uccode.io/',
word_count_threshold=5,
extraction_strategy= CosinegStrategy(),
extraction_strategy= CosineStrategy(),
bypass_cache=False,
verbose = False
)
self.ready = True
@@ -60,10 +61,11 @@ class WebCrawler:
**kwargs,
) -> CrawlResult:
return self.run(
url_model,
url_model.url,
word_count_threshold,
extraction_strategy,
chunking_strategy,
bypass_cache=url_model.forced,
**kwargs,
)
pass
@@ -71,77 +73,85 @@ class WebCrawler:
def run(
self,
url_model: UrlModel,
url: str,
word_count_threshold=MIN_WORD_THRESHOLD,
extraction_strategy: ExtractionStrategy = NoExtractionStrategy(),
chunking_strategy: ChunkingStrategy = RegexChunking(),
bypass_cache: bool = False,
css_selector: str = None,
verbose=True,
**kwargs,
) -> CrawlResult:
# Check if extraction strategy is an instance of ExtractionStrategy if not raise an error
if not isinstance(extraction_strategy, ExtractionStrategy):
raise ValueError("Unsupported extraction strategy")
if not isinstance(chunking_strategy, ChunkingStrategy):
raise ValueError("Unsupported chunking strategy")
# make sure word_count_threshold is not lesser than MIN_WORD_THRESHOLD
if word_count_threshold < MIN_WORD_THRESHOLD:
word_count_threshold = MIN_WORD_THRESHOLD
# Check cache first
cached = get_cached_url(self.db_path, str(url_model.url))
if cached and not url_model.forced:
return CrawlResult(
**{
"url": cached[0],
"html": cached[1],
"cleaned_html": cached[2],
"markdown": cached[3],
"parsed_json": cached[4],
"success": cached[5],
"error_message": "",
}
)
if not bypass_cache:
cached = get_cached_url(url)
if cached:
return CrawlResult(
**{
"url": cached[0],
"html": cached[1],
"cleaned_html": cached[2],
"markdown": cached[3],
"parsed_json": cached[4],
"success": cached[5],
"error_message": "",
}
)
# Initialize WebDriver for crawling
t = time.time()
try:
html = self.crawler_strategy.crawl(str(url_model.url))
success = True
error_message = ""
except Exception as e:
html = ""
success = False
error_message = str(e)
html = self.crawler_strategy.crawl(url)
success = True
error_message = ""
# Extract content from HTML
result = get_content_of_website(html, word_count_threshold)
try:
result = get_content_of_website(html, word_count_threshold, css_selector=css_selector)
if result is None:
raise ValueError(f"Failed to extract content from the website: {url}")
except InvalidCSSSelectorError as e:
raise ValueError(str(e))
cleaned_html = result.get("cleaned_html", html)
markdown = result.get("markdown", "")
# Print a profession LOG style message, show time taken and say crawling is done
if verbose:
print(
f"[LOG] 🚀 Crawling done for {url_model.url}, success: {success}, time taken: {time.time() - t} seconds"
f"[LOG] 🚀 Crawling done for {url}, success: {success}, time taken: {time.time() - t} seconds"
)
parsed_json = []
if verbose:
print(f"[LOG] 🔥 Extracting semantic blocks for {url_model.url}")
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}")
t = time.time()
# Split markdown into sections
sections = chunking_strategy.chunk(markdown)
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
parsed_json = extraction_strategy.run(
str(url_model.url), sections,
url, sections,
)
parsed_json = json.dumps(parsed_json)
if verbose:
print(
f"[LOG] 🚀 Extraction done for {url_model.url}, time taken: {time.time() - t} seconds."
f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t} seconds."
)
# Cache the result
cleaned_html = beautify_html(cleaned_html)
cache_url(
self.db_path,
str(url_model.url),
url,
html,
cleaned_html,
markdown,
@@ -150,7 +160,7 @@ class WebCrawler:
)
return CrawlResult(
url=str(url_model.url),
url=url,
html=html,
cleaned_html=cleaned_html,
markdown=markdown,