- Debug
- Refactor code for new version
This commit is contained in:
unclecode
2024-05-16 17:31:44 +08:00
parent f6e59157bf
commit 5b80be956d
23 changed files with 3116 additions and 1019 deletions

View File

@@ -38,7 +38,12 @@ class RegexChunking(ChunkingStrategy):
class NlpSentenceChunking(ChunkingStrategy):
def __init__(self, model='en_core_web_sm'):
import spacy
self.nlp = spacy.load(model)
try:
self.nlp = spacy.load(model)
except IOError:
spacy.cli.download("en_core_web_sm")
self.nlp = spacy.load(model)
# raise ImportError(f"Spacy model '{model}' not found. Please download the model using 'python -m spacy download {model}'")
def chunk(self, text: str) -> list:
doc = self.nlp(text)

View File

@@ -18,15 +18,16 @@ class CrawlerStrategy(ABC):
pass
class CloudCrawlerStrategy(CrawlerStrategy):
def crawl(self, url: str, use_cached_html = False, css_selector = None) -> str:
def __init__(self, use_cached_html = False):
super().__init__()
self.use_cached_html = use_cached_html
def crawl(self, url: str) -> str:
data = {
"urls": [url],
"provider_model": "",
"api_token": "token",
"include_raw_html": True,
"forced": True,
"extract_blocks": False,
"word_count_threshold": 10
}
response = requests.post("http://crawl4ai.uccode.io/crawl", json=data)
@@ -35,19 +36,24 @@ class CloudCrawlerStrategy(CrawlerStrategy):
return html
class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
def __init__(self):
def __init__(self, use_cached_html=False, js_code=None):
super().__init__()
self.options = Options()
self.options.headless = True
self.options.add_argument("--no-sandbox")
self.options.add_argument("--disable-dev-shm-usage")
self.options.add_argument("--disable-gpu")
self.options.add_argument("--disable-extensions")
self.options.add_argument("--headless")
self.use_cached_html = use_cached_html
self.js_code = js_code
# chromedriver_autoinstaller.install()
self.service = Service(chromedriver_autoinstaller.install())
self.driver = webdriver.Chrome(service=self.service, options=self.options)
def crawl(self, url: str, use_cached_html = False, css_selector = None) -> str:
if use_cached_html:
def crawl(self, url: str) -> str:
if self.use_cached_html:
cache_file_path = os.path.join(Path.home(), ".crawl4ai", "cache", url.replace("/", "_"))
if os.path.exists(cache_file_path):
with open(cache_file_path, "r") as f:
@@ -58,6 +64,15 @@ class LocalSeleniumCrawlerStrategy(CrawlerStrategy):
WebDriverWait(self.driver, 10).until(
EC.presence_of_all_elements_located((By.TAG_NAME, "html"))
)
# Execute JS code if provided
if self.js_code:
self.driver.execute_script(self.js_code)
# Optionally, wait for some condition after executing the JS code
WebDriverWait(self.driver, 10).until(
lambda driver: driver.execute_script("return document.readyState") == "complete"
)
html = self.driver.page_source
# Store in cache

View File

@@ -8,9 +8,9 @@ DB_PATH = os.path.join(Path.home(), ".crawl4ai")
os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(DB_PATH, "crawl4ai.db")
def init_db(db_path: str):
def init_db():
global DB_PATH
conn = sqlite3.connect(db_path)
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS crawled_data (
@@ -18,13 +18,12 @@ def init_db(db_path: str):
html TEXT,
cleaned_html TEXT,
markdown TEXT,
parsed_json TEXT,
extracted_content TEXT,
success BOOLEAN
)
''')
conn.commit()
conn.close()
DB_PATH = db_path
def check_db_path():
if not DB_PATH:
@@ -35,7 +34,7 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool]]:
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('SELECT url, html, cleaned_html, markdown, parsed_json, success FROM crawled_data WHERE url = ?', (url,))
cursor.execute('SELECT url, html, cleaned_html, markdown, extracted_content, success FROM crawled_data WHERE url = ?', (url,))
result = cursor.fetchone()
conn.close()
return result
@@ -43,21 +42,21 @@ def get_cached_url(url: str) -> Optional[Tuple[str, str, str, str, str, bool]]:
print(f"Error retrieving cached URL: {e}")
return None
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, parsed_json: str, success: bool):
def cache_url(url: str, html: str, cleaned_html: str, markdown: str, extracted_content: str, success: bool):
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO crawled_data (url, html, cleaned_html, markdown, parsed_json, success)
INSERT INTO crawled_data (url, html, cleaned_html, markdown, extracted_content, success)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(url) DO UPDATE SET
html = excluded.html,
cleaned_html = excluded.cleaned_html,
markdown = excluded.markdown,
parsed_json = excluded.parsed_json,
extracted_content = excluded.extracted_content,
success = excluded.success
''', (url, html, cleaned_html, markdown, parsed_json, success))
''', (url, html, cleaned_html, markdown, extracted_content, success))
conn.commit()
conn.close()
except Exception as e:
@@ -85,4 +84,15 @@ def clear_db():
conn.commit()
conn.close()
except Exception as e:
print(f"Error clearing database: {e}")
print(f"Error clearing database: {e}")
def flush_db():
check_db_path()
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute('DROP TABLE crawled_data')
conn.commit()
conn.close()
except Exception as e:
print(f"Error flushing database: {e}")

View File

@@ -3,19 +3,20 @@ from typing import Any, List, Dict, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
import json, time
# from optimum.intel import IPEXModel
from .prompts import PROMPT_EXTRACT_BLOCKS
from .prompts import PROMPT_EXTRACT_BLOCKS, PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
from .config import *
from .utils import *
from functools import partial
from .model_loader import load_bert_base_uncased, load_bge_small_en_v1_5, load_spacy_model
from transformers import pipeline
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class ExtractionStrategy(ABC):
"""
Abstract base class for all extraction strategies.
"""
def __init__(self):
def __init__(self, **kwargs):
self.DEL = "<|DEL|>"
self.name = self.__class__.__name__
@@ -38,12 +39,12 @@ class ExtractionStrategy(ABC):
:param sections: List of sections (strings) to process.
:return: A list of processed JSON blocks.
"""
parsed_json = []
extracted_content = []
with ThreadPoolExecutor() as executor:
futures = [executor.submit(self.extract, url, section, **kwargs) for section in sections]
for future in as_completed(futures):
parsed_json.extend(future.result())
return parsed_json
extracted_content.extend(future.result())
return extracted_content
class NoExtractionStrategy(ExtractionStrategy):
def extract(self, url: str, html: str, *q, **kwargs) -> List[Dict[str, Any]]:
@@ -53,37 +54,41 @@ class NoExtractionStrategy(ExtractionStrategy):
return [{"index": i, "tags": [], "content": section} for i, section in enumerate(sections)]
class LLMExtractionStrategy(ExtractionStrategy):
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None):
def __init__(self, provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, instruction:str = None, **kwargs):
"""
Initialize the strategy with clustering parameters.
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
:param provider: The provider to use for extraction.
:param api_token: The API token for the provider.
:param instruction: The instruction to use for the LLM model.
"""
super().__init__()
self.provider = provider
self.api_token = api_token or PROVIDER_MODELS.get(provider, None) or os.getenv("OPENAI_API_KEY")
self.instruction = instruction
if not self.api_token:
raise ValueError("API token must be provided for LLMExtractionStrategy. Update the config.py or set OPENAI_API_KEY environment variable.")
def extract(self, url: str, html: str) -> List[Dict[str, Any]]:
print("[LOG] Extracting blocks from URL:", url)
def extract(self, url: str, ix:int, html: str) -> List[Dict[str, Any]]:
# print("[LOG] Extracting blocks from URL:", url)
print(f"[LOG] Call LLM for {url} - block index: {ix}")
variable_values = {
"URL": url,
"HTML": escape_json_string(sanitize_html(html)),
}
if self.instruction:
variable_values["REQUEST"] = self.instruction
prompt_with_variables = PROMPT_EXTRACT_BLOCKS
prompt_with_variables = PROMPT_EXTRACT_BLOCKS if not self.instruction else PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION
for variable in variable_values:
prompt_with_variables = prompt_with_variables.replace(
"{" + variable + "}", variable_values[variable]
)
response = perform_completion_with_backoff(self.provider, prompt_with_variables, self.api_token)
try:
blocks = extract_xml_data(["blocks"], response.choices[0].message.content)['blocks']
blocks = json.loads(blocks)
@@ -101,7 +106,7 @@ class LLMExtractionStrategy(ExtractionStrategy):
"content": unparsed
})
print("[LOG] Extracted", len(blocks), "blocks from URL:", url)
print("[LOG] Extracted", len(blocks), "blocks from URL:", url, "block index:", ix)
return blocks
def _merge(self, documents):
@@ -130,29 +135,30 @@ class LLMExtractionStrategy(ExtractionStrategy):
"""
merged_sections = self._merge(sections)
parsed_json = []
extracted_content = []
if self.provider.startswith("groq/"):
# Sequential processing with a delay
for section in merged_sections:
parsed_json.extend(self.extract(url, section))
for ix, section in enumerate(merged_sections):
extracted_content.extend(self.extract(ix, url, section))
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=4) as executor:
extract_func = partial(self.extract, url)
futures = [executor.submit(extract_func, section) for section in merged_sections]
futures = [executor.submit(extract_func, ix, section) for ix, section in enumerate(merged_sections)]
for future in as_completed(futures):
parsed_json.extend(future.result())
extracted_content.extend(future.result())
return parsed_json
return extracted_content
class CosineStrategy(ExtractionStrategy):
def __init__(self, word_count_threshold=20, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5'):
def __init__(self, semantic_filter = None, word_count_threshold=10, max_dist=0.2, linkage_method='ward', top_k=3, model_name = 'BAAI/bge-small-en-v1.5', **kwargs):
"""
Initialize the strategy with clustering parameters.
:param semantic_filter: A keyword filter for document filtering.
:param word_count_threshold: Minimum number of words per cluster.
:param max_dist: The maximum cophenetic distance on the dendrogram to form clusters.
:param linkage_method: The linkage method for hierarchical clustering.
@@ -163,11 +169,14 @@ class CosineStrategy(ExtractionStrategy):
from transformers import AutoTokenizer, AutoModel
import spacy
self.semantic_filter = semantic_filter
self.word_count_threshold = word_count_threshold
self.max_dist = max_dist
self.linkage_method = linkage_method
self.top_k = top_k
self.timer = time.time()
self.buffer_embeddings = np.array([])
if model_name == "bert-base-uncased":
self.tokenizer, self.model = load_bert_base_uncased()
@@ -177,13 +186,42 @@ class CosineStrategy(ExtractionStrategy):
self.nlp = load_spacy_model()
print(f"[LOG] Model loaded {model_name}, models/reuters, took " + str(time.time() - self.timer) + " seconds")
def get_embeddings(self, sentences: List[str]):
def filter_documents_embeddings(self, documents: List[str], semantic_filter: str, threshold: float = 0.5) -> List[str]:
"""
Filter documents based on the cosine similarity of their embeddings with the semantic_filter embedding.
:param documents: List of text chunks (documents).
:param semantic_filter: A string containing the keywords for filtering.
:param threshold: Cosine similarity threshold for filtering documents.
:return: Filtered list of documents.
"""
if not semantic_filter:
return documents
# Compute embedding for the keyword filter
query_embedding = self.get_embeddings([semantic_filter])[0]
# Compute embeddings for the docu ments
document_embeddings = self.get_embeddings(documents)
# Calculate cosine similarity between the query embedding and document embeddings
similarities = cosine_similarity([query_embedding], document_embeddings).flatten()
# Filter documents based on the similarity threshold
filtered_docs = [doc for doc, sim in zip(documents, similarities) if sim >= threshold]
return filtered_docs
def get_embeddings(self, sentences: List[str], bypass_buffer=True):
"""
Get BERT embeddings for a list of sentences.
:param sentences: List of text chunks (sentences).
:return: NumPy array of embeddings.
"""
# if self.buffer_embeddings.any() and not bypass_buffer:
# return self.buffer_embeddings
import torch
# Tokenize sentences and convert to tensor
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
@@ -193,6 +231,7 @@ class CosineStrategy(ExtractionStrategy):
# Get embeddings from the last hidden state (mean pooling)
embeddings = model_output.last_hidden_state.mean(1)
self.buffer_embeddings = embeddings.numpy()
return embeddings.numpy()
def hierarchical_clustering(self, sentences: List[str]):
@@ -206,7 +245,7 @@ class CosineStrategy(ExtractionStrategy):
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
self.timer = time.time()
embeddings = self.get_embeddings(sentences)
embeddings = self.get_embeddings(sentences, bypass_buffer=False)
# print(f"[LOG] 🚀 Embeddings computed in {time.time() - self.timer:.2f} seconds")
# Compute pairwise cosine distances
distance_matrix = pdist(embeddings, 'cosine')
@@ -247,6 +286,12 @@ class CosineStrategy(ExtractionStrategy):
# Assume `html` is a list of text chunks for this strategy
t = time.time()
text_chunks = html.split(self.DEL) # Split by lines or paragraphs as needed
# Pre-filter documents using embeddings and semantic_filter
text_chunks = self.filter_documents_embeddings(text_chunks, self.semantic_filter)
if not text_chunks:
return []
# Perform clustering
labels = self.hierarchical_clustering(text_chunks)
@@ -290,7 +335,7 @@ class CosineStrategy(ExtractionStrategy):
return self.extract(url, self.DEL.join(sections), **kwargs)
class TopicExtractionStrategy(ExtractionStrategy):
def __init__(self, num_keywords: int = 3):
def __init__(self, num_keywords: int = 3, **kwargs):
"""
Initialize the topic extraction strategy with parameters for topic segmentation.
@@ -358,7 +403,7 @@ class TopicExtractionStrategy(ExtractionStrategy):
return self.extract(url, self.DEL.join(sections), **kwargs)
class ContentSummarizationStrategy(ExtractionStrategy):
def __init__(self, model_name: str = "sshleifer/distilbart-cnn-12-6"):
def __init__(self, model_name: str = "sshleifer/distilbart-cnn-12-6", **kwargs):
"""
Initialize the content summarization strategy with a specific model.

View File

@@ -11,5 +11,6 @@ class CrawlResult(BaseModel):
success: bool
cleaned_html: str = None
markdown: str = None
parsed_json: str = None
extracted_content: str = None
metadata: dict = None
error_message: str = None

View File

@@ -59,7 +59,7 @@ Please provide your output within <blocks> tags, like this:
Remember, the output should be a complete, parsable JSON wrapped in <blocks> tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""
PROMPT_EXTRACT_BLOCKS = """YHere is the URL of the webpage:
PROMPT_EXTRACT_BLOCKS = """Here is the URL of the webpage:
<url>{URL}</url>
And here is the cleaned HTML content of that webpage:
@@ -107,4 +107,61 @@ Please provide your output within <blocks> tags, like this:
}]
</blocks>
Remember, the output should be a complete, parsable JSON wrapped in <blocks> tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""
PROMPT_EXTRACT_BLOCKS_WITH_INSTRUCTION = """Here is the URL of the webpage:
<url>{URL}</url>
And here is the cleaned HTML content of that webpage:
<html>
{HTML}
</html>
Your task is to break down this HTML content into semantically relevant blocks, following the provided user's REQUEST, and for each block, generate a JSON object with the following keys:
- index: an integer representing the index of the block in the content
- content: a list of strings containing the text content of the block
This is the user's REQUEST, pay attention to it:
<request>
{REQUEST}
</request>
To generate the JSON objects:
1. Carefully read through the HTML content and identify logical breaks or shifts in the content that would warrant splitting it into separate blocks.
2. For each block:
a. Assign it an index based on its order in the content.
b. Analyze the content and generate ONE semantic tag that describe what the block is about.
c. Extract the text content, EXACTLY SAME AS GIVE DATA, clean it up if needed, and store it as a list of strings in the "content" field.
3. Ensure that the order of the JSON objects matches the order of the blocks as they appear in the original HTML content.
4. Double-check that each JSON object includes all required keys (index, tag, content) and that the values are in the expected format (integer, list of strings, etc.).
5. Make sure the generated JSON is complete and parsable, with no errors or omissions.
6. Make sur to escape any special characters in the HTML content, and also single or double quote to avoid JSON parsing issues.
7. Never alter the extracted content, just copy and paste it as it is.
Please provide your output within <blocks> tags, like this:
<blocks>
[{
"index": 0,
"tags": ["introduction"],
"content": ["This is the first paragraph of the article, which provides an introduction and overview of the main topic."]
},
{
"index": 1,
"tags": ["background"],
"content": ["This is the second paragraph, which delves into the history and background of the topic.",
"It provides context and sets the stage for the rest of the article."]
}]
</blocks>
**Make sure to follow the user instruction to extract blocks aligin with the instruction.**
Remember, the output should be a complete, parsable JSON wrapped in <blocks> tags, with no omissions or errors. The JSON objects should semantically break down the content into relevant blocks, maintaining the original order."""

View File

@@ -461,17 +461,17 @@ def merge_chunks_based_on_token_threshold(chunks, token_threshold):
return merged_sections
def process_sections(url: str, sections: list, provider: str, api_token: str) -> list:
parsed_json = []
extracted_content = []
if provider.startswith("groq/"):
# Sequential processing with a delay
for section in sections:
parsed_json.extend(extract_blocks(url, section, provider, api_token))
extracted_content.extend(extract_blocks(url, section, provider, api_token))
time.sleep(0.5) # 500 ms delay between each processing
else:
# Parallel processing using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = [executor.submit(extract_blocks, url, section, provider, api_token) for section in sections]
for future in as_completed(futures):
parsed_json.extend(future.result())
extracted_content.extend(future.result())
return parsed_json
return extracted_content

View File

@@ -1,8 +1,9 @@
import os, time
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path
from .models import UrlModel, CrawlResult
from .database import init_db, get_cached_url, cache_url, DB_PATH
from .database import init_db, get_cached_url, cache_url, DB_PATH, flush_db
from .utils import *
from .chunking_strategy import *
from .extraction_strategy import *
@@ -16,11 +17,13 @@ from .config import *
class WebCrawler:
def __init__(
self,
db_path: str = None,
# db_path: str = None,
crawler_strategy: CrawlerStrategy = LocalSeleniumCrawlerStrategy(),
always_by_pass_cache: bool = False,
):
self.db_path = db_path
# self.db_path = db_path
self.crawler_strategy = crawler_strategy
self.always_by_pass_cache = always_by_pass_cache
# Create the .crawl4ai folder in the user's home directory if it doesn't exist
self.crawl4ai_folder = os.path.join(Path.home(), ".crawl4ai")
@@ -28,10 +31,11 @@ class WebCrawler:
os.makedirs(f"{self.crawl4ai_folder}/cache", exist_ok=True)
# If db_path is not provided, use the default path
if not db_path:
self.db_path = f"{self.crawl4ai_folder}/crawl4ai.db"
# if not db_path:
# self.db_path = f"{self.crawl4ai_folder}/crawl4ai.db"
init_db(self.db_path)
flush_db()
init_db()
self.ready = False
@@ -93,7 +97,7 @@ class WebCrawler:
word_count_threshold = MIN_WORD_THRESHOLD
# Check cache first
if not bypass_cache:
if not bypass_cache and not self.always_by_pass_cache:
cached = get_cached_url(url)
if cached:
return CrawlResult(
@@ -102,7 +106,7 @@ class WebCrawler:
"html": cached[1],
"cleaned_html": cached[2],
"markdown": cached[3],
"parsed_json": cached[4],
"extracted_content": cached[4],
"success": cached[5],
"error_message": "",
}
@@ -130,7 +134,7 @@ class WebCrawler:
f"[LOG] 🚀 Crawling done for {url}, success: {success}, time taken: {time.time() - t} seconds"
)
parsed_json = []
extracted_content = []
if verbose:
print(f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}")
t = time.time()
@@ -138,10 +142,10 @@ class WebCrawler:
sections = chunking_strategy.chunk(markdown)
# sections = merge_chunks_based_on_token_threshold(sections, CHUNK_TOKEN_THRESHOLD)
parsed_json = extraction_strategy.run(
extracted_content = extraction_strategy.run(
url, sections,
)
parsed_json = json.dumps(parsed_json)
extracted_content = json.dumps(extracted_content)
if verbose:
print(
@@ -155,7 +159,7 @@ class WebCrawler:
html,
cleaned_html,
markdown,
parsed_json,
extracted_content,
success,
)
@@ -164,7 +168,7 @@ class WebCrawler:
html=html,
cleaned_html=cleaned_html,
markdown=markdown,
parsed_json=parsed_json,
extracted_content=extracted_content,
success=success,
error_message=error_message,
)